Skip to content

Commit aae1465

Browse files
committed
fix: Surface relevant exception when initializing langchain model
1 parent 9b59488 commit aae1465

File tree

3 files changed

+49
-15
lines changed

3 files changed

+49
-15
lines changed

nemoguardrails/llm/models/langchain_initializer.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def _init_chat_completion_model(model_name: str, provider_name: str, kwargs: Dic
225225
raise
226226

227227

228-
def _init_text_completion_model(model_name: str, provider_name: str, kwargs: Dict[str, Any]) -> BaseLLM:
228+
def _init_text_completion_model(model_name: str, provider_name: str, kwargs: Dict[str, Any]) -> BaseLLM | None:
229229
"""Initialize a text completion model.
230230
231231
Args:
@@ -234,22 +234,24 @@ def _init_text_completion_model(model_name: str, provider_name: str, kwargs: Dic
234234
kwargs: Additional arguments to pass to the model initialization
235235
236236
Returns:
237-
An initialized text completion model
238-
239-
Raises:
240-
RuntimeError: If the provider is not found
237+
An initialized text completion model, or None if the provider is not found
241238
"""
242-
provider_cls = _get_text_completion_provider(provider_name)
239+
try:
240+
provider_cls = _get_text_completion_provider(provider_name)
241+
except RuntimeError:
242+
return None
243+
243244
if provider_cls is None:
244-
raise ValueError()
245+
return None
246+
245247
kwargs = _update_model_kwargs(provider_cls, model_name, kwargs)
246248
# remove stream_usage parameter as it's not supported by text completion APIs
247249
# (e.g., OpenAI's AsyncCompletions.create() doesn't accept this parameter)
248250
kwargs.pop("stream_usage", None)
249251
return provider_cls(**kwargs)
250252

251253

252-
def _init_community_chat_models(model_name: str, provider_name: str, kwargs: Dict[str, Any]) -> BaseChatModel:
254+
def _init_community_chat_models(model_name: str, provider_name: str, kwargs: Dict[str, Any]) -> BaseChatModel | None:
253255
"""Initialize community chat models.
254256
255257
Args:
@@ -264,14 +266,19 @@ def _init_community_chat_models(model_name: str, provider_name: str, kwargs: Dic
264266
ImportError: If langchain_community is not installed
265267
ModelInitializationError: If model initialization fails
266268
"""
267-
provider_cls = _get_chat_completion_provider(provider_name)
269+
try:
270+
provider_cls = _get_chat_completion_provider(provider_name)
271+
except RuntimeError:
272+
return None
273+
268274
if provider_cls is None:
269-
raise ValueError()
275+
return None
276+
270277
kwargs = _update_model_kwargs(provider_cls, model_name, kwargs)
271278
return provider_cls(**kwargs)
272279

273280

274-
def _init_gpt35_turbo_instruct(model_name: str, provider_name: str, kwargs: Dict[str, Any]) -> BaseLLM:
281+
def _init_gpt35_turbo_instruct(model_name: str, provider_name: str, kwargs: Dict[str, Any]) -> BaseLLM | None:
275282
"""Initialize GPT-3.5 Turbo Instruct model.
276283
277284
Currently init_chat_model from langchain infers this as a chat model.

tests/llm_providers/test_langchain_initialization_methods.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,7 @@ def test_init_community_chat_models_no_provider(self):
116116
"nemoguardrails.llm.models.langchain_initializer._get_chat_completion_provider"
117117
) as mock_get_provider:
118118
mock_get_provider.return_value = None
119-
with pytest.raises(ValueError):
120-
_init_community_chat_models("community-model", "provider", {})
119+
assert _init_community_chat_models("community-model", "provider", {}) is None
121120

122121

123122
class TestTextCompletionInitializer:
@@ -156,8 +155,7 @@ def test_init_text_completion_model_no_provider(self):
156155
"nemoguardrails.llm.models.langchain_initializer._get_text_completion_provider"
157156
) as mock_get_provider:
158157
mock_get_provider.return_value = None
159-
with pytest.raises(ValueError):
160-
_init_text_completion_model("text-model", "provider", {})
158+
assert _init_text_completion_model("text-model", "provider", {}) is None
161159

162160

163161
class TestUpdateModelKwargs:

tests/llm_providers/test_langchain_initializer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,32 @@ def test_text_completion_supports_chat_mode(mock_initializers):
194194
mock_initializers["chat"].assert_called_once()
195195
mock_initializers["community"].assert_called_once()
196196
mock_initializers["text"].assert_called_once()
197+
198+
199+
# Tests for error masking prevention (issue where later None returns mask earlier exceptions)
200+
201+
202+
def test_exception_not_masked_by_none_return(mock_initializers):
203+
"""Test that an exception from an initializer is preserved when later ones return None.
204+
205+
For example: if community chat throws an error (e.g., invalid API key), but text completion
206+
returns None because that provider type doesn't exist, the community error should be raised.
207+
"""
208+
mock_initializers["special"].return_value = None
209+
mock_initializers["chat"].return_value = None
210+
mock_initializers["community"].side_effect = ValueError("Invalid API key for provider")
211+
mock_initializers["text"].return_value = None # Provider not found, returns None
212+
213+
with pytest.raises(ModelInitializationError, match="Invalid API key for provider"):
214+
init_langchain_model("community-model", "provider", "chat", {})
215+
216+
217+
def test_import_error_prioritized_over_other_exceptions(mock_initializers):
218+
"""Test that ImportError is surfaced to help users know when packages are missing."""
219+
mock_initializers["special"].return_value = None
220+
mock_initializers["chat"].side_effect = ValueError("Some config error")
221+
mock_initializers["community"].side_effect = ImportError("Missing langchain_community package")
222+
mock_initializers["text"].return_value = None
223+
224+
with pytest.raises(ModelInitializationError, match="Missing langchain_community package"):
225+
init_langchain_model("model", "provider", "chat", {})

0 commit comments

Comments
 (0)