Skip to content

Commit 815fc34

Browse files
committed
Fix more tests
1 parent df829e4 commit 815fc34

File tree

2 files changed

+31
-27
lines changed

2 files changed

+31
-27
lines changed

tests/llm_providers/test_langchain_initializer.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,25 +70,27 @@ def test_special_case_called_first(mock_initializers):
7070

7171
def test_chat_completion_called(mock_initializers):
7272
mock_initializers["special"].return_value = None
73+
mock_initializers["text"].return_value = None
7374
mock_initializers["chat"].return_value = "chat_model"
7475
result = init_langchain_model("chat-model", "provider", "chat", {})
7576
assert result == "chat_model"
7677
mock_initializers["special"].assert_called_once()
78+
mock_initializers["text"].assert_called_once()
7779
mock_initializers["chat"].assert_called_once()
7880
mock_initializers["community"].assert_not_called()
79-
mock_initializers["text"].assert_not_called()
8081

8182

8283
def test_community_chat_called(mock_initializers):
8384
mock_initializers["special"].return_value = None
85+
mock_initializers["text"].return_value = None
8486
mock_initializers["chat"].return_value = None
8587
mock_initializers["community"].return_value = "community_model"
8688
result = init_langchain_model("community-chat", "provider", "chat", {})
8789
assert result == "community_model"
8890
mock_initializers["special"].assert_called_once()
91+
mock_initializers["text"].assert_called_once()
8992
mock_initializers["chat"].assert_called_once()
9093
mock_initializers["community"].assert_called_once()
91-
mock_initializers["text"].assert_not_called()
9294

9395

9496
def test_text_completion_called(mock_initializers):
@@ -154,36 +156,39 @@ def test_all_initializers_raise_exceptions(mock_initializers):
154156

155157
def test_duplicate_modes_in_initializer(mock_initializers):
156158
mock_initializers["special"].return_value = None
159+
mock_initializers["text"].return_value = None
157160
mock_initializers["chat"].return_value = "chat_model"
158161
result = init_langchain_model("chat-model", "provider", "chat", {})
159162
assert result == "chat_model"
160163
mock_initializers["special"].assert_called_once()
164+
mock_initializers["text"].assert_called_once()
161165
mock_initializers["chat"].assert_called_once()
162166
mock_initializers["community"].assert_not_called()
163-
mock_initializers["text"].assert_not_called()
164167

165168

166169
def test_chat_completion_called_when_special_returns_none(mock_initializers):
167170
mock_initializers["special"].return_value = None
171+
mock_initializers["text"].return_value = None
168172
mock_initializers["chat"].return_value = "chat_model"
169173
result = init_langchain_model("chat-model", "provider", "chat", {})
170174
assert result == "chat_model"
171175
mock_initializers["special"].assert_called_once()
176+
mock_initializers["text"].assert_called_once()
172177
mock_initializers["chat"].assert_called_once()
173178
mock_initializers["community"].assert_not_called()
174-
mock_initializers["text"].assert_not_called()
175179

176180

177181
def test_community_chat_called_when_previous_fail(mock_initializers):
178182
mock_initializers["special"].return_value = None
183+
mock_initializers["text"].return_value = None
179184
mock_initializers["chat"].return_value = None
180185
mock_initializers["community"].return_value = "community_model"
181186
result = init_langchain_model("community-chat", "provider", "chat", {})
182187
assert result == "community_model"
183188
mock_initializers["special"].assert_called_once()
189+
mock_initializers["text"].assert_called_once()
184190
mock_initializers["chat"].assert_called_once()
185191
mock_initializers["community"].assert_called_once()
186-
mock_initializers["text"].assert_not_called()
187192

188193

189194
def test_text_completion_called_when_previous_fail(mock_initializers):
@@ -201,12 +206,11 @@ def test_text_completion_called_when_previous_fail(mock_initializers):
201206

202207
def test_text_completion_supports_chat_mode(mock_initializers):
203208
mock_initializers["special"].return_value = None
204-
mock_initializers["chat"].return_value = None
205-
mock_initializers["community"].return_value = None
206209
mock_initializers["text"].return_value = "text_model"
207210
result = init_langchain_model("text-model", "provider", "chat", {})
208211
assert result == "text_model"
209212
mock_initializers["special"].assert_called_once()
210-
mock_initializers["chat"].assert_called_once()
211-
mock_initializers["community"].assert_called_once()
212213
mock_initializers["text"].assert_called_once()
214+
# Since text returns a value, chat and community are not called
215+
mock_initializers["chat"].assert_not_called()
216+
mock_initializers["community"].assert_not_called()

