Skip to content

Commit e352d54

Browse files
committed
Move test runner to decorator
1 parent b02c85c commit e352d54

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

evaluation_function/evaluation_test.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import unittest
2-
import json
32

43
from .evaluation import Params, evaluation_function
5-
from .json_tests import get_tests_from_json
4+
from .json_tests import auto_test
65

6+
@auto_test("eval_tests.json", evaluation_function)
77
class TestEvaluationFunction(unittest.TestCase):
88
"""
99
TestCase Class used to test the algorithm.
@@ -106,9 +106,3 @@ def test_brackets(self):
106106

107107
self.assertEqual(result.get("is_correct"), True)
108108
self.assertFalse(result.get("feedback"))
109-
110-
def test_auto(self):
111-
tests = get_tests_from_json("eval_tests.json")
112-
for test in tests:
113-
results = test.evaluate()
114-
self.assertTrue(*test.compare(results.to_dict()))

evaluation_function/json_tests.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import json
2-
from .evaluation import evaluation_function
32

43
class TestData:
54
def __init__(self, test_dict: dict):
@@ -9,8 +8,8 @@ def __init__(self, test_dict: dict):
98
self.is_correct = test_dict["is_correct"]
109
self.results = test_dict.get("results")
1110

12-
def evaluate(self) -> dict:
13-
return evaluation_function(self.response, self.answer, self.params)
11+
def evaluate(self, func) -> dict:
12+
return func(self.response, self.answer, self.params)
1413

1514
def compare(self, eval_result: dict) -> tuple[bool, str]:
1615
eval_correct = eval_result["is_correct"]
@@ -48,3 +47,15 @@ def get_tests_from_json(filename: str) -> list[TestData]:
4847
out.append(TestData(test))
4948

5049
return out
50+
51+
def auto_test(path, func):
52+
def _auto_test(orig_class):
53+
def test_auto(self):
54+
tests = get_tests_from_json(path)
55+
for test in tests:
56+
results = test.evaluate(func)
57+
self.assertTrue(*test.compare(results.to_dict()))
58+
59+
orig_class.test_auto = test_auto # Add the test_auto function to the class
60+
return orig_class
61+
return _auto_test

0 commit comments

Comments
 (0)