Skip to content

Commit df829e4

Browse files
committed
Fix tests
1 parent ace439c commit df829e4

File tree

4 files changed

+32
-20
lines changed

4 files changed

+32
-20
lines changed

nemoguardrails/actions/llm/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"endpoint",
4848
"endpoint_url",
4949
"openai_api_base",
50+
"server_url",
5051
]
5152

5253

nemoguardrails/rails/llm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1526,7 +1526,7 @@ def check_model_exists_for_output_rails(cls, values):
15261526
else "none"
15271527
)
15281528
raise InvalidRailsConfigurationError(
1529-
"Output flow {flow_id} references model type '{flow_model}' that is not defined in the configuration. Detected model types: {available_types}."
1529+
f"Output flow '{flow_id}' references model type '{flow_model}' that is not defined in the configuration. Detected model types: {available_types}."
15301530
)
15311531
return values
15321532

tests/test_actions_llm_utils.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
from typing import cast
1617
from unittest.mock import AsyncMock
1718

1819
import pytest
20+
from langchain_core.language_models import BaseLanguageModel
1921

2022
from nemoguardrails.actions.llm.utils import (
2123
_extract_and_remove_think_tags,
@@ -55,6 +57,24 @@ class MockPatchedNVIDIA(MockNVIDIAOriginal):
5557
__module__ = "nemoguardrails.llm.providers._langchain_nvidia_ai_endpoints_patch"
5658

5759

60+
class MockTRTLLM:
61+
__module__ = "nemoguardrails.llm.providers.trtllm.llm"
62+
63+
64+
class MockAzureLLM:
65+
__module__ = "langchain_openai.chat_models"
66+
67+
68+
class MockLLMWithClient:
69+
__module__ = "langchain_openai.chat_models"
70+
71+
class _MockClient:
72+
base_url = "https://custom.endpoint.com/v1"
73+
74+
def __init__(self):
75+
self.client = self._MockClient()
76+
77+
5878
def test_infer_provider_openai():
5979
llm = MockOpenAILLM()
6080
provider = _infer_provider_from_module(llm)
@@ -315,14 +335,13 @@ def test_extract_and_remove_think_tags_wrong_order():
315335
@pytest.mark.asyncio
316336
async def test_llm_call_exception_enrichment_with_model_and_endpoint():
317337
"""Test that LLM invocation errors include model and endpoint context."""
318-
mock_llm = AsyncMock()
319-
mock_llm.__module__ = "langchain_openai.chat_models"
338+
mock_llm = MockOpenAILLM()
320339
mock_llm.model_name = "gpt-4"
321340
mock_llm.base_url = "https://api.openai.com/v1"
322341
mock_llm.ainvoke = AsyncMock(side_effect=ConnectionError("Connection refused"))
323342

324343
with pytest.raises(LLMCallException) as exc_info:
325-
await llm_call(mock_llm, "test prompt")
344+
await llm_call(cast(BaseLanguageModel, mock_llm), "test prompt")
326345

327346
exc_str = str(exc_info.value)
328347
assert "gpt-4" in exc_str
@@ -351,14 +370,13 @@ async def test_llm_call_exception_without_endpoint():
351370
@pytest.mark.asyncio
352371
async def test_llm_call_exception_extracts_azure_endpoint():
353372
"""Test that Azure-style endpoint URLs are extracted."""
354-
mock_llm = AsyncMock()
355-
mock_llm.__module__ = "langchain_openai.chat_models"
373+
mock_llm = MockAzureLLM()
356374
mock_llm.model_name = "gpt-4"
357375
mock_llm.azure_endpoint = "https://example.openai.azure.com"
358376
mock_llm.ainvoke = AsyncMock(side_effect=Exception("Azure error"))
359377

360378
with pytest.raises(LLMCallException) as exc_info:
361-
await llm_call(mock_llm, "test prompt")
379+
await llm_call(cast(BaseLanguageModel, mock_llm), "test prompt")
362380

363381
exc_str = str(exc_info.value)
364382
assert "https://example.openai.azure.com" in exc_str
@@ -369,14 +387,13 @@ async def test_llm_call_exception_extracts_azure_endpoint():
369387
@pytest.mark.asyncio
370388
async def test_llm_call_exception_extracts_server_url():
371389
"""Test that TRT-style server_url is extracted."""
372-
mock_llm = AsyncMock()
373-
mock_llm.__module__ = "nemoguardrails.llm.providers.trtllm.llm"
390+
mock_llm = MockTRTLLM()
374391
mock_llm.model_name = "llama-2-70b"
375392
mock_llm.server_url = "https://triton.example.com:8000"
376393
mock_llm.ainvoke = AsyncMock(side_effect=Exception("Triton server error"))
377394

378395
with pytest.raises(LLMCallException) as exc_info:
379-
await llm_call(mock_llm, "test prompt")
396+
await llm_call(cast(BaseLanguageModel, mock_llm), "test prompt")
380397

381398
exc_str = str(exc_info.value)
382399
assert "https://triton.example.com:8000" in exc_str
@@ -387,18 +404,12 @@ async def test_llm_call_exception_extracts_server_url():
387404
@pytest.mark.asyncio
388405
async def test_llm_call_exception_extracts_nested_client_base_url():
389406
"""Test that nested client.base_url is extracted."""
390-
391-
class MockClient:
392-
base_url = "https://custom.endpoint.com/v1"
393-
394-
mock_llm = AsyncMock()
395-
mock_llm.__module__ = "langchain_openai.chat_models"
407+
mock_llm = MockLLMWithClient()
396408
mock_llm.model_name = "gpt-4-turbo"
397-
mock_llm.client = MockClient()
398409
mock_llm.ainvoke = AsyncMock(side_effect=Exception("Client error"))
399410

400411
with pytest.raises(LLMCallException) as exc_info:
401-
await llm_call(mock_llm, "test prompt")
412+
await llm_call(cast(BaseLanguageModel, mock_llm), "test prompt")
402413

403414
exc_str = str(exc_info.value)
404415
assert "https://custom.endpoint.com/v1" in exc_str

tests/test_config_validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_self_check_input_prompt_exception():
7575
)
7676
LLMRails(config=config)
7777

78-
assert "You must provide a `self_check_input` prompt" in str(exc_info.value)
78+
assert "Missing a `self_check_input` prompt template" in str(exc_info.value)
7979

8080

8181
def test_self_check_output_prompt_exception():
@@ -90,7 +90,7 @@ def test_self_check_output_prompt_exception():
9090
)
9191
LLMRails(config=config)
9292

93-
assert "You must provide a `self_check_output` prompt" in str(exc_info.value)
93+
assert "Missing a `self_check_output` prompt template" in str(exc_info.value)
9494

9595

9696
def test_passthrough_and_single_call_incompatibility():

0 commit comments

Comments
 (0)