Skip to content

Commit e3486da

Browse files
first change
1 parent e927b68 commit e3486da

File tree

5 files changed

+191
-41
lines changed

5 files changed

+191
-41
lines changed

app/eval_single_testing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from evaluation import contains_special_math
2+
print(contains_special_math("dy/dx"))

app/evaluation.py

Lines changed: 133 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1+
import os
2+
from dotenv import load_dotenv
3+
from langchain_openai import ChatOpenAI
4+
15
from typing import Any, TypedDict
6+
from sympy import solve, Eq, simplify
7+
from sympy.parsing.sympy_parser import parse_expr, standard_transformations, implicit_multiplication_application
8+
import re
29

310

411
class Params(TypedDict):
@@ -7,30 +14,134 @@ class Params(TypedDict):
714

815
class Result(TypedDict):
916
is_correct: bool
17+
sympy_result: bool | None
18+
llm_result: bool
19+
mismatch_info: str
20+
21+
22+
transformations = standard_transformations + (implicit_multiplication_application,)
1023

1124

12-
def evaluation_function(response: Any, answer: Any, params: Params) -> Result:
25+
def contains_special_math(expr: str) -> bool:
1326
"""
14-
Function used to evaluate a student response.
15-
---
16-
The handler function passes three arguments to evaluation_function():
17-
18-
- `response` which are the answers provided by the student.
19-
- `answer` which are the correct answers to compare against.
20-
- `params` which are any extra parameters that may be useful,
21-
e.g., error tolerances.
22-
23-
The output of this function is what is returned as the API response
24-
and therefore must be JSON-encodable. It must also conform to the
25-
response schema.
26-
27-
Any standard python library may be used, as well as any package
28-
available on pip (provided it is added to requirements.txt).
29-
30-
The way you wish to structure you code (all in this function, or
31-
split into many) is entirely up to you. All that matters are the
32-
return types and that evaluation_function() is the main function used
33-
to output the evaluation response.
27+
特殊な記号/演算が含まれているか判定
28+
"""
29+
patterns = [
30+
r"d(\^|\*\*)?\d*(\*\*)?\w*/d\w+(\^|\*\*)?\d*(\*\*)?", # Ordinary diff (dy/dx, d^2y/dx^2)
31+
r"∂(\^|\*\*)?\d*(\*\*)?\w*/∂\w+(\^|\*\*)?\d*(\*\*)?", # Partial diff (∂y/∂x, ∂^2y/∂x^2)
32+
r"diff\(\w+, \w+\)", # diff function (diff(y, x))
33+
r"int", # integration (int_b^a f(x)dx)
34+
r"∫",
35+
]
36+
return any(re.search(p, expr) for p in patterns)
37+
38+
39+
def is_equivalent_sympy(expr1, expr2) -> bool | None:
40+
"""
41+
Return True/False if comparable with SymPy,
42+
or None if an error occurs.
43+
"""
44+
try:
45+
expr1, expr2 = expr1.replace("^", "**"), expr2.replace("^", "**")
46+
if not expr1.strip() and not expr2.strip():
47+
return True
48+
elif not expr1.strip() or not expr2.strip():
49+
return False
50+
51+
# Compare with Eq() for equations
52+
if "=" in expr1 and "=" in expr2:
53+
lhs1, rhs1 = expr1.split("=")
54+
lhs2, rhs2 = expr2.split("=")
55+
56+
# implicit multiplication handlable
57+
lhs1_parsed = parse_expr(lhs1, transformations=transformations)
58+
rhs1_parsed = parse_expr(rhs1, transformations=transformations)
59+
lhs2_parsed = parse_expr(lhs2, transformations=transformations)
60+
rhs2_parsed = parse_expr(rhs2, transformations=transformations)
61+
62+
eq1 = Eq(lhs1_parsed - rhs1_parsed, 0)
63+
eq2 = Eq(lhs2_parsed - rhs2_parsed, 0)
64+
65+
all_symbols = eq1.free_symbols.union(eq2.free_symbols)
66+
67+
sol1 = solve(eq1, list(all_symbols))
68+
sol2 = solve(eq2, list(all_symbols))
69+
70+
return set(sol1) == set(sol2)
71+
else:
72+
expr1_parsed = parse_expr(expr1, transformations=transformations)
73+
expr2_parsed = parse_expr(expr2, transformations=transformations)
74+
return simplify(expr1_parsed - expr2_parsed) == 0
75+
76+
except Exception as e:
77+
print(f" SymPy error: {e}")
78+
return None
79+
80+
81+
def evaluation_function(response, answer, params):
82+
load_dotenv()
83+
llm = ChatOpenAI(
84+
model=os.environ['OPENAI_MODEL'],
85+
api_key=os.environ["OPENAI_API_KEY"],
86+
)
87+
88+
# Check if LLM priority is needed
89+
needs_llm_priority = contains_special_math(response) or contains_special_math(answer)
90+
91+
# Check with SymPy first if not using LLM priority
92+
sympy_result = None
93+
if not needs_llm_priority:
94+
sympy_result = is_equivalent_sympy(response, answer)
95+
96+
prompt = fr"""
97+
Follow these steps carefully:
98+
A student response and an answer are provided below. Compare the two if they are mathematically equivalent.
99+
Only return True if they are **exactly equivalent** for all possible values of all variables.
100+
Do not assume expressions are equivalent based on similarity.
101+
There are a few types of symbols for differentiation and the following in the same square brackets are considered equivalent:
102+
[dy/dx, d/dx(y), diff(y,x)], [d^2y/dx^2, d**2y/dx**2, diff(y,x,x)], [∂y/∂x, ∂/∂x(y), diff(y,x), partial(y)/partial(x)], [∂^2y/∂x^2, ∂**2y/∂x**2, diff(y,x,x), partial**2(y)/partial(x)**2, partial^2(y)/partial(x)^2]
103+
The terms above that are not in the same square brackets are not considered equivalent.
104+
Student response: {response}
105+
Answer: {answer}
106+
107+
Return either True or False as a single word and nothing else.
34108
"""
109+
llm_response = llm.invoke(prompt)
110+
llm_result_text = llm_response.content.strip().lower()
111+
112+
if llm_result_text == "true":
113+
llm_result = True
114+
elif llm_result_text == "false":
115+
llm_result = False
116+
else:
117+
# Any weird responses
118+
llm_result = False
35119

