1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ from typing import cast
1617from unittest .mock import AsyncMock
1718
1819import pytest
20+ from langchain_core .language_models import BaseLanguageModel
1921
2022from nemoguardrails .actions .llm .utils import (
2123 _extract_and_remove_think_tags ,
@@ -55,6 +57,24 @@ class MockPatchedNVIDIA(MockNVIDIAOriginal):
5557 __module__ = "nemoguardrails.llm.providers._langchain_nvidia_ai_endpoints_patch"
5658
5759
60+ class MockTRTLLM :
61+ __module__ = "nemoguardrails.llm.providers.trtllm.llm"
62+
63+
64+ class MockAzureLLM :
65+ __module__ = "langchain_openai.chat_models"
66+
67+
68+ class MockLLMWithClient :
69+ __module__ = "langchain_openai.chat_models"
70+
71+ class _MockClient :
72+ base_url = "https://custom.endpoint.com/v1"
73+
74+ def __init__ (self ):
75+ self .client = self ._MockClient ()
76+
77+
5878def test_infer_provider_openai ():
5979 llm = MockOpenAILLM ()
6080 provider = _infer_provider_from_module (llm )
@@ -315,14 +335,13 @@ def test_extract_and_remove_think_tags_wrong_order():
315335@pytest .mark .asyncio
316336async def test_llm_call_exception_enrichment_with_model_and_endpoint ():
317337 """Test that LLM invocation errors include model and endpoint context."""
318- mock_llm = AsyncMock ()
319- mock_llm .__module__ = "langchain_openai.chat_models"
338+ mock_llm = MockOpenAILLM ()
320339 mock_llm .model_name = "gpt-4"
321340 mock_llm .base_url = "https://api.openai.com/v1"
322341 mock_llm .ainvoke = AsyncMock (side_effect = ConnectionError ("Connection refused" ))
323342
324343 with pytest .raises (LLMCallException ) as exc_info :
325- await llm_call (mock_llm , "test prompt" )
344+ await llm_call (cast ( BaseLanguageModel , mock_llm ) , "test prompt" )
326345
327346 exc_str = str (exc_info .value )
328347 assert "gpt-4" in exc_str
@@ -351,14 +370,13 @@ async def test_llm_call_exception_without_endpoint():
351370@pytest .mark .asyncio
352371async def test_llm_call_exception_extracts_azure_endpoint ():
353372 """Test that Azure-style endpoint URLs are extracted."""
354- mock_llm = AsyncMock ()
355- mock_llm .__module__ = "langchain_openai.chat_models"
373+ mock_llm = MockAzureLLM ()
356374 mock_llm .model_name = "gpt-4"
357375 mock_llm .azure_endpoint = "https://example.openai.azure.com"
358376 mock_llm .ainvoke = AsyncMock (side_effect = Exception ("Azure error" ))
359377
360378 with pytest .raises (LLMCallException ) as exc_info :
361- await llm_call (mock_llm , "test prompt" )
379+ await llm_call (cast ( BaseLanguageModel , mock_llm ) , "test prompt" )
362380
363381 exc_str = str (exc_info .value )
364382 assert "https://example.openai.azure.com" in exc_str
@@ -369,14 +387,13 @@ async def test_llm_call_exception_extracts_azure_endpoint():
369387@pytest .mark .asyncio
370388async def test_llm_call_exception_extracts_server_url ():
371389 """Test that TRT-style server_url is extracted."""
372- mock_llm = AsyncMock ()
373- mock_llm .__module__ = "nemoguardrails.llm.providers.trtllm.llm"
390+ mock_llm = MockTRTLLM ()
374391 mock_llm .model_name = "llama-2-70b"
375392 mock_llm .server_url = "https://triton.example.com:8000"
376393 mock_llm .ainvoke = AsyncMock (side_effect = Exception ("Triton server error" ))
377394
378395 with pytest .raises (LLMCallException ) as exc_info :
379- await llm_call (mock_llm , "test prompt" )
396+ await llm_call (cast ( BaseLanguageModel , mock_llm ) , "test prompt" )
380397
381398 exc_str = str (exc_info .value )
382399 assert "https://triton.example.com:8000" in exc_str
@@ -387,18 +404,12 @@ async def test_llm_call_exception_extracts_server_url():
387404@pytest .mark .asyncio
388405async def test_llm_call_exception_extracts_nested_client_base_url ():
389406 """Test that nested client.base_url is extracted."""
390-
391- class MockClient :
392- base_url = "https://custom.endpoint.com/v1"
393-
394- mock_llm = AsyncMock ()
395- mock_llm .__module__ = "langchain_openai.chat_models"
407+ mock_llm = MockLLMWithClient ()
396408 mock_llm .model_name = "gpt-4-turbo"
397- mock_llm .client = MockClient ()
398409 mock_llm .ainvoke = AsyncMock (side_effect = Exception ("Client error" ))
399410
400411 with pytest .raises (LLMCallException ) as exc_info :
401- await llm_call (mock_llm , "test prompt" )
412+ await llm_call (cast ( BaseLanguageModel , mock_llm ) , "test prompt" )
402413
403414 exc_str = str (exc_info .value )
404415 assert "https://custom.endpoint.com/v1" in exc_str
0 commit comments