Skip to content

Commit 8b18bb5

Browse files
committed
add tests
fix
1 parent 3b5d09e commit 8b18bb5

File tree

2 files changed

+235
-2
lines changed

2 files changed

+235
-2
lines changed

tests/test_content_safety_actions.py

Lines changed: 160 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,31 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from unittest.mock import MagicMock
16+
from unittest.mock import MagicMock, patch
1717

18-
# conftest.py
1918
import pytest
2019

2120
from nemoguardrails.library.content_safety.actions import (
21+
DEFAULT_REFUSAL_MESSAGES,
22+
SUPPORTED_LANGUAGES,
23+
_detect_language,
24+
_get_refusal_message,
2225
content_safety_check_input,
2326
content_safety_check_output,
2427
content_safety_check_output_mapping,
28+
detect_language,
2529
)
2630
from tests.utils import FakeLLM
2731

32+
try:
33+
import fast_langdetect # noqa
34+
35+
HAS_FAST_LANGDETECT = True
36+
except ImportError:
37+
HAS_FAST_LANGDETECT = False
38+
39+
requires_fast_langdetect = pytest.mark.skipif(not HAS_FAST_LANGDETECT, reason="fast-langdetect not installed")
40+
2841

2942
@pytest.fixture
3043
def fake_llm():
@@ -150,3 +163,148 @@ def test_content_safety_check_output_mapping_default():
150163
"""Test content_safety_check_output_mapping defaults to allowed=False when key is missing."""
151164
result = {"policy_violations": []}
152165
assert content_safety_check_output_mapping(result) is False
166+
167+
168+
@requires_fast_langdetect
169+
class TestDetectLanguage:
170+
@pytest.mark.parametrize(
171+
"text,expected_lang",
172+
[
173+
("Hello, how are you today?", "en"),
174+
("Hola, ¿cómo estás hoy?", "es"),
175+
("你好,你今天好吗?", "zh"),
176+
("Guten Tag, wie geht es Ihnen?", "de"),
177+
("Bonjour, comment allez-vous?", "fr"),
178+
("こんにちは、お元気ですか?", "ja"),
179+
],
180+
ids=["english", "spanish", "chinese", "german", "french", "japanese"],
181+
)
182+
def test_detect_language(self, text, expected_lang):
183+
assert _detect_language(text) == expected_lang
184+
185+
def test_detect_language_empty_string(self):
186+
result = _detect_language("")
187+
assert result is None or result == "en"
188+
189+
def test_detect_language_import_error(self):
190+
with patch.dict("sys.modules", {"fast_langdetect": None}):
191+
import nemoguardrails.library.content_safety.actions as actions_module
192+
193+
_original_detect_language = actions_module._detect_language
194+
195+
def patched_detect_language(text):
196+
try:
197+
raise ImportError("No module named 'fast_langdetect'")
198+
except ImportError:
199+
return None
200+
201+
with patch.object(actions_module, "_detect_language", patched_detect_language):
202+
result = actions_module._detect_language("Hello")
203+
assert result is None
204+
205+
def test_detect_language_exception(self):
206+
with patch("fast_langdetect.detect", side_effect=Exception("Detection failed")):
207+
result = _detect_language("Hello")
208+
assert result is None
209+
210+
211+
class TestGetRefusalMessage:
212+
@pytest.mark.parametrize("lang", list(SUPPORTED_LANGUAGES))
213+
def test_default_messages(self, lang):
214+
result = _get_refusal_message(lang, None)
215+
assert result == DEFAULT_REFUSAL_MESSAGES[lang]
216+
217+
def test_custom_message_used_when_available(self):
218+
custom = {"en": "Custom refusal", "es": "Rechazo personalizado"}
219+
assert _get_refusal_message("en", custom) == "Custom refusal"
220+
assert _get_refusal_message("es", custom) == "Rechazo personalizado"
221+
222+
def test_unsupported_lang_falls_back_to_english(self):
223+
assert _get_refusal_message("xyz", None) == DEFAULT_REFUSAL_MESSAGES["en"]
224+
assert _get_refusal_message("xyz", {"en": "Custom fallback"}) == "Custom fallback"
225+
226+
def test_lang_not_in_custom_uses_default(self):
227+
custom = {"en": "Custom English"}
228+
assert _get_refusal_message("es", custom) == DEFAULT_REFUSAL_MESSAGES["es"]
229+
230+
231+
@requires_fast_langdetect
232+
class TestDetectLanguageAction:
233+
@pytest.mark.asyncio
234+
@pytest.mark.parametrize(
235+
"user_message,expected_lang",
236+
[
237+
("Hello, how are you?", "en"),
238+
("Hola, ¿cómo estás?", "es"),
239+
("你好", "zh"),
240+
],
241+
ids=["english", "spanish", "chinese"],
242+
)
243+
async def test_detect_language_action(self, user_message, expected_lang):
244+
context = {"user_message": user_message}
245+
result = await detect_language(context=context, config=None)
246+
assert result["language"] == expected_lang
247+
assert result["refusal_message"] == DEFAULT_REFUSAL_MESSAGES[expected_lang]
248+
249+
@pytest.mark.asyncio
250+
@pytest.mark.parametrize(
251+
"context",
252+
[None, {"user_message": ""}],
253+
ids=["no_context", "empty_message"],
254+
)
255+
async def test_detect_language_action_defaults_to_english(self, context):
256+
result = await detect_language(context=context, config=None)
257+
assert result["language"] == "en"
258+
assert result["refusal_message"] == DEFAULT_REFUSAL_MESSAGES["en"]
259+
260+
@pytest.mark.asyncio
261+
async def test_detect_language_action_unsupported_language_falls_back_to_english(self):
262+
with patch(
263+
"nemoguardrails.library.content_safety.actions._detect_language",
264+
return_value="xyz",
265+
):
266+
context = {"user_message": "some text"}
267+
result = await detect_language(context=context, config=None)
268+
assert result["language"] == "en"
269+
assert result["refusal_message"] == DEFAULT_REFUSAL_MESSAGES["en"]
270+
271+
@pytest.mark.asyncio
272+
async def test_detect_language_action_with_config_custom_messages(self):
273+
mock_config = MagicMock()
274+
mock_config.rails.config.content_safety.multilingual.refusal_messages = {
275+
"en": "Custom: Cannot help",
276+
"es": "Personalizado: No puedo ayudar",
277+
}
278+
279+
context = {"user_message": "Hello"}
280+
result = await detect_language(context=context, config=mock_config)
281+
assert result["language"] == "en"
282+
assert result["refusal_message"] == "Custom: Cannot help"
283+
284+
@pytest.mark.asyncio
285+
async def test_detect_language_action_with_config_no_multilingual(self):
286+
mock_config = MagicMock()
287+
mock_config.rails.config.content_safety.multilingual = None
288+
289+
context = {"user_message": "Hello"}
290+
result = await detect_language(context=context, config=mock_config)
291+
assert result["language"] == "en"
292+
assert result["refusal_message"] == DEFAULT_REFUSAL_MESSAGES["en"]
293+
294+
295+
class TestSupportedLanguagesAndDefaults:
296+
def test_supported_languages_count(self):
297+
assert len(SUPPORTED_LANGUAGES) == 9
298+
299+
def test_supported_languages_contents(self):
300+
expected = {"en", "es", "zh", "de", "fr", "hi", "ja", "ar", "th"}
301+
assert SUPPORTED_LANGUAGES == expected
302+
303+
def test_default_refusal_messages_has_all_supported_languages(self):
304+
for lang in SUPPORTED_LANGUAGES:
305+
assert lang in DEFAULT_REFUSAL_MESSAGES
306+
307+
def test_default_refusal_messages_are_non_empty(self):
308+
for _lang, message in DEFAULT_REFUSAL_MESSAGES.items():
309+
assert message
310+
assert len(message) > 0

