Skip to content

Commit 57d3e4e

Browse files
first addition of parameters
1 parent e3486da commit 57d3e4e

File tree

7 files changed

+243
-95
lines changed

7 files changed

+243
-95
lines changed

app/eval_single_testing.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

app/evaluation.py

Lines changed: 136 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from langchain_openai import ChatOpenAI
44

55
from typing import Any, TypedDict
6-
from sympy import solve, Eq, simplify
6+
from sympy import solve, Eq, simplify, Symbol
77
from sympy.parsing.sympy_parser import parse_expr, standard_transformations, implicit_multiplication_application
88
import re
9-
9+
from parameter import create_sympy_parsing_params
1010

1111
class Params(TypedDict):
1212
pass
@@ -21,43 +21,92 @@ class Result(TypedDict):
2121

2222
transformations = standard_transformations + (implicit_multiplication_application,)
2323

24+
def has_unbalanced_parentheses(expr: str) -> bool:
25+
"""
26+
Check if the expression has unbalanced parentheses
27+
"""
28+
return expr.count("(") != expr.count(")")
2429

2530
def contains_special_math(expr: str) -> bool:
2631
"""
27-
特殊な記号/演算が含まれているか判定
32+
Check if the expression contains special mathematical symbols or operations
2833
"""
34+
2935
patterns = [
30-
r"d(\^|\*\*)?\d*(\*\*)?\w*/d\w+(\^|\*\*)?\d*(\*\*)?", # Ordinary diff (dy/dx, d^2y/dx^2)
31-
r"∂(\^|\*\*)?\d*(\*\*)?\w*/∂\w+(\^|\*\*)?\d*(\*\*)?", # Partial diff (∂y/∂x, ∂^2y/∂x^2)
32-
r"diff\(\w+, \w+\)", # diff function (diff(y, x))
33-
r"int", # integration (int_b^a f(x)dx)
34-
r"∫",
36+
# Differentiation
37+
r"d(\*\*)?\d*\w*/d\w+(\*\*)?\d*", # dy/dx, d**2y/dx**2
38+
r"d/d\w+\(.*\)", # d/dx(y)
39+
r"d(\*\*)?\d*/d\w+(\*\*)?\d*\([^\)]+\)", # d**2/dx**2(y)
40+
r"D(\*\*)?\d*\w*/D\w+(\*\*)?\d*", # Dy/Dx, D**2y/Dx**2
41+
r"D/D\w+\(.*\)", # D/Dx(y)
42+
r"∂(\*\*)?\d*\w*/∂\w+(\*\*)?\d*", # ∂y/∂x
43+
r"∂/∂\w+\(.*\)", # ∂/∂x(y)
44+
r"diff\([^\)]+\)", # diff(y, x), diff(y,x,x)
45+
# Integration
46+
r"int\([^\)]+\)", # int(f(x), x)
47+
r"∫", r"∮", # ∫f(x)dx, ∮f(x)dx
48+
# Summation and delta functions
49+
r"Σ", r"∑", # summation symbols
50+
r"Π", r"∏", # product symbols
51+
r"DiracDelta", #delta functions
52+
# Infinity variations
53+
r"Infinity", r"infinity", r"∞", r"oo", r"Inf", r"inf", r"Infty", r"infty"
3554
]
3655
return any(re.search(p, expr) for p in patterns)
3756

38-
39-
def is_equivalent_sympy(expr1, expr2) -> bool | None:
57+
def replace_greek_symbols(expr: str) -> str:
58+
greek_map = {
59+
# 小文字
60+
"alpha": "α", "beta": "β", "gamma": "γ", "delta": "δ",
61+
"epsilon": "ε", "zeta": "ζ", "eta": "η", "theta": "θ",
62+
"iota": "ι", "kappa": "κ", "lambda": "λ", "mu": "μ",
63+
"nu": "ν", "xi": "ξ", "omicron": "ο", "pi": "π",
64+
"rho": "ρ", "sigma": "σ", "tau": "τ", "upsilon": "υ",
65+
"phi": "φ", "chi": "χ", "psi": "ψ", "omega": "ω",
66+
# 大文字
67+
"Alpha": "Α", "Beta": "Β", "Gamma": "Γ", "Delta": "Δ",
68+
"Epsilon": "Ε", "Zeta": "Ζ", "Eta": "Η", "Theta": "Θ",
69+
"Iota": "Ι", "Kappa": "Κ", "Lambda": "Λ", "Mu": "Μ",
70+
"Nu": "Ν", "Xi": "Ξ", "Omicron": "Ο", "Pi": "Π",
71+
"Rho": "Ρ", "Sigma": "Σ", "Tau": "Τ", "Upsilon": "Υ",
72+
"Phi": "Φ", "Chi": "Χ", "Psi": "Ψ", "Omega": "Ω"
73+
}
74+
75+
for ascii_name, greek_letter in greek_map.items():
76+
expr = re.sub(rf"\b{ascii_name}\b", greek_letter, expr)
77+
return expr
78+
79+
def is_equivalent_sympy(expr1, expr2, params) -> bool | None:
4080
"""
4181
Return True/False if comparable with SymPy,
4282
or None if an error occurs.
4383
"""
84+
if not expr1.strip() and not expr2.strip():
85+
return True
86+
if not expr1.strip() or not expr2.strip():
87+
return False
88+
4489
try:
45-
expr1, expr2 = expr1.replace("^", "**"), expr2.replace("^", "**")
46-
if not expr1.strip() and not expr2.strip():
47-
return True
48-
elif not expr1.strip() or not expr2.strip():
49-
return False
90+
# Create parsing parameters (expressions渡す版)
91+
parsing_params = create_sympy_parsing_params(params, expr1, expr2)
92+
raw_dict = parsing_params["symbol_dict"]
93+
transformations = parsing_params.get("extra_transformations", ())
94+
95+
# assumptions0 の辞書なら Symbol を作り直す
96+
local_dict = {
97+
name: Symbol(name, **attrs) if isinstance(attrs, dict) else attrs
98+
for name, attrs in raw_dict.items()
99+
}
50100

51101
# Compare with Eq() for equations
52102
if "=" in expr1 and "=" in expr2:
53103
lhs1, rhs1 = expr1.split("=")
54104
lhs2, rhs2 = expr2.split("=")
55105

56-
# implicit multiplication handlable
57-
lhs1_parsed = parse_expr(lhs1, transformations=transformations)
58-
rhs1_parsed = parse_expr(rhs1, transformations=transformations)
59-
lhs2_parsed = parse_expr(lhs2, transformations=transformations)
60-
rhs2_parsed = parse_expr(rhs2, transformations=transformations)
106+
lhs1_parsed = parse_expr(lhs1, transformations=transformations, local_dict=local_dict)
107+
rhs1_parsed = parse_expr(rhs1, transformations=transformations, local_dict=local_dict)
108+
lhs2_parsed = parse_expr(lhs2, transformations=transformations, local_dict=local_dict)
109+
rhs2_parsed = parse_expr(rhs2, transformations=transformations, local_dict=local_dict)
61110

62111
eq1 = Eq(lhs1_parsed - rhs1_parsed, 0)
63112
eq2 = Eq(lhs2_parsed - rhs2_parsed, 0)
@@ -69,79 +118,86 @@ def is_equivalent_sympy(expr1, expr2) -> bool | None:
69118

70119
return set(sol1) == set(sol2)
71120
else:
72-
expr1_parsed = parse_expr(expr1, transformations=transformations)
73-
expr2_parsed = parse_expr(expr2, transformations=transformations)
121+
expr1_parsed = parse_expr(expr1, transformations=transformations, local_dict=local_dict)
122+
expr2_parsed = parse_expr(expr2, transformations=transformations, local_dict=local_dict)
74123
return simplify(expr1_parsed - expr2_parsed) == 0
75124

76125
except Exception as e:
77-
print(f" SymPy error: {e}")
126+
print(f"SymPy error: {e}")
78127
return None
79128

80-
81-
def evaluation_function(response, answer, params):
129+
def convert_to_sympy(expr: str, params: Params) -> str:
82130
load_dotenv()
83131
llm = ChatOpenAI(
84132
model=os.environ['OPENAI_MODEL'],
85133
api_key=os.environ["OPENAI_API_KEY"],
86134
)
87-
88-
# Check if LLM priority is needed
89-
needs_llm_priority = contains_special_math(response) or contains_special_math(answer)
90-
91-
# Check with SymPy first if not using LLM priority
92-
sympy_result = None
93-
if not needs_llm_priority:
94-
sympy_result = is_equivalent_sympy(response, answer)
95-
96135
prompt = fr"""
97-
Follow these steps carefully:
98-
A student response and an answer are provided below. Compare the two if they are mathematically equivalent.
99-
Only return True if they are **exactly equivalent** for all possible values of all variables.
100-
Do not assume expressions are equivalent based on similarity.
101-
There are a few types of symbols for differentiation and the following in the same square brackets are considered equivalent:
102-
[dy/dx, d/dx(y), diff(y,x)], [d^2y/dx^2, d**2y/dx**2, diff(y,x,x)], [∂y/∂x, ∂/∂x(y), diff(y,x), partial(y)/partial(x)], [∂^2y/∂x^2, ∂**2y/∂x**2, diff(y,x,x), partial**2(y)/partial(x)**2, partial^2(y)/partial(x)^2]
103-
The terms above that are not in the same square brackets are not considered equivalent.
104-
Student response: {response}
105-
Answer: {answer}
106-
107-
Return either True or False as a single word and nothing else.
136+
Follow these steps carefully:
137+
A student response and an answer are provided below. Convert the student response into a SymPy expression.
138+
139+
When the following notations in (a) and (b) are used, they must be replaced with the equivalent SymPy expressions.
140+
All the notations in the same square brackets are equivalent, and must be replaced with the notation after the right arrow (->) after the square brackets.
141+
142+
(a) The following notations for derivatives, partial derivatives, and integrals **must be considered strictly equivalent** within the same group:
143+
- [dy/dx, d/dx(y), diff(y,x)] -> diff(y,x)
144+
- [d^2y/dx^2, d**2y/dx**2, diff(y,x,x)] -> diff(y,x,x)
145+
- [d^3y/dx^3, d**3y/dx**3, diff(y,x,x,x)] -> diff(y,x,x,x)
146+
- [Dy/Dx, D/Dx(y)] -> diff(y,t)+v.dot(gradient(y))
147+
- [∂y/∂x, ∂/∂x(y), diff(y,x), partial(y)/partial(x)] -> diff(y,x)
148+
- [∂^2y/∂x^2, ∂**2y/dx**2, diff(y,x,x), partial**2(y)/partial(x)**2, partial^2(y)/partial(x)^2] -> diff(y,x,x)
149+
- [∫f(x)dx, int(f(x),x), integrate(f(x),x), Integral(f(x),x)] -> integrate(f(x), x)
150+
- [∮f(x)dx, int(f(x),x,circular=True), integrate(f(x),x,circular=True), Integral(f(x),x,circular=True)] -> integrate(f(x), x)
151+
- [∫ₐᵇf(x)dx, ∫_a^bf(x)dx, int_a^bf(x)dx, int(f(x),(x,a,b)), integrate(f(x),(x,a,b)), Integral(f(x),(x,a,b))] -> integrate(f(x), (x, a, b))
152+
- [∫∫f(x,y)dxdy, int(int(f(x,y),x),y), integrate(f(x,y),x,y), Integral(f(x,y),x,y)] -> integrate(integrate(f(x,y),x),y)
153+
- [∇f, gradient(f), grad(f)] -> gradient(f)
154+
- [∇·F, div(F), divergence(F)] -> div(f)
155+
- [∇×F, curl(F), rot(F)] -> curl(f)
156+
157+
(b) Other notations that **must be considered equivalent** within the same group:
158+
- [Infinity, infinity, ∞, oo, Inf, inf, Infty, infty] -> oo
159+
- [a·b, a⋅b, a.b, dot(a, b), a.dot(b)] -> a.dot(b)
160+
*Note: a.b is only equivalent to these if a and b are variables, not constants like 0, 1, π, etc.*
161+
- [a×b, cross(a, b), a.cross(b)] -> a.cross(b)
162+
- [\vec{{a}}, vector(a), a.vector(), Matrix(a)] -> Matrix(a)
163+
- [â, \hat{{a}}, unit(a), normalize(a), a_hat] -> a/Abs(a)
164+
- [exp(x), e**x, e**x, exponential(x)] -> exp(x)
165+
166+
When comparing integrals, assume that any derivative or expression between the integral sign and the differential (e.g., ∂y/∂x in ∫_a ∂y/∂x dx) is the complete integrand, even if parentheses around the integrand are missing.
167+
168+
**Notations from different groups or not listed above are NOT equivalent.**
169+
170+
This is the student response: {expr}
171+
Now convert it to a SymPy expression. Ouput only the SymPy expression, without any additional text or explanation.
108172
"""
109173
llm_response = llm.invoke(prompt)
110-
llm_result_text = llm_response.content.strip().lower()
174+
return llm_response
111175

112-
if llm_result_text == "true":
113-
llm_result = True
114-
elif llm_result_text == "false":
115-
llm_result = False
116-
else:
117-
# Any weird responses
118-
llm_result = False
119-
120-
if sympy_result is not None:
121-
if sympy_result == llm_result:
122-
return {
123-
"is_correct": sympy_result,
124-
"sympy_result": sympy_result,
125-
"llm_result": llm_result,
126-
"mismatch_info": ""
127-
}
128-
else:
129-
mismatch_info = (
130-
f"Mismatch detected:\n"
131-
f"- SymPy result: {sympy_result}\n"
132-
f"- LLM result: {llm_result}\n"
133-
f"Used LLM result due to mismatch"
134-
)
135-
return {
136-
"is_correct": sympy_result,
137-
"sympy_result": sympy_result,
138-
"llm_result": llm_result,
139-
"mismatch_info": mismatch_info
140-
}
141-
else:
176+
def evaluation_function(response, answer, params):
177+
178+
if has_unbalanced_parentheses(response) or has_unbalanced_parentheses(answer):
142179
return {
143-
"is_correct": llm_result,
180+
"is_correct": False,
144181
"sympy_result": None,
145-
"llm_result": llm_result,
146-
"mismatch_info": "Used LLM result only"
147-
}
182+
"llm_result": False,
183+
"mismatch_info": "Invalid syntax: unbalanced parentheses"
184+
}
185+
response = response.replace("^", "**")
186+
answer = answer.replace("^", "**")
187+
response = response.replace(" ", "")
188+
answer = answer.replace(" ", "")
189+
response = replace_greek_symbols(response)
190+
answer = replace_greek_symbols(answer)
191+
192+
if response.strip() == "" or answer.strip() == "":
193+
needs_conversion = False
194+
else:
195+
needs_conversion = contains_special_math(response) or contains_special_math(answer)
196+
197+
if needs_conversion:
198+
response = convert_to_sympy(response, params).content.strip()
199+
answer = convert_to_sympy(answer, params).content.strip()
200+
result = None
201+
result = is_equivalent_sympy(response, answer, params)
202+
203+
return {"is_correct": result}

app/evaluation_test_cases.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from evaluation import Params
1+
from evaluation import Params, evaluation_function
2+
from parameter import create_sympy_parsing_params
23
# [response, answer, params, expected]
34
test_cases = [
4-
["2+2", "4", Params(), True],
5+
["2+2", "4", Params(), True], #1
56
["sin(x)**2 + cos(x)**2", "1", Params(), True],
67
["x+y", "y+x", Params(), True],
78
["x*y", "x+y", Params(), False],
@@ -15,6 +16,7 @@
1516
["x**3 + x**2", "x**2 * (x + 1)", Params(), True],
1617
["", "", Params(), True],
1718
["", "x", Params(), False],
19+
["1+", "1", Params(), False],
1820
["x+1=0", "-2x-2=0", Params(), True],
1921
["dy/dx", "diff(y, x)", Params(), True],
2022
["(x+y)/x", "1 + y/x", Params(), True],
@@ -23,4 +25,30 @@
2325
["∂^2y/∂x^2", "diff(diff(y, x), x)", Params(), True],
2426
["dy/dx + 1", "diff(y, x) + 1", Params(), True],
2527
["∂y/∂x + 1", "diff(y, x) + 1", Params(), True],
28+
["dp/dt", "diff(p, t)", Params(), True],
29+
["dg/dm", "diff(y,x)", Params(), False],
30+
]
31+
test_cases2 = [
32+
["infty", "Infinity", Params(), True], #1
33+
["sqrt(-1)", "I", Params(), True],
34+
["sqrt(x**2)", "x", Params(), False],
35+
["1/(x-1)", "1/(1-x)", Params(), False],
36+
["x^2", "x**2", Params(), True],
37+
["x^^2", "x**2", Params(), False],
38+
["d^3y/dx^3", "diff(y, x, x, x)", Params(), True],
39+
["∫∫f(x)dxdy", "int(int(f(x), x), y)", Params(), True],
40+
["f(x)=x+1", "f(x)-x-1=0", Params(), True],
41+
["f(x) = x**2", "f(y) = y**2", Params(), False],#should this always be false?
42+
["diff(y,x)+", "diff(y,x)+0", Params(), False],
43+
["d/dx(y", "diff(y, x)", Params(), False],
44+
["DiracDelta(x)", "0", Params(), False],
45+
["∫_{V_sys} ∂ρ/∂t dV", "int(partial(ρ)/partial(t), (V, V_sys))", Params(), True],
46+
["∮_{A_sys} ρu·n̂ dA", "Integral(rho * u.dot(n), (A, A_sys), circular=True)", Params(), True],
47+
["rho", "ρ", Params(), True],
48+
["Dx/Dt=-div(u)", "Dx/Dt+div(u)=0", Params(), True],
49+
["(1/rho)*Drho/Dt=-div(u)", "(1/ρ)*Dρ/Dt+div(u)=0", Params(), True],
50+
]
51+
test_cases3 = [
52+
["abs(x)", "sqrt(x**2)", Params(), False],
53+
["abs(x)", "sqrt(x**2)", Params(symbol_assumptions={"x": {"real": True},}), True],
2654
]

app/evaluation_tests.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .evaluation import Params, evaluation_function
55
except ImportError:
66
from evaluation import Params, evaluation_function
7-
from evaluation_test_cases import test_cases
7+
from evaluation_test_cases import test_cases, test_cases2,test_cases3
88

99

1010
class TestEvaluationFunction(unittest.TestCase):
@@ -15,8 +15,8 @@ class TestEvaluationFunction(unittest.TestCase):
1515
def test_multiple_cases(self):
1616
passed = 0
1717
failed = 0
18-
19-
for i, (response, answer, params, expected) in enumerate(test_cases, 1):
18+
case = [test_cases, test_cases2, test_cases3]
19+
for i, (response, answer, params, expected) in enumerate(case[2], 1): #change here test_cases <-> test_cases2
2020
with self.subTest(test_case=i):
2121
result = evaluation_function(response, answer, params)
2222
is_correct = result.get("is_correct")
@@ -26,16 +26,15 @@ def test_multiple_cases(self):
2626
print(f"Test {i} Passed")
2727
passed += 1
2828
except AssertionError:
29-
print(f"Test {i} Failed: expected {expected}, got {is_correct}")
29+
print(f"Test {i} Failed:")
30+
print(f" Response: {response}")
31+
print(f" Answer : {answer}")
32+
print(f" Params : {params}")
33+
print(f" Expected: {expected}, Got: {is_correct}")
3034
failed += 1
3135

32-
# mismatch_info があれば表示
33-
mismatch_info = result.get("mismatch_info")
34-
if mismatch_info:
35-
print(f"Mismatch Info (Test {i}):\n{mismatch_info}")
36-
3736
print(f"\n--- Summary ---\nPassed: {passed}, Failed: {failed}, Total: {passed + failed}")
3837

3938

4039
if __name__ == "__main__":
41-
unittest.main()
40+
unittest.main()

app/miscellaneous_testing.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from parameter import extract_variable_names, create_sympy_parsing_params
2+
from evaluation import Params, evaluation_function
3+
from sympy.parsing.sympy_parser import parse_expr
4+
from sympy import solve, Eq, simplify, Symbol
5+
6+
# parsing_params = create_sympy_parsing_params(Params(symbol_assumptions={"x": {"real": True}}), "x+2+z", "x*y*z")
7+
# local_dict = parsing_params["symbol_dict"]
8+
# print(local_dict)
9+
lhs, rhs = "2*x+1", "0"
10+
eq = Eq(parse_expr(lhs), parse_expr(rhs))
11+
print(solve(eq))
12+
lhs, rhs = "x", "-1/2"
13+
eq = Eq(parse_expr(lhs), parse_expr(rhs))
14+
print(solve(eq))

0 commit comments

Comments
 (0)