36-
return Result(is_correct=True)
120+
if sympy_result is not None:
121+
if sympy_result == llm_result:
122+
return {
123+
"is_correct": sympy_result,
124+
"sympy_result": sympy_result,
125+
"llm_result": llm_result,
126+
"mismatch_info": ""
127+
}
128+
else:
129+
mismatch_info = (
130+
f"Mismatch detected:\n"
131+
f"- SymPy result: {sympy_result}\n"
132+
f"- LLM result: {llm_result}\n"
133+
f"Used LLM result due to mismatch"
134+
)
135+
return {
136+
"is_correct": sympy_result,
137+
"sympy_result": sympy_result,
138+
"llm_result": llm_result,
139+
"mismatch_info": mismatch_info
140+
}
141+
else:
142+
return {
143+
"is_correct": llm_result,
144+
"sympy_result": None,
145+
"llm_result": llm_result,
146+
"mismatch_info": "Used LLM result only"
147+
}

app/evaluation_test_cases.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from evaluation import Params
2+
# [response, answer, params, expected]
3+
test_cases = [
4+
["2+2", "4", Params(), True],
5+
["sin(x)**2 + cos(x)**2", "1", Params(), True],
6+
["x+y", "y+x", Params(), True],
7+
["x*y", "x+y", Params(), False],
8+
["x**2 + 2*x + 1", "(x+1)**2", Params(), True],
9+
["x**2 - 1", "(x-1)*(x+1)", Params(), True],
10+
["x^5-1", "(x-1)*(x**4+x**3+x**2+x+1)", Params(), True],
11+
["sin(x) + cos(x)", "cos(x) + sin(x)", Params(), True],
12+
["sin(x) * cos(x)", "sin(x) + cos(x)", Params(), False],
13+
["exp(x) * exp(y)", "exp(x+y)", Params(), True],
14+
["log(x*y)", "log(x) + log(y)", Params(), False],
15+
["x**3 + x**2", "x**2 * (x + 1)", Params(), True],
16+
["", "", Params(), True],
17+
["", "x", Params(), False],
18+
["x+1=0", "-2x-2=0", Params(), True],
19+
["dy/dx", "diff(y, x)", Params(), True],
20+
["(x+y)/x", "1 + y/x", Params(), True],
21+
["∂y/∂x", "diff(y, x)", Params(), True],
22+
["∫f(x)dx", "int(f(x), x)", Params(), True],
23+
["∂^2y/∂x^2", "diff(diff(y, x), x)", Params(), True],
24+
["dy/dx + 1", "diff(y, x) + 1", Params(), True],
25+
["∂y/∂x + 1", "diff(y, x) + 1", Params(), True],
26+
]

app/evaluation_tests.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,37 @@
44
from .evaluation import Params, evaluation_function
55
except ImportError:
66
from evaluation import Params, evaluation_function
7+
from evaluation_test_cases import test_cases
78

89

910
class TestEvaluationFunction(unittest.TestCase):
1011
"""
1112
TestCase Class used to test the algorithm.
12-
---
13-
Tests are used here to check that the algorithm written
14-
is working as it should.
15-
16-
It's best practise to write these tests first to get a
17-
kind of 'specification' for how your algorithm should
18-
work, and you should run these tests before committing
19-
your code to AWS.
20-
21-
Read the docs on how to use unittest here:
22-
https://docs.python.org/3/library/unittest.html
23-
24-
Use evaluation_function() to check your algorithm works
25-
as it should.
2613
"""
2714

28-
def test_returns_is_correct_true(self):
29-
response, answer, params = None, None, Params()
30-
result = evaluation_function(response, answer, params)
31-
32-
self.assertEqual(result.get("is_correct"), True)
15+
def test_multiple_cases(self):
16+
passed = 0
17+
failed = 0
18+
19+
for i, (response, answer, params, expected) in enumerate(test_cases, 1):
20+
with self.subTest(test_case=i):
21+
result = evaluation_function(response, answer, params)
22+
is_correct = result.get("is_correct")
23+
24+
try:
25+
self.assertEqual(is_correct, expected)
26+
print(f"Test {i} Passed")
27+
passed += 1
28+
except AssertionError:
29+
print(f"Test {i} Failed: expected {expected}, got {is_correct}")
30+
failed += 1
31+
32+
# mismatch_info があれば表示
33+
mismatch_info = result.get("mismatch_info")
34+
if mismatch_info:
35+
print(f"Mismatch Info (Test {i}):\n{mismatch_info}")
36+
37+
print(f"\n--- Summary ---\nPassed: {passed}, Failed: {failed}, Total: {passed + failed}")
3338

3439

3540
if __name__ == "__main__":

app/requirements.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
os
2+
typing
3+
sympy
4+
re
5+
python-dotenv
6+
langchain-openai

0 commit comments

Comments
 (0)