tests/test_rails_config.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323

2424
from nemoguardrails.llm.prompts import TaskPrompt
2525
from nemoguardrails.rails.llm.config import (
26+
ContentSafetyConfig,
2627
Model,
28+
MultilingualConfig,
2729
RailsConfig,
2830
_get_flow_model,
2931
_validate_rail_prompts,
@@ -1015,3 +1017,76 @@ def test_hero_topic_safety_prompt_raises(self):
10151017
content: Verify the user input is on-topic
10161018
"""
10171019
)
1020+
1021+
1022+
class TestMultilingualConfig:
1023+
def test_defaults(self):
1024+
config = MultilingualConfig()
1025+
assert config.enabled is False
1026+
assert config.refusal_messages is None
1027+
1028+
def test_with_custom_messages(self):
1029+
custom = {"en": "Custom", "es": "Personalizado"}
1030+
config = MultilingualConfig(enabled=True, refusal_messages=custom)
1031+
assert config.enabled is True
1032+
assert config.refusal_messages == custom
1033+
1034+
1035+
class TestContentSafetyConfigModel:
1036+
def test_defaults(self):
1037+
config = ContentSafetyConfig()
1038+
assert config.multilingual.enabled is False
1039+
assert config.multilingual.refusal_messages is None
1040+
1041+
def test_with_multilingual(self):
1042+
custom = {"en": "Custom"}
1043+
config = ContentSafetyConfig(multilingual=MultilingualConfig(enabled=True, refusal_messages=custom))
1044+
assert config.multilingual.enabled is True
1045+
assert config.multilingual.refusal_messages == custom
1046+
1047+
1048+
class TestMultilingualConfigInRailsConfig:
1049+
BASE_YAML = """
1050+
models:
1051+
- type: content_safety
1052+
engine: nim
1053+
model: nvidia/llama-3.1-nemoguard-8b-content-safety
1054+
rails:
1055+
{rails_config}
1056+
input:
1057+
flows:
1058+
- content safety check input $model=content_safety
1059+
prompts:
1060+
- task: content_safety_check_input $model=content_safety
1061+
content: Check content safety
1062+
"""
1063+
1064+
def test_multilingual_disabled_by_default(self):
1065+
config = RailsConfig.from_content(yaml_content=self.BASE_YAML.format(rails_config=""))
1066+
assert config.rails.config.content_safety.multilingual.enabled is False
1067+
1068+
def test_multilingual_enabled_with_custom_messages(self):
1069+
rails_config = """
1070+
config:
1071+
content_safety:
1072+
multilingual:
1073+
enabled: true
1074+
refusal_messages:
1075+
en: "Custom English"
1076+
es: "Personalizado"
1077+
"""
1078+
config = RailsConfig.from_content(yaml_content=self.BASE_YAML.format(rails_config=rails_config))
1079+
assert config.rails.config.content_safety.multilingual.enabled is True
1080+
assert config.rails.config.content_safety.multilingual.refusal_messages["en"] == "Custom English"
1081+
assert config.rails.config.content_safety.multilingual.refusal_messages["es"] == "Personalizado"
1082+
1083+
def test_multilingual_enabled_no_custom_messages(self):
1084+
rails_config = """
1085+
config:
1086+
content_safety:
1087+
multilingual:
1088+
enabled: true
1089+
"""
1090+
config = RailsConfig.from_content(yaml_content=self.BASE_YAML.format(rails_config=rails_config))
1091+
assert config.rails.config.content_safety.multilingual.enabled is True
1092+
assert config.rails.config.content_safety.multilingual.refusal_messages is None

0 commit comments

Comments
 (0)