22from dotenv import load_dotenv
33from langchain_openai import ChatOpenAI
44from typing import Any , TypedDict , Union
5- from sympy import solve , Eq , simplify , Expr , symbols , Symbol , Function , FunctionClass , Integral , Derivative , Matrix , Abs , sin , cos , tan , sqrt , log , exp
5+ from sympy import solve , Eq , simplify , Expr , symbols , Symbol , Function , FunctionClass , Integral , Derivative , Matrix , Abs , sin , cos , tan , sqrt , log , exp , oo
66from sympy .core .function import AppliedUndef
77from sympy .matrices import MatrixBase
88from sympy import Basic
99from sympy .parsing .sympy_parser import parse_expr , standard_transformations , implicit_multiplication_application
1010import re
11- from parameter import create_sympy_parsing_params , Params
11+ from parameter import create_sympy_parsing_params , Params , apply_declared_functions , parse_domain , check_in_domain
1212from re_conversion import convert_diff_re , convert_integral_re , convert_other_re
1313from llm_conversion import convert_diff , convert_integral , convert_other
1414from FMX2_symbols import Fluids
15+ from evaluation_symbolise import symbolise_by_operators , is_matching_parens
16+ from sympy import sympify
1517
1618
1719class Result (TypedDict ):
@@ -20,9 +22,6 @@ class Result(TypedDict):
2022 llm_result : bool
2123 mismatch_info : str
2224
23-
24- transformations = standard_transformations + (implicit_multiplication_application ,)
25-
2625def strip_outer_parens (expr : str ) -> str :
2726 expr = expr .strip ()
2827 if not expr .startswith ("(" ) or not expr .endswith (")" ):
@@ -44,6 +43,7 @@ def has_unbalanced_parentheses(expr: str) -> bool:
4443 Check if the expression has unbalanced parentheses
4544 """
4645 return expr .count ("(" ) != expr .count (")" )
46+
4747class contains_special_math ():
4848 def contains_diff (self , expr : str ) -> bool :
4949 patterns = [
@@ -124,14 +124,14 @@ def contains_other(self, expr: str) -> bool:
124124
125125def replace_greek_symbols (expr : str ) -> str :
126126 greek_map = {
127- # 小文字
127+
128128 "α" : "alpha" , "β" : "beta" , "γ" : "gamma" , "δ" : "delta" ,
129129 "ε" : "epsilon" , "ζ" : "zeta" , "η" : "eta" , "θ" : "theta" ,
130130 "ι" : "iota" , "κ" : "kappa" , "λ" : "lambda" , "μ" : "mu" ,
131131 "ν" : "nu" , "ξ" : "xi" , "ο" : "omicron" , "π" : "pi" ,
132132 "ρ" : "rho" , "σ" : "sigma" , "τ" : "tau" , "υ" : "upsilon" ,
133133 "φ" : "phi" , "χ" : "chi" , "ψ" : "psi" , "ω" : "omega" ,
134- # 大文字
134+
135135 "Α" : "Alpha" , "Β" : "Beta" , "Γ" : "Gamma" , "Δ" : "Delta" ,
136136 "Ε" : "Epsilon" , "Ζ" : "Zeta" , "Η" : "Eta" , "Θ" : "Theta" ,
137137 "Ι" : "Iota" , "Κ" : "Kappa" , "Λ" : "Lambda" , "Μ" : "Mu" ,
@@ -149,8 +149,13 @@ def extract_symbols(expr: str) -> dict:
149149
150150 # high_order_pattern = r"\b(?:d|del)\*\*\d+[a-zA-Z_]\w*/(?:d|del)[a-zA-Z_]\w*\*\*\d+\b"
151151 # first_order_pattern = r"\b(?:d|del)[a-zA-Z_]\w*/(?:d|del)[a-zA-Z_]\w*\b"
152- material_pattern = r"\bD_[a-zA-Z_]\w*_[a-zA-Z_]\w*\b"
153-
152+ material_pattern = (
153+ r"\bD_[A-Za-z_]\w*_[A-Za-z_]\w*\b" # D_var1_var2 形式
154+ r"|" # または
155+ r"\bD[A-Za-z_]\w*/D[A-Za-z_]\w*\b" # Dy/Dx 形式
156+ r"|" # または
157+ r"\bD/D[A-Za-z_]\w*(?:\(\s*[A-Za-z_]\w*\s*\)|\s+[A-Za-z_]\w*)" # D/Dx(y) 形式
158+ )
154159 # intg_pattern = r"(?:o?intg)\((?:[^()]+|\((?:[^()]+|\([^()]*\))*\))*\)"
155160
156161 # nabla_pattern = r"(?:\bgrad\b|\bdivg\b|\brot\b|\bdot\b|\bcross\b|\bvec\b|\bhat\b)"
@@ -194,16 +199,17 @@ def is_equivalent_sympy(expr1, expr2, params) -> Union[bool, None]:
194199 fd_dict = vars (fd )
195200 valid_types = (Basic , MatrixBase )
196201 fd_func_dict = {
197- k . replace ( '_func' , '' ) : v
202+ k [: - 5 ] : v
198203 for k , v in fd_dict .items ()
199204 if k .endswith ('_func' ) and isinstance (v , FunctionClass )
200205 }
201-
206+ func_names = set ( fd_func_dict . keys ())
202207 fd_filtered = {
203208 k : v for k , v in fd_dict .items ()
204209 if isinstance (v , valid_types )
205210 and not k .endswith ('_func' )
206211 and not isinstance (v , AppliedUndef )
212+ and k not in func_names
207213 }
208214
209215 # Build local_dict for parser
@@ -212,18 +218,28 @@ def is_equivalent_sympy(expr1, expr2, params) -> Union[bool, None]:
212218 "Divergence" : fd .Divergence ,
213219 "Curl" : fd .Curl ,
214220 "smart_derivative" : fd .smart_derivative ,
221+ "smart_dot" : fd .smart_dot ,
215222 ** fd_func_dict ,
216223 ** fd_filtered ,
217224 # **symbols1,
218225 # **symbols2,
219226 }
220227
221228 for name , sym in raw_dict .items ():
222- if not isinstance (sym , dict ):
223- if params :
224- local_dict [name ] = sym
225- elif name not in local_dict :
226- local_dict [name ] = sym
229+ if isinstance (sym , dict ):
230+ local_dict [name ] = Symbol (name , ** sym )
231+ else :
232+ if params :
233+ if isinstance (params , dict ) and params :
234+ local_dict [name ] = sym
235+ elif name not in local_dict :
236+ local_dict [name ] = sym
237+
238+ if isinstance (params , dict ) and params .get ("function" ):
239+ if isinstance (expr1 , str ):
240+ expr1 = apply_declared_functions (expr1 , params ["function" ])
241+ if isinstance (expr2 , str ):
242+ expr2 = apply_declared_functions (expr2 , params ["function" ])
227243
228244 def ensure_expr (expr ):
229245 if isinstance (expr , str ):
@@ -245,7 +261,6 @@ def ensure_expr(expr):
245261 args = func_args_map [name ]
246262 applied_funcs [name ] = func (* args )
247263
248- # 置換:既に呼び出されていない関数だけ変換
249264 for name , applied in applied_funcs .items ():
250265 pattern = rf'(?<!\w){ name } (?!\w|\s*\()'
251266 expr = re .sub (pattern , f'({ str (applied )} )' , expr )
@@ -256,6 +271,23 @@ def ensure_expr(expr):
256271
257272 # Handle equations
258273 if "=" in str (expr1 ) and "=" in str (expr2 ):
274+ def unwrap_parens_if_equal (s : str ) -> str :
275+ s = s .strip ()
276+ if "=" not in s :
277+ if s .startswith ("(" ) and s .endswith (")" ) and is_matching_parens (s ):
278+ inner = s [1 :- 1 ].strip ()
279+ return inner
280+ else :
281+ left , right = s .split ("=" )
282+ if left .startswith ("(" ) and left .endswith (")" ) and is_matching_parens (left ):
283+ left = left [1 :- 1 ].strip ()
284+ if right .startswith ("(" ) and right .endswith (")" ) and is_matching_parens (right ):
285+ right = right [1 :- 1 ].strip ()
286+ return left + "=" + right
287+ return s
288+ if isinstance (expr1 , str ) and isinstance (expr2 , str ):
289+ expr1 = unwrap_parens_if_equal (expr1 )
290+ expr2 = unwrap_parens_if_equal (expr2 )
259291 lhs1 , rhs1 = str (expr1 ).split ("=" )
260292 lhs2 , rhs2 = str (expr2 ).split ("=" )
261293
@@ -276,6 +308,9 @@ def ensure_expr(expr):
276308 # Handle expression comparison
277309 expr1_parsed = ensure_expr (expr1 )
278310 expr2_parsed = ensure_expr (expr2 )
311+ expr1 = apply_declared_functions (expr1 , params .get ("function" , []))
312+ expr2 = apply_declared_functions (expr2 , params .get ("function" , []))
313+
279314
280315 if isinstance (expr1_parsed , MatrixBase ) and isinstance (expr2_parsed , MatrixBase ):
281316 return simplify (expr1_parsed - expr2_parsed ) == Matrix .zeros (* expr1_parsed .shape )
@@ -284,8 +319,30 @@ def ensure_expr(expr):
284319
285320
286321 except Exception as e :
287- print (f"SymPy error: { e } " )
288- return False
322+ if "could not solve" in str (e ).lower ():
323+ try :
324+ expr1_symbolised , expr2_symbolised , symbol_map = symbolise_by_operators (expr1 , expr2 )
325+ if "=" in str (expr1 ) and "=" in str (expr2 ):
326+ lhs1 , rhs1 = str (expr1_symbolised ).split ("=" )
327+ lhs2 , rhs2 = str (expr2_symbolised ).split ("=" )
328+ lhs1_expr = parse_expr (lhs1 .strip ())
329+ rhs1_expr = parse_expr (rhs1 .strip ())
330+ lhs2_expr = parse_expr (lhs2 .strip ())
331+ rhs2_expr = parse_expr (rhs2 .strip ())
332+ eq1 = Eq (lhs1_expr - rhs1_expr , 0 )
333+ eq2 = Eq (lhs2_expr - rhs2_expr , 0 )
334+ all_symbols = eq1 .free_symbols .union (eq2 .free_symbols )
335+ sol1 = solve (eq1 , list (all_symbols ))
336+ sol2 = solve (eq2 , list (all_symbols ))
337+ return set (sol1 ) == set (sol2 )
338+ else :
339+ return simplify (expr1_symbolised - expr2_symbolised ) == 0
340+ except Exception as fallback_error :
341+ print (f"Fallback error: { fallback_error } " )
342+ return None
343+ else :
344+ print (f"SymPy error: { e } " )
345+ return False
289346
290347def evaluation_function (response , answer , params ):
291348 if has_unbalanced_parentheses (response ) or has_unbalanced_parentheses (answer ):
@@ -298,6 +355,8 @@ def evaluation_function(response, answer, params):
298355 }
299356 response = response .replace ("^" , "**" )
300357 answer = answer .replace ("^" , "**" )
358+ response = response .replace (" " , "" )
359+ answer = answer .replace (" " , "" )
301360 response = replace_greek_symbols (response )
302361 answer = replace_greek_symbols (answer )
303362 print (response , answer )
@@ -321,22 +380,44 @@ def evaluation_function(response, answer, params):
321380
322381 if needs_conversion :
323382 if response_has_other :
324- response = convert_other (response , params ).content .strip ()
383+ # response = convert_other(response, params).content.strip()
384+ response = convert_other_re (response , params )
325385 if response_has_diff :
326- response = convert_diff (response , params ).content .strip ()
386+ # response = convert_diff(response, params).content.strip()
387+ response = convert_diff_re (response , params )
327388 if response_has_integral :
328- response = convert_integral (response , params ).content .strip ()
389+ # response = convert_integral(response, params).content.strip()
390+ response = convert_integral_re (response , params )
329391
330392 if answer_has_other :
331- answer = convert_other (answer , params ).content .strip ()
393+ # answer = convert_other(answer, params).content.strip()
394+ answer = convert_other_re (answer , params )
332395 if answer_has_diff :
333- answer = convert_diff (answer , params ).content .strip ()
396+ # answer = convert_diff(answer, params).content.strip()
397+ answer = convert_diff_re (answer , params )
334398 if answer_has_integral :
335- answer = convert_integral (answer , params ).content .strip ()
399+ # answer = convert_integral(answer, params).content.strip()
400+ answer = convert_integral_re (answer , params )
336401
337- print (response , answer ) #parentheses not removed but will be removed later
402+ print ("response converted: " + response )
403+ print ("answer converted: " + answer )
338404 response = strip_outer_parens (response )
339405 answer = strip_outer_parens (answer )
340406 result = is_equivalent_sympy (response , answer , params )
341407
408+ if "domain" in params and params ["domain" ]:
409+ try :
410+ domain = params ["domain" ]
411+ if isinstance (domain , str ):
412+ domain = parse_domain (domain )
413+
414+ val = sympify (response )
415+
416+ if val .is_number :
417+ if not check_in_domain (val , domain ):
418+ return {"is_correct" : False }
419+
420+ except Exception as e :
421+ print ("error in domain check:" , e )
422+
342423 return {"is_correct" : result }
0 commit comments