1- import json
21import yaml
32from typing import Union
3+ from dataclasses import dataclass
44
55class TestFile :
6- """An abstraction over a test file, which may be in one of several different formats .
7- Currently, JSON and YAML are supported.
6+ """An abstraction over a test file.
7+ Currently, only YAML files are supported.
88 """
99
10- def __init__ (self , path : str ) -> None :
10+ def __init__ (self , file_content : str , file_name : str ) -> None :
1111 self .groups = []
1212
13- # Attempt to open the given file. Exit with an error if this
14- # is not possible.
15- file_content = ""
16- try :
17- with open (path , "r" ) as test_file :
18- file_content = test_file .read ()
19- except IOError as e :
20- raise Exception (f'Failed to open test file: "{ e } "' )
21-
2213 # Get the file extension to determine which format should be used.
23- extension = path .split ("." )[- 1 ]
24- if extension == "json" :
25- try :
26- questions = json .loads (file_content )
27-
28- for question in questions :
29- out = []
30- title = question ["title" ]
31- for part in question ["parts" ]:
32- for response_area in part ["responseAreas" ]:
33- params = response_area ["params" ]
34- answer = response_area ["answer" ]
35- for test in response_area ["tests" ]:
36- test .update ({"answer" : answer })
37- test .update ({"params" : params })
38- out .append (SingleTest (test ))
39- self .groups .append ({"title" : title , "tests" : out })
40-
41- except KeyError as e :
42- raise Exception (f'The key "{ e .args [0 ]} " doesn\' t exist, or is in the wrong place.' )
43- except json .JSONDecodeError as e :
44- raise Exception (f'Error parsing JSON: "{ e } "' )
45- elif extension == "yaml" :
14+ extension = file_name .split ("." )[- 1 ]
15+ if extension == "yaml" :
4616 try :
4717 # Tests are organised in groups of separate YAML documents (separated by "---")
4818 docs = yaml .safe_load_all (file_content )
@@ -53,66 +23,85 @@ def __init__(self, path: str) -> None:
5323 # Add an empty params field if none was provided.
5424 if test .get ("params" ) == None :
5525 test ["params" ] = {}
56-
57- # Does this test have sub-tests?
58- sub_tests = test .get ("sub_tests" )
59- if sub_tests != None :
60- params = test ["params" ]
61- answer = test ["answer" ]
62-
63- for sub_test in sub_tests :
64- sub_test ["params" ] = params
65- sub_test ["answer" ] = answer
66- tests .append (SingleTest (sub_test ))
67- else :
68- tests .append (SingleTest (test ))
26+
27+ tests .append (SingleTest (test ))
6928
7029 self .groups .append ({"title" : title , "tests" : tests })
7130 except yaml .YAMLError as e :
7231 raise Exception (f'Error parsing YAML: { e } ' )
7332 else :
7433 raise Exception (f'"{ extension } " files are not supported as a test format.' )
7534
35+
7636class SingleTest :
7737 def __init__ (self , test_dict : dict ):
78- self .response = test_dict .get ("response" , "" )
7938 self .answer = test_dict .get ("answer" , "" )
8039 self .params = test_dict .get ("params" , {})
81- expected_result = test_dict .get ("expected_result" )
82- if not expected_result :
83- raise Exception ("No expected result given for test" )
84- self .is_correct = expected_result .get ("is_correct" )
85- self .results = expected_result
8640 self .desc = test_dict .get ("description" , "" )
8741
88- def evaluate (self , func ) -> dict :
89- return func (self .response , self .answer , self .params )
42+ self .sub_tests = []
43+ if "sub_tests" in test_dict :
44+ for sub_test in test_dict ["sub_tests" ]:
45+ expected_result = sub_test .get ("expected_result" )
46+ if not expected_result :
47+ raise Exception ("No expected result given for test" )
48+
49+ self .sub_tests .append (SubTest (
50+ sub_test .get ("description" , "" ),
51+ sub_test .get ("response" , "" ),
52+ expected_result .get ("is_correct" ),
53+ expected_result ,
54+ ))
55+ else :
56+ expected_result = test_dict .get ("expected_result" )
57+ if not expected_result :
58+ raise Exception ("No expected result given for test" )
59+
60+ self .sub_tests .append (SubTest (
61+ "" ,
62+ test_dict .get ("response" , "" ),
63+ expected_result .get ("is_correct" ),
64+ expected_result ,
65+ ))
66+
67+ def evaluate_all (self , func ) -> list [dict ]:
68+ return [func (test .response , self .answer , self .params ) for test in self .sub_tests ]
9069
91- def compare (self , eval_result : dict ) -> tuple [bool , str ]:
92- eval_correct = eval_result ["is_correct" ]
93-
94- if eval_correct != self .is_correct :
95- return (
96- False ,
97- f"response \" { self .response } \" with answer \" { self .answer } \" was { '' if eval_correct else 'in' } correct: { eval_result ['feedback' ]} \n Test description: { self .desc } "
98- )
99-
100- # Are there any other fields in the eval function result that need to be checked?
101- if self .results != None :
102- # Check each one in turn
103- for key , value in self .results .items ():
104- actual_result_val = eval_result .get (key )
105- if actual_result_val == None :
106- return (False , f"No value returned for \" { key } \" " )
70+ def compare_all (self , eval_results : list [dict ]) -> tuple [bool , str ]:
71+ for i , eval_result in enumerate (eval_results ):
72+ eval_correct = eval_result ["is_correct" ]
10773
108- if actual_result_val != value :
109- return (
110- False ,
111- f"expected { key } = \" { value } \" , got { key } = \" { actual_result_val } \" \n Test description: { self .desc } "
112- )
74+ if eval_correct != self .sub_tests [i ].is_correct :
75+ return (
76+ False ,
77+ (f"response \" { self .sub_tests [i ].response } \" with answer "
78+ f"\" { self .answer } \" was { '' if eval_correct else 'in' } correct: "
79+ f"{ eval_result ['feedback' ]} \n Test description: { self .sub_tests [i ].desc } " )
80+ )
81+
82+ # Are there any other fields in the eval function result that need to be checked?
83+ if self .sub_tests [i ].expected_result != None :
84+ # Check each one in turn
85+ for key , value in self .sub_tests [i ].expected_result .items ():
86+ actual_result_val = eval_result .get (key )
87+ if actual_result_val == None :
88+ return (False , f"No value returned for \" { key } \" " )
89+
90+ if actual_result_val != value :
91+ return (
92+ False ,
93+ f"expected { key } = \" { value } \" , got { key } = \" { actual_result_val } \" \n Test description: { self .desc } "
94+ )
11395
11496 return (True , "" )
11597
98+ @dataclass
99+ class SubTest :
100+ desc : str
101+ response : str
102+ is_correct : bool
103+ expected_result : dict
104+
116105
117106def auto_test (path , func ):
118107 """A decorator that adds the necessary infrastructure to run tests defined
@@ -124,16 +113,18 @@ def _auto_test(orig_class):
124113 def test_auto (self ):
125114 # Creating a TestFile can fail for several reasons.
126115 # If so, an exception is raised with a suitable error message
116+ tests = {}
127117 try :
128- tests = TestFile (path )
118+ with open (path , "r" ) as f :
119+ tests = TestFile (f .read (), path )
129120 except Exception as e :
130121 self .fail (e )
131122
132123 # Successfully loaded
133124 for group in tests .groups :
134125 for test in group ["tests" ]:
135- results = test .evaluate (func )
136- self .assertTrue (* test .compare ( results .to_dict ()))
126+ results = test .evaluate_all (func )
127+ self .assertTrue (* test .compare_all ( map ( lambda r : r .to_dict (), results )))
137128
138129 orig_class .test_auto = test_auto # Add the test_auto function to the class
139130 return orig_class
0 commit comments