tests/test_rails_config.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_check_prompt_exist_for_self_check_rails():
9898
],
9999
}
100100
with pytest.raises(
101-
ValueError, match="You must provide a `self_check_output` prompt template"
101+
ValueError, match="Missing a `self_check_output` prompt template"
102102
):
103103
RailsConfig.check_prompt_exist_for_self_check_rails(values)
104104

@@ -353,7 +353,7 @@ def test_validate_rail_prompts_wrong_flow_id_raises(self):
353353

354354
with pytest.raises(
355355
ValueError,
356-
match="You must provide a `content_safety_check_input \$model=content_safety` prompt template.",
356+
match="Missing a `content_safety_check_input \$model=content_safety` prompt template",
357357
):
358358
_validate_rail_prompts(
359359
["content safety check input $model=content_safety"],
@@ -366,7 +366,7 @@ def test_validate_rail_prompts_wrong_model_raises(self):
366366

367367
with pytest.raises(
368368
ValueError,
369-
match="You must provide a `content_safety_check_input \$model=content_safety` prompt template.",
369+
match="Missing a `content_safety_check_input \$model=content_safety` prompt template",
370370
):
371371
_validate_rail_prompts(
372372
["content safety check input $model=content_safety"],
@@ -379,7 +379,7 @@ def test_validate_rail_prompts_no_prompt_raises(self):
379379

380380
with pytest.raises(
381381
ValueError,
382-
match="You must provide a `content_safety_check_input \$model=content_safety` prompt template.",
382+
match="Missing a `content_safety_check_input \$model=content_safety` prompt template",
383383
):
384384
_validate_rail_prompts(
385385
["content safety check input $model=content_safety"],
@@ -395,7 +395,7 @@ def test_content_safety_input_missing_prompt_raises(self):
395395
"""Check Content Safety output rail raises ValueError if we don't have a prompt"""
396396
with pytest.raises(
397397
ValueError,
398-
match="You must provide a `content_safety_check_input \$model=content_safety` prompt template.",
398+
match="Missing a `content_safety_check_input \$model=content_safety` prompt template",
399399
):
400400
_ = RailsConfig.from_content(
401401
yaml_content="""
@@ -415,7 +415,7 @@ def test_content_safety_output_missing_prompt_raises(self):
415415
"""Check Content Safety output rail raises ValueError if we don't have a prompt"""
416416
with pytest.raises(
417417
ValueError,
418-
match="You must provide a `content_safety_check_output \$model=content_safety` prompt template.",
418+
match="Missing a `content_safety_check_output \$model=content_safety` prompt template",
419419
):
420420
_ = RailsConfig.from_content(
421421
yaml_content="""
@@ -531,7 +531,7 @@ def test_input_content_safety_no_model_raises(self):
531531

532532
with pytest.raises(
533533
ValueError,
534-
match="No `content_safety` model provided for input flow `content safety check input`",
534+
match="Input flow 'content safety check input' references model type 'content_safety' that is not defined",
535535
):
536536
_ = RailsConfig.from_content(
537537
yaml_content="""
@@ -556,7 +556,7 @@ def test_input_content_safety_wrong_model_raises(self):
556556

557557
with pytest.raises(
558558
ValueError,
559-
match="No `content_safety` model provided for input flow `content safety check input",
559+
match="Input flow 'content safety check input' references model type 'content_safety' that is not defined",
560560
):
561561
_ = RailsConfig.from_content(
562562
yaml_content="""
@@ -581,7 +581,7 @@ def test_output_content_safety_no_model_raises(self):
581581

582582
with pytest.raises(
583583
ValueError,
584-
match="No `content_safety` model provided for output flow `content safety check output`",
584+
match="Output flow 'content safety check output' references model type 'content_safety' that is not defined",
585585
):
586586
_ = RailsConfig.from_content(
587587
yaml_content="""
@@ -606,7 +606,7 @@ def test_output_content_safety_wrong_model_raises(self):
606606

607607
with pytest.raises(
608608
ValueError,
609-
match="You must provide a `content_safety_check_output \$model=content_safety` prompt template",
609+
match="Missing a `content_safety_check_output \$model=content_safety` prompt template",
610610
):
611611
_ = RailsConfig.from_content(
612612
yaml_content="""
@@ -664,7 +664,7 @@ def test_topic_safety_no_prompt_raises(self):
664664

665665
with pytest.raises(
666666
ValueError,
667-
match="You must provide a `topic_safety_check_input \$model=topic_control` prompt template",
667+
match="Missing a `topic_safety_check_input \$model=topic_control` prompt template",
668668
):
669669
_ = RailsConfig.from_content(
670670
yaml_content="""
@@ -688,7 +688,7 @@ def test_topic_safety_no_model_raises(self):
688688
"""Check if we don't provide a topic-safety model we raise a ValueError"""
689689
with pytest.raises(
690690
ValueError,
691-
match="No `topic_control` model provided for input flow `topic safety check input`",
691+
match="Input flow 'topic safety check input' references model type 'topic_control' that is not defined",
692692
):
693693
_ = RailsConfig.from_content(
694694
yaml_content="""
@@ -712,7 +712,7 @@ def test_topic_safety_no_model_no_prompt_raises(self):
712712
"""Check a missing model and prompt raises ValueError"""
713713
with pytest.raises(
714714
ValueError,
715-
match="You must provide a `topic_safety_check_input \$model=topic_control` prompt template",
715+
match="Missing a `topic_safety_check_input \$model=topic_control` prompt template",
716716
):
717717
_ = RailsConfig.from_content(
718718
yaml_content="""
@@ -741,7 +741,7 @@ def test_hero_separate_models_no_prompts_raises(self):
741741

742742
with pytest.raises(
743743
ValueError,
744-
match="You must provide a `content_safety_check_input \$model=my_content_safety` prompt template",
744+
match="Missing a `content_safety_check_input \$model=my_content_safety` prompt template",
745745
):
746746
_ = RailsConfig.from_content(
747747
yaml_content="""
@@ -883,7 +883,7 @@ def test_hero_no_prompts_raises(self):
883883
"""Create hero workflow with no prompts. Expect Content Safety input prompt check to fail"""
884884
with pytest.raises(
885885
ValueError,
886-
match="You must provide a `content_safety_check_input \$model=content_safety` prompt template",
886+
match="Missing a `content_safety_check_input \$model=content_safety` prompt template",
887887
):
888888
_ = RailsConfig.from_content(
889889
yaml_content="""
@@ -923,7 +923,7 @@ def test_hero_no_output_content_safety_prompt_raises(self):
923923
"""Create hero workflow with no prompts. Expect Content Safety input prompt check to fail"""
924924
with pytest.raises(
925925
ValueError,
926-
match="You must provide a `topic_safety_check_input \$model=your_topic_control` prompt template",
926+
match="Missing a `topic_safety_check_input \$model=your_topic_control` prompt template",
927927
):
928928
_ = RailsConfig.from_content(
929929
yaml_content="""
@@ -967,7 +967,7 @@ def test_hero_no_topic_safety_prompt_raises(self):
967967
"""Create hero workflow with no prompts. Expect Content Safety input prompt check to fail"""
968968
with pytest.raises(
969969
ValueError,
970-
match="You must provide a `topic_safety_check_input \$model=your_topic_control` prompt template",
970+
match="Missing a `topic_safety_check_input \$model=your_topic_control` prompt template",
971971
):
972972
_ = RailsConfig.from_content(
973973
yaml_content="""
@@ -1013,7 +1013,7 @@ def test_hero_topic_safety_prompt_raises(self):
10131013
"""Create hero workflow with no prompts. Expect Content Safety input prompt check to fail"""
10141014
with pytest.raises(
10151015
ValueError,
1016-
match="You must provide a `content_safety_check_input \$model=content_safety` prompt template",
1016+
match="Missing a `content_safety_check_input \$model=content_safety` prompt template",
10171017
):
10181018
_ = RailsConfig.from_content(
10191019
yaml_content="""

0 commit comments

Comments
 (0)