33from langchain_openai import ChatOpenAI
44
55from typing import Any , TypedDict
6- from sympy import solve , Eq , simplify
6+ from sympy import solve , Eq , simplify , Symbol
77from sympy .parsing .sympy_parser import parse_expr , standard_transformations , implicit_multiplication_application
88import re
9-
9+ from parameter import create_sympy_parsing_params
1010
1111class Params (TypedDict ):
1212 pass
@@ -21,43 +21,92 @@ class Result(TypedDict):
2121
2222transformations = standard_transformations + (implicit_multiplication_application ,)
2323
24+ def has_unbalanced_parentheses (expr : str ) -> bool :
25+ """
26+ Check if the expression has unbalanced parentheses
27+ """
28+ return expr .count ("(" ) != expr .count (")" )
2429
2530def contains_special_math (expr : str ) -> bool :
2631 """
27- 特殊な記号/演算が含まれているか判定
32+ Check if the expression contains special mathematical symbols or operations
2833 """
34+
2935 patterns = [
30- r"d(\^|\*\*)?\d*(\*\*)?\w*/d\w+(\^|\*\*)?\d*(\*\*)?" , # Ordinary diff (dy/dx, d^2y/dx^2)
31- r"∂(\^|\*\*)?\d*(\*\*)?\w*/∂\w+(\^|\*\*)?\d*(\*\*)?" , # Partial diff (∂y/∂x, ∂^2y/∂x^2)
32- r"diff\(\w+, \w+\)" , # diff function (diff(y, x))
33- r"int" , # integration (int_b^a f(x)dx)
34- r"∫" ,
36+ # Differentiation
37+ r"d(\*\*)?\d*\w*/d\w+(\*\*)?\d*" , # dy/dx, d**2y/dx**2
38+ r"d/d\w+\(.*\)" , # d/dx(y)
39+ r"d(\*\*)?\d*/d\w+(\*\*)?\d*\([^\)]+\)" , # d**2/dx**2(y)
40+ r"D(\*\*)?\d*\w*/D\w+(\*\*)?\d*" , # Dy/Dx, D**2y/Dx**2
41+ r"D/D\w+\(.*\)" , # D/Dx(y)
42+ r"∂(\*\*)?\d*\w*/∂\w+(\*\*)?\d*" , # ∂y/∂x
43+ r"∂/∂\w+\(.*\)" , # ∂/∂x(y)
44+ r"diff\([^\)]+\)" , # diff(y, x), diff(y,x,x)
45+ # Integration
46+ r"int\([^\)]+\)" , # int(f(x), x)
47+ r"∫" , r"∮" , # ∫f(x)dx, ∮f(x)dx
48+ # Summation and delta functions
49+ r"Σ" , r"∑" , # summation symbols
50+ r"Π" , r"∏" , # product symbols
51+ r"DiracDelta" , #delta functions
52+ # Infinity variations
53+ r"Infinity" , r"infinity" , r"∞" , r"oo" , r"Inf" , r"inf" , r"Infty" , r"infty"
3554 ]
3655 return any (re .search (p , expr ) for p in patterns )
3756
38-
39- def is_equivalent_sympy (expr1 , expr2 ) -> bool | None :
57+ def replace_greek_symbols (expr : str ) -> str :
58+ greek_map = {
59+ # 小文字
60+ "alpha" : "α" , "beta" : "β" , "gamma" : "γ" , "delta" : "δ" ,
61+ "epsilon" : "ε" , "zeta" : "ζ" , "eta" : "η" , "theta" : "θ" ,
62+ "iota" : "ι" , "kappa" : "κ" , "lambda" : "λ" , "mu" : "μ" ,
63+ "nu" : "ν" , "xi" : "ξ" , "omicron" : "ο" , "pi" : "π" ,
64+ "rho" : "ρ" , "sigma" : "σ" , "tau" : "τ" , "upsilon" : "υ" ,
65+ "phi" : "φ" , "chi" : "χ" , "psi" : "ψ" , "omega" : "ω" ,
66+ # 大文字
67+ "Alpha" : "Α" , "Beta" : "Β" , "Gamma" : "Γ" , "Delta" : "Δ" ,
68+ "Epsilon" : "Ε" , "Zeta" : "Ζ" , "Eta" : "Η" , "Theta" : "Θ" ,
69+ "Iota" : "Ι" , "Kappa" : "Κ" , "Lambda" : "Λ" , "Mu" : "Μ" ,
70+ "Nu" : "Ν" , "Xi" : "Ξ" , "Omicron" : "Ο" , "Pi" : "Π" ,
71+ "Rho" : "Ρ" , "Sigma" : "Σ" , "Tau" : "Τ" , "Upsilon" : "Υ" ,
72+ "Phi" : "Φ" , "Chi" : "Χ" , "Psi" : "Ψ" , "Omega" : "Ω"
73+ }
74+
75+ for ascii_name , greek_letter in greek_map .items ():
76+ expr = re .sub (rf"\b{ ascii_name } \b" , greek_letter , expr )
77+ return expr
78+
79+ def is_equivalent_sympy (expr1 , expr2 , params ) -> bool | None :
4080 """
4181 Return True/False if comparable with SymPy,
4282 or None if an error occurs.
4383 """
84+ if not expr1 .strip () and not expr2 .strip ():
85+ return True
86+ if not expr1 .strip () or not expr2 .strip ():
87+ return False
88+
4489 try :
45- expr1 , expr2 = expr1 .replace ("^" , "**" ), expr2 .replace ("^" , "**" )
46- if not expr1 .strip () and not expr2 .strip ():
47- return True
48- elif not expr1 .strip () or not expr2 .strip ():
49- return False
90+ # Create parsing parameters (expressions渡す版)
91+ parsing_params = create_sympy_parsing_params (params , expr1 , expr2 )
92+ raw_dict = parsing_params ["symbol_dict" ]
93+ transformations = parsing_params .get ("extra_transformations" , ())
94+
95+ # assumptions0 の辞書なら Symbol を作り直す
96+ local_dict = {
97+ name : Symbol (name , ** attrs ) if isinstance (attrs , dict ) else attrs
98+ for name , attrs in raw_dict .items ()
99+ }
50100
51101 # Compare with Eq() for equations
52102 if "=" in expr1 and "=" in expr2 :
53103 lhs1 , rhs1 = expr1 .split ("=" )
54104 lhs2 , rhs2 = expr2 .split ("=" )
55105
56- # implicit multiplication handlable
57- lhs1_parsed = parse_expr (lhs1 , transformations = transformations )
58- rhs1_parsed = parse_expr (rhs1 , transformations = transformations )
59- lhs2_parsed = parse_expr (lhs2 , transformations = transformations )
60- rhs2_parsed = parse_expr (rhs2 , transformations = transformations )
106+ lhs1_parsed = parse_expr (lhs1 , transformations = transformations , local_dict = local_dict )
107+ rhs1_parsed = parse_expr (rhs1 , transformations = transformations , local_dict = local_dict )
108+ lhs2_parsed = parse_expr (lhs2 , transformations = transformations , local_dict = local_dict )
109+ rhs2_parsed = parse_expr (rhs2 , transformations = transformations , local_dict = local_dict )
61110
62111 eq1 = Eq (lhs1_parsed - rhs1_parsed , 0 )
63112 eq2 = Eq (lhs2_parsed - rhs2_parsed , 0 )
@@ -69,79 +118,86 @@ def is_equivalent_sympy(expr1, expr2) -> bool | None:
69118
70119 return set (sol1 ) == set (sol2 )
71120 else :
72- expr1_parsed = parse_expr (expr1 , transformations = transformations )
73- expr2_parsed = parse_expr (expr2 , transformations = transformations )
121+ expr1_parsed = parse_expr (expr1 , transformations = transformations , local_dict = local_dict )
122+ expr2_parsed = parse_expr (expr2 , transformations = transformations , local_dict = local_dict )
74123 return simplify (expr1_parsed - expr2_parsed ) == 0
75124
76125 except Exception as e :
77- print (f" SymPy error: { e } " )
126+ print (f"SymPy error: { e } " )
78127 return None
79128
80-
81- def evaluation_function (response , answer , params ):
129+ def convert_to_sympy (expr : str , params : Params ) -> str :
82130 load_dotenv ()
83131 llm = ChatOpenAI (
84132 model = os .environ ['OPENAI_MODEL' ],
85133 api_key = os .environ ["OPENAI_API_KEY" ],
86134 )
87-
88- # Check if LLM priority is needed
89- needs_llm_priority = contains_special_math (response ) or contains_special_math (answer )
90-
91- # Check with SymPy first if not using LLM priority
92- sympy_result = None
93- if not needs_llm_priority :
94- sympy_result = is_equivalent_sympy (response , answer )
95-
96135 prompt = fr"""
97- Follow these steps carefully:
98- A student response and an answer are provided below. Compare the two if they are mathematically equivalent.
99- Only return True if they are **exactly equivalent** for all possible values of all variables.
100- Do not assume expressions are equivalent based on similarity.
101- There are a few types of symbols for differentiation and the following in the same square brackets are considered equivalent:
102- [dy/dx, d/dx(y), diff(y,x)], [d^2y/dx^2, d**2y/dx**2, diff(y,x,x)], [∂y/∂x, ∂/∂x(y), diff(y,x), partial(y)/partial(x)], [∂^2y/∂x^2, ∂**2y/∂x**2, diff(y,x,x), partial**2(y)/partial(x)**2, partial^2(y)/partial(x)^2]
103- The terms above that are not in the same square brackets are not considered equivalent.
104- Student response: { response }
105- Answer: { answer }
106-
107- Return either True or False as a single word and nothing else.
136+ Follow these steps carefully:
137+ A student response and an answer are provided below. Convert the student response into a SymPy expression.
138+
139+ When the following notations in (a) and (b) are used, they must be replaced with the equivalent SymPy expressions.
140+ All the notations in the same square brackets are equivalent, and must be replaced with the notation after the right arrow (->) after the square brackets.
141+
142+ (a) The following notations for derivatives, partial derivatives, and integrals **must be considered strictly equivalent** within the same group:
143+ - [dy/dx, d/dx(y), diff(y,x)] -> diff(y,x)
144+ - [d^2y/dx^2, d**2y/dx**2, diff(y,x,x)] -> diff(y,x,x)
145+ - [d^3y/dx^3, d**3y/dx**3, diff(y,x,x,x)] -> diff(y,x,x,x)
146+ - [Dy/Dx, D/Dx(y)] -> diff(y,t)+v.dot(gradient(y))
147+ - [∂y/∂x, ∂/∂x(y), diff(y,x), partial(y)/partial(x)] -> diff(y,x)
148+ - [∂^2y/∂x^2, ∂**2y/dx**2, diff(y,x,x), partial**2(y)/partial(x)**2, partial^2(y)/partial(x)^2] -> diff(y,x,x)
149+ - [∫f(x)dx, int(f(x),x), integrate(f(x),x), Integral(f(x),x)] -> integrate(f(x), x)
150+ - [∮f(x)dx, int(f(x),x,circular=True), integrate(f(x),x,circular=True), Integral(f(x),x,circular=True)] -> integrate(f(x), x)
151+ - [∫ₐᵇf(x)dx, ∫_a^bf(x)dx, int_a^bf(x)dx, int(f(x),(x,a,b)), integrate(f(x),(x,a,b)), Integral(f(x),(x,a,b))] -> integrate(f(x), (x, a, b))
152+ - [∫∫f(x,y)dxdy, int(int(f(x,y),x),y), integrate(f(x,y),x,y), Integral(f(x,y),x,y)] -> integrate(integrate(f(x,y),x),y)
153+ - [∇f, gradient(f), grad(f)] -> gradient(f)
154+ - [∇·F, div(F), divergence(F)] -> div(f)
155+ - [∇×F, curl(F), rot(F)] -> curl(f)
156+
157+ (b) Other notations that **must be considered equivalent** within the same group:
158+ - [Infinity, infinity, ∞, oo, Inf, inf, Infty, infty] -> oo
159+ - [a·b, a⋅b, a.b, dot(a, b), a.dot(b)] -> a.dot(b)
160+ *Note: a.b is only equivalent to these if a and b are variables, not constants like 0, 1, π, etc.*
161+ - [a×b, cross(a, b), a.cross(b)] -> a.cross(b)
162+ - [\vec{{a}}, vector(a), a.vector(), Matrix(a)] -> Matrix(a)
163+ - [â, \hat{{a}}, unit(a), normalize(a), a_hat] -> a/Abs(a)
164+ - [exp(x), e**x, e**x, exponential(x)] -> exp(x)
165+
166+ When comparing integrals, assume that any derivative or expression between the integral sign and the differential (e.g., ∂y/∂x in ∫_a ∂y/∂x dx) is the complete integrand, even if parentheses around the integrand are missing.
167+
168+ **Notations from different groups or not listed above are NOT equivalent.**
169+
170+ This is the student response: { expr }
171+ Now convert it to a SymPy expression. Ouput only the SymPy expression, without any additional text or explanation.
108172 """
109173 llm_response = llm .invoke (prompt )
110- llm_result_text = llm_response . content . strip (). lower ()
174+ return llm_response
111175
112- if llm_result_text == "true" :
113- llm_result = True
114- elif llm_result_text == "false" :
115- llm_result = False
116- else :
117- # Any weird responses
118- llm_result = False
119-
120- if sympy_result is not None :
121- if sympy_result == llm_result :
122- return {
123- "is_correct" : sympy_result ,
124- "sympy_result" : sympy_result ,
125- "llm_result" : llm_result ,
126- "mismatch_info" : ""
127- }
128- else :
129- mismatch_info = (
130- f"Mismatch detected:\n "
131- f"- SymPy result: { sympy_result } \n "
132- f"- LLM result: { llm_result } \n "
133- f"Used LLM result due to mismatch"
134- )
135- return {
136- "is_correct" : sympy_result ,
137- "sympy_result" : sympy_result ,
138- "llm_result" : llm_result ,
139- "mismatch_info" : mismatch_info
140- }
141- else :
176+ def evaluation_function (response , answer , params ):
177+
178+ if has_unbalanced_parentheses (response ) or has_unbalanced_parentheses (answer ):
142179 return {
143- "is_correct" : llm_result ,
180+ "is_correct" : False ,
144181 "sympy_result" : None ,
145- "llm_result" : llm_result ,
146- "mismatch_info" : "Used LLM result only"
147- }
182+ "llm_result" : False ,
183+ "mismatch_info" : "Invalid syntax: unbalanced parentheses"
184+ }
185+ response = response .replace ("^" , "**" )
186+ answer = answer .replace ("^" , "**" )
187+ response = response .replace (" " , "" )
188+ answer = answer .replace (" " , "" )
189+ response = replace_greek_symbols (response )
190+ answer = replace_greek_symbols (answer )
191+
192+ if response .strip () == "" or answer .strip () == "" :
193+ needs_conversion = False
194+ else :
195+ needs_conversion = contains_special_math (response ) or contains_special_math (answer )
196+
197+ if needs_conversion :
198+ response = convert_to_sympy (response , params ).content .strip ()
199+ answer = convert_to_sympy (answer , params ).content .strip ()
200+ result = None
201+ result = is_equivalent_sympy (response , answer , params )
202+
203+ return {"is_correct" : result }
0 commit comments