22from dotenv import load_dotenv
33from langchain_openai import ChatOpenAI
44from typing import Any , TypedDict , Union
5- from sympy import solve , Eq , simplify , Symbol , Function , integrate , diff
5+ from sympy import solve , Eq , simplify , Expr , symbols , Symbol , Function , FunctionClass , Integral , Derivative , Matrix , Abs , sin , cos , tan , sqrt , log , exp
6+ from sympy .core .function import AppliedUndef
7+ from sympy .matrices import MatrixBase
8+ from sympy import Basic
69from sympy .parsing .sympy_parser import parse_expr , standard_transformations , implicit_multiplication_application
710import re
811from parameter import create_sympy_parsing_params , Params
912from re_conversion import convert_diff_re , convert_integral_re , convert_other_re
13+ from llm_conversion import convert_diff , convert_integral , convert_other
14+ from FMX2_symbols import Fluids
1015
1116
1217class Result (TypedDict ):
@@ -18,6 +23,21 @@ class Result(TypedDict):
1823
1924transformations = standard_transformations + (implicit_multiplication_application ,)
2025
26+ def strip_outer_parens (expr : str ) -> str :
27+ expr = expr .strip ()
28+ if not expr .startswith ("(" ) or not expr .endswith (")" ):
29+ return expr
30+
31+ depth = 0
32+ for i , char in enumerate (expr ):
33+ if char == "(" :
34+ depth += 1
35+ elif char == ")" :
36+ depth -= 1
37+ if depth == 0 and i != len (expr ) - 1 :
38+ return expr
39+ return expr [1 :- 1 ].strip ()
40+
2141
2242def has_unbalanced_parentheses (expr : str ) -> bool :
2343 """
@@ -127,15 +147,15 @@ def replace_greek_symbols(expr: str) -> str:
127147
128148def extract_symbols (expr : str ) -> dict :
129149
130- high_order_pattern = r"\b(?:d|del)\*\*\d+[a-zA-Z_]\w*/(?:d|del)[a-zA-Z_]\w*\*\*\d+\b"
131- first_order_pattern = r"\b(?:d|del)[a-zA-Z_]\w*/(?:d|del)[a-zA-Z_]\w*\b"
150+ # high_order_pattern = r"\b(?:d|del)\*\*\d+[a-zA-Z_]\w*/(?:d|del)[a-zA-Z_]\w*\*\*\d+\b"
151+ # first_order_pattern = r"\b(?:d|del)[a-zA-Z_]\w*/(?:d|del)[a-zA-Z_]\w*\b"
132152 material_pattern = r"\bD_[a-zA-Z_]\w*_[a-zA-Z_]\w*\b"
133153
134- intg_pattern = r"(?:o?intg)\((?:[^()]+|\((?:[^()]+|\([^()]*\))*\))*\)"
154+ # intg_pattern = r"(?:o?intg)\((?:[^()]+|\((?:[^()]+|\([^()]*\))*\))*\)"
135155
136- nabla_pattern = r"(?:\bgrad\b|\bdivg\b|\brot\b|\bdot\b|\bcross\b|\bvec\b|\bhat\b)"
156+ # nabla_pattern = r"(?:\bgrad\b|\bdivg\b|\brot\b|\bdot\b|\bcross\b|\bvec\b|\bhat\b)"
137157
138- combined_pattern = f"{ high_order_pattern } | { first_order_pattern } | { material_pattern } | { intg_pattern } | { nabla_pattern } "
158+ combined_pattern = f"{ material_pattern } "
139159
140160 matches = re .findall (combined_pattern , expr )
141161
@@ -147,73 +167,127 @@ def extract_symbols(expr: str) -> dict:
147167def is_equivalent_sympy (expr1 , expr2 , params ) -> Union [bool , None ]:
148168 """
149169 Return True/False if comparable with SymPy,
150- or None if an error occurs.
170+ or False if an error occurs.
151171 """
152- if not expr1 .strip () and not expr2 .strip ():
153- return True
154- if not expr1 .strip () or not expr2 .strip ():
155- return False
172+
173+ if isinstance (expr1 , str ) and isinstance (expr2 , str ):
174+ if not expr1 .strip () and not expr2 .strip ():
175+ return True
176+ if not expr1 .strip () or not expr2 .strip ():
177+ return False
156178
157179 try :
158- parsing_params = create_sympy_parsing_params (params , expr1 , expr2 )
180+ # Always convert to string before parsing assumptions
181+ parsing_params = create_sympy_parsing_params (params , str (expr1 ), str (expr2 ))
159182 raw_dict = parsing_params ["symbol_dict" ]
160- transformations = parsing_params .get ("extra_transformations" , standard_transformations + (implicit_multiplication_application ,))
183+ transformations = parsing_params .get (
184+ "extra_transformations" ,
185+ standard_transformations + (implicit_multiplication_application ,)
186+ )
161187
162- symbols1 = extract_symbols (expr1 )
163- symbols2 = extract_symbols (expr2 )
188+ # Optional: Extract symbols from string inputs
189+ # symbols1 = extract_symbols(expr1) if isinstance(expr1, str) else expr1
190+ # symbols2 = extract_symbols(expr2) if isinstance(expr2, str) else expr2
191+
192+ # Optional fluid context
193+ fd = Fluids ()
194+ fd_dict = vars (fd )
195+ valid_types = (Basic , MatrixBase )
196+ fd_func_dict = {
197+ k .replace ('_func' , '' ): v
198+ for k , v in fd_dict .items ()
199+ if k .endswith ('_func' ) and isinstance (v , FunctionClass )
200+ }
164201
202+ fd_filtered = {
203+ k : v for k , v in fd_dict .items ()
204+ if isinstance (v , valid_types )
205+ and not k .endswith ('_func' )
206+ and not isinstance (v , AppliedUndef )
207+ }
208+
209+ # Build local_dict for parser
165210 local_dict = {
166- ** symbols1 ,
167- ** symbols2 ,
211+ "Gradient" : fd .Gradient ,
212+ "Divergence" : fd .Divergence ,
213+ "Curl" : fd .Curl ,
214+ "smart_derivative" : fd .smart_derivative ,
215+ ** fd_func_dict ,
216+ ** fd_filtered ,
217+ # **symbols1,
218+ # **symbols2,
168219 }
169- local_dict ["intg" ] = Function ("intg" )
170- local_dict ["ointg" ] = Function ("ointg" )
171- local_dict ["grad" ] = Function ("grad" )
172- local_dict ["divg" ] = Function ("divg" )
173- local_dict ["rot" ] = Function ("rot" )
174- local_dict ["dot" ] = Function ("dot" )
175- local_dict ["cross" ] = Function ("cross" )
176- local_dict ["vec" ] = Function ("vec" )
177- local_dict ["hat" ] = Function ("hat" )
178-
179- for name , attrs in raw_dict .items ():
180- if name == "integrate" :
181- local_dict [name ] = Function (name )
182- elif name not in local_dict :
183- local_dict [name ] = Symbol (name , ** attrs ) if isinstance (attrs , dict ) else attrs
184-
185- if "=" in expr1 and "=" in expr2 :
186- lhs1 , rhs1 = expr1 .split ("=" )
187- lhs2 , rhs2 = expr2 .split ("=" )
188- lhs1_parsed = parse_expr (lhs1 , transformations = transformations , local_dict = local_dict )
189- rhs1_parsed = parse_expr (rhs1 , transformations = transformations , local_dict = local_dict )
190- lhs2_parsed = parse_expr (lhs2 , transformations = transformations , local_dict = local_dict )
191- rhs2_parsed = parse_expr (rhs2 , transformations = transformations , local_dict = local_dict )
220+
221+ for name , sym in raw_dict .items ():
222+ if not isinstance (sym , dict ):
223+ if params :
224+ local_dict [name ] = sym
225+ elif name not in local_dict :
226+ local_dict [name ] = sym
227+
228+ def ensure_expr (expr ):
229+ if isinstance (expr , str ):
230+ func_args_map = {
231+ "u" : (fd .x , fd .y , fd .z , fd .t ),
232+ "v" : (fd .x , fd .y , fd .z , fd .t ),
233+ "w" : (fd .x , fd .y , fd .z , fd .t ),
234+ "T" : (fd .x , fd .y , fd .z , fd .t ),
235+ "rho" : (fd .x , fd .y , fd .z , fd .t ),
236+ "p" : (fd .x , fd .y , fd .z , fd .t ),
237+ "u_r" : (fd .r , fd .theta , fd .z , fd .t ),
238+ "u_theta" : (fd .r , fd .theta , fd .z , fd .t ),
239+ "u_z" : (fd .r , fd .theta , fd .z , fd .t ),
240+ }
241+
242+ applied_funcs = {}
243+ for name , func in local_dict .items ():
244+ if isinstance (func , FunctionClass ) and name in func_args_map :
245+ args = func_args_map [name ]
246+ applied_funcs [name ] = func (* args )
247+
248+ # 置換:既に呼び出されていない関数だけ変換
249+ for name , applied in applied_funcs .items ():
250+ pattern = rf'(?<!\w){ name } (?!\w|\s*\()'
251+ expr = re .sub (pattern , f'({ str (applied )} )' , expr )
252+
253+ return parse_expr (expr , transformations = transformations , local_dict = local_dict )
254+ else :
255+ return expr
256+
257+ # Handle equations
258+ if "=" in str (expr1 ) and "=" in str (expr2 ):
259+ lhs1 , rhs1 = str (expr1 ).split ("=" )
260+ lhs2 , rhs2 = str (expr2 ).split ("=" )
261+
262+ lhs1_parsed = ensure_expr (lhs1 )
263+ rhs1_parsed = ensure_expr (rhs1 )
264+ lhs2_parsed = ensure_expr (lhs2 )
265+ rhs2_parsed = ensure_expr (rhs2 )
266+
192267 eq1 = Eq (lhs1_parsed - rhs1_parsed , 0 )
193268 eq2 = Eq (lhs2_parsed - rhs2_parsed , 0 )
194269
195270 all_symbols = eq1 .free_symbols .union (eq2 .free_symbols )
196-
197271 sol1 = solve (eq1 , list (all_symbols ))
198272 sol2 = solve (eq2 , list (all_symbols ))
199273
200274 return set (sol1 ) == set (sol2 )
201275
202- # Parse expressions
203- expr1_parsed = parse_expr (expr1 , transformations = transformations , local_dict = local_dict )
204- expr2_parsed = parse_expr (expr2 , transformations = transformations , local_dict = local_dict )
276+ # Handle expression comparison
277+ expr1_parsed = ensure_expr (expr1 )
278+ expr2_parsed = ensure_expr (expr2 )
205279
206- print ("expr1_parsed:" , expr1_parsed )
207- print ("expr2_parsed:" , expr2_parsed )
280+ if isinstance (expr1_parsed , MatrixBase ) and isinstance (expr2_parsed , MatrixBase ):
281+ return simplify (expr1_parsed - expr2_parsed ) == Matrix .zeros (* expr1_parsed .shape )
282+ else :
283+ return simplify (expr1_parsed - expr2_parsed ) == 0
208284
209- return simplify (expr1_parsed - expr2_parsed ) == 0
210285
211286 except Exception as e :
212287 print (f"SymPy error: { e } " )
213288 return False
214289
215290def evaluation_function (response , answer , params ):
216-
217291 if has_unbalanced_parentheses (response ) or has_unbalanced_parentheses (answer ):
218292 return {
219293 "is_correct" : False ,
@@ -247,21 +321,22 @@ def evaluation_function(response, answer, params):
247321
248322 if needs_conversion :
249323 if response_has_other :
250- response = convert_other_re (response , params )
324+ response = convert_other (response , params ). content . strip ( )
251325 if response_has_diff :
252- response = convert_diff_re (response , params )
326+ response = convert_diff (response , params ). content . strip ( )
253327 if response_has_integral :
254- response = convert_integral_re (response , params )
328+ response = convert_integral (response , params ). content . strip ( )
255329
256330 if answer_has_other :
257- answer = convert_other_re (answer , params )
331+ answer = convert_other (answer , params ). content . strip ( )
258332 if answer_has_diff :
259- answer = convert_diff_re (answer , params )
333+ answer = convert_diff (answer , params ). content . strip ( )
260334 if answer_has_integral :
261- answer = convert_integral_re (answer , params )
262-
263- print (response , answer )
335+ answer = convert_integral (answer , params ).content .strip ()
264336
337+ print (response , answer ) #parentheses not removed but will be removed later
338+ response = strip_outer_parens (response )
339+ answer = strip_outer_parens (answer )
265340 result = is_equivalent_sympy (response , answer , params )
266341
267342 return {"is_correct" : result }
0 commit comments