Skip to content

Commit c4f8208

Browse files
Addition of preview and parameter extraction
1 parent fa588fc commit c4f8208

13 files changed

+971
-446
lines changed

app/FMX2_symbols.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33

44
class Fluids:
55
def __init__(self):
6-
# 座標
76
self.x, self.y, self.z, self.t = symbols('x y z t')
87
self.r, self.theta = symbols('r theta')
8+
self.inf = symbols('inf')
99

10-
# 物理量(スカラー場)
1110
self.u_func = Function('u')
1211
self.v_func = Function('v')
1312
self.w_func = Function('w')
@@ -22,20 +21,16 @@ def __init__(self):
2221
self.rho = self.rho_func(self.x, self.y, self.z, self.t)
2322
self.p = self.p_func(self.x, self.y, self.z, self.t)
2423

25-
# 定数
2624
self.a, self.b, self.c, self.h = symbols('a b c h')
2725
self.m, self.g, self.mu, self.nu, self.R, self.c_p, self.c_v, self.kappa = symbols(
2826
'm g mu nu R c_p c_v kappa', positive=True
2927
)
3028

31-
# 速度ベクトル(Matrixベース)
3229
self.u_vec = Matrix([[self.u],[self.v],[self.w]])
3330
self.x_vec = Matrix([[self.x],[self.y],[self.z]])
3431

35-
# 勾配テンソル(成分ごとの微分)
3632
self.grad_u = self.Gradient(self.u_vec)
3733

38-
# 円筒座標系の速度(Matrixベース)
3934
self.u_r_func = Function('u_r')
4035
self.u_theta_func = Function('u_theta')
4136
self.u_z_func = Function('u_z')
@@ -53,42 +48,42 @@ def __init__(self):
5348
self.theta_hat = Matrix([[0],[1],[0]])
5449

5550
def Gradient(self, f):
56-
"""スカラーかベクトルに応じて勾配 or 勾配テンソルを返す"""
5751
if isinstance(f, Expr):
5852
return Matrix([
59-
[f.diff(self.x)],
60-
[f.diff(self.y)],
61-
[f.diff(self.z)]
53+
[Derivative(f, self.x)],
54+
[Derivative(f, self.y)],
55+
[Derivative(f, self.z)]
6256
])
6357

6458
elif isinstance(f, MatrixBase) and f.shape == (3, 1):
6559
return Matrix([
66-
[f[0].diff(self.x), f[0].diff(self.y), f[0].diff(self.z)],
67-
[f[1].diff(self.x), f[1].diff(self.y), f[1].diff(self.z)],
68-
[f[2].diff(self.x), f[2].diff(self.y), f[2].diff(self.z)]
60+
[Derivative(f[0], self.x), Derivative(f[0], self.y), Derivative(f[0], self.z)],
61+
[Derivative(f[1], self.x), Derivative(f[1], self.y), Derivative(f[1], self.z)],
62+
[Derivative(f[2], self.x), Derivative(f[2], self.y), Derivative(f[2], self.z)]
6963
])
7064

7165
else:
7266
raise TypeError("Gradient() expects a scalar Expr or a 3x1 Matrix.")
7367

7468
def Divergence(self, vec):
75-
"""ベクトルの発散(スカラー)"""
7669
if not (isinstance(vec, MatrixBase) and vec.shape == (3, 1)):
7770
raise TypeError("Divergence expects a 3x1 vector Matrix.")
78-
return vec[0].diff(self.x) + vec[1].diff(self.y) + vec[2].diff(self.z)
71+
return Derivative(vec[0],(self.x)) + Derivative(vec[1],(self.y)) + Derivative(vec[2],(self.z))
7972

8073
def Curl(self, vec):
81-
"""ベクトルの回転(3x1ベクトル)"""
8274
if not (isinstance(vec, MatrixBase) and vec.shape == (3, 1)):
8375
raise TypeError("Curl expects a 3x1 vector Matrix.")
8476
return Matrix([
85-
[vec[2].diff(self.y) - vec[1].diff(self.z)],
86-
[vec[0].diff(self.z) - vec[2].diff(self.x)],
87-
[vec[1].diff(self.x) - vec[0].diff(self.y)]
77+
[Derivative(vec[2],(self.y)) - Derivative(vec[1],(self.z))],
78+
[Derivative(vec[0],(self.z)) - Derivative(vec[2],(self.x))],
79+
[Derivative(vec[1],(self.x)) - Derivative(vec[0],(self.y))]
8880
])
89-
def smart_derivative(self, expr, var):
90-
# スカラーならそのまま
81+
def smart_derivative(self, expr, var, n=1):
9182
if not isinstance(expr, MatrixBase):
92-
return Derivative(expr, var)
93-
# ベクトルなら成分ごとに
94-
return Matrix([[Derivative(expr[i], var)] for i in range(expr.rows)])
83+
return Derivative(expr, var, n)
84+
return Matrix([[Derivative(expr[i], var, n)] for i in range(expr.rows)])
85+
def smart_dot(self, expr1, expr2):
86+
if isinstance(expr1, MatrixBase) and isinstance(expr2, MatrixBase):
87+
if expr1.shape == (3, 1) and expr2.shape == (3, 1):
88+
return expr1.dot(expr2)
89+
return expr1 * expr2

app/evaluation.py

Lines changed: 107 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22
from dotenv import load_dotenv
33
from langchain_openai import ChatOpenAI
44
from typing import Any, TypedDict, Union
5-
from sympy import solve, Eq, simplify, Expr, symbols, Symbol, Function, FunctionClass, Integral, Derivative, Matrix, Abs, sin, cos, tan, sqrt, log, exp
5+
from sympy import solve, Eq, simplify, Expr, symbols, Symbol, Function, FunctionClass, Integral, Derivative, Matrix, Abs, sin, cos, tan, sqrt, log, exp, oo
66
from sympy.core.function import AppliedUndef
77
from sympy.matrices import MatrixBase
88
from sympy import Basic
99
from sympy.parsing.sympy_parser import parse_expr, standard_transformations, implicit_multiplication_application
1010
import re
11-
from parameter import create_sympy_parsing_params, Params
11+
from parameter import create_sympy_parsing_params, Params, apply_declared_functions, parse_domain, check_in_domain
1212
from re_conversion import convert_diff_re, convert_integral_re, convert_other_re
1313
from llm_conversion import convert_diff, convert_integral, convert_other
1414
from FMX2_symbols import Fluids
15+
from evaluation_symbolise import symbolise_by_operators, is_matching_parens
16+
from sympy import sympify
1517

1618

1719
class Result(TypedDict):
@@ -20,9 +22,6 @@ class Result(TypedDict):
2022
llm_result: bool
2123
mismatch_info: str
2224

23-
24-
transformations = standard_transformations + (implicit_multiplication_application,)
25-
2625
def strip_outer_parens(expr: str) -> str:
2726
expr = expr.strip()
2827
if not expr.startswith("(") or not expr.endswith(")"):
@@ -44,6 +43,7 @@ def has_unbalanced_parentheses(expr: str) -> bool:
4443
Check if the expression has unbalanced parentheses
4544
"""
4645
return expr.count("(") != expr.count(")")
46+
4747
class contains_special_math():
4848
def contains_diff(self, expr: str) -> bool:
4949
patterns = [
@@ -124,14 +124,14 @@ def contains_other(self, expr: str) -> bool:
124124

125125
def replace_greek_symbols(expr: str) -> str:
126126
greek_map = {
127-
# 小文字
127+
128128
"α": "alpha", "β": "beta", "γ": "gamma", "δ": "delta",
129129
"ε": "epsilon", "ζ": "zeta", "η": "eta", "θ": "theta",
130130
"ι": "iota", "κ": "kappa", "λ": "lambda", "μ": "mu",
131131
"ν": "nu", "ξ": "xi", "ο": "omicron", "π": "pi",
132132
"ρ": "rho", "σ": "sigma", "τ": "tau", "υ": "upsilon",
133133
"φ": "phi", "χ": "chi", "ψ": "psi", "ω": "omega",
134-
# 大文字
134+
135135
"Α": "Alpha", "Β": "Beta", "Γ": "Gamma", "Δ": "Delta",
136136
"Ε": "Epsilon", "Ζ": "Zeta", "Η": "Eta", "Θ": "Theta",
137137
"Ι": "Iota", "Κ": "Kappa", "Λ": "Lambda", "Μ": "Mu",
@@ -149,8 +149,13 @@ def extract_symbols(expr: str) -> dict:
149149

150150
# high_order_pattern = r"\b(?:d|del)\*\*\d+[a-zA-Z_]\w*/(?:d|del)[a-zA-Z_]\w*\*\*\d+\b"
151151
# first_order_pattern = r"\b(?:d|del)[a-zA-Z_]\w*/(?:d|del)[a-zA-Z_]\w*\b"
152-
material_pattern = r"\bD_[a-zA-Z_]\w*_[a-zA-Z_]\w*\b"
153-
152+
material_pattern = (
153+
r"\bD_[A-Za-z_]\w*_[A-Za-z_]\w*\b" # D_var1_var2 形式
154+
r"|" # または
155+
r"\bD[A-Za-z_]\w*/D[A-Za-z_]\w*\b" # Dy/Dx 形式
156+
r"|" # または
157+
r"\bD/D[A-Za-z_]\w*(?:\(\s*[A-Za-z_]\w*\s*\)|\s+[A-Za-z_]\w*)" # D/Dx(y) 形式
158+
)
154159
# intg_pattern = r"(?:o?intg)\((?:[^()]+|\((?:[^()]+|\([^()]*\))*\))*\)"
155160

156161
# nabla_pattern = r"(?:\bgrad\b|\bdivg\b|\brot\b|\bdot\b|\bcross\b|\bvec\b|\bhat\b)"
@@ -194,16 +199,17 @@ def is_equivalent_sympy(expr1, expr2, params) -> Union[bool, None]:
194199
fd_dict = vars(fd)
195200
valid_types = (Basic, MatrixBase)
196201
fd_func_dict = {
197-
k.replace('_func', ''): v
202+
k[:-5]: v
198203
for k, v in fd_dict.items()
199204
if k.endswith('_func') and isinstance(v, FunctionClass)
200205
}
201-
206+
func_names = set(fd_func_dict.keys())
202207
fd_filtered = {
203208
k: v for k, v in fd_dict.items()
204209
if isinstance(v, valid_types)
205210
and not k.endswith('_func')
206211
and not isinstance(v, AppliedUndef)
212+
and k not in func_names
207213
}
208214

209215
# Build local_dict for parser
@@ -212,18 +218,28 @@ def is_equivalent_sympy(expr1, expr2, params) -> Union[bool, None]:
212218
"Divergence": fd.Divergence,
213219
"Curl": fd.Curl,
214220
"smart_derivative": fd.smart_derivative,
221+
"smart_dot": fd.smart_dot,
215222
**fd_func_dict,
216223
**fd_filtered,
217224
# **symbols1,
218225
# **symbols2,
219226
}
220227

221228
for name, sym in raw_dict.items():
222-
if not isinstance(sym, dict):
223-
if params:
224-
local_dict[name] = sym
225-
elif name not in local_dict:
226-
local_dict[name] = sym
229+
if isinstance(sym, dict):
230+
local_dict[name] = Symbol(name, **sym)
231+
else:
232+
if params:
233+
if isinstance(params, dict) and params:
234+
local_dict[name] = sym
235+
elif name not in local_dict:
236+
local_dict[name] = sym
237+
238+
if isinstance(params, dict) and params.get("function"):
239+
if isinstance(expr1, str):
240+
expr1 = apply_declared_functions(expr1, params["function"])
241+
if isinstance(expr2, str):
242+
expr2 = apply_declared_functions(expr2, params["function"])
227243

228244
def ensure_expr(expr):
229245
if isinstance(expr, str):
@@ -245,7 +261,6 @@ def ensure_expr(expr):
245261
args = func_args_map[name]
246262
applied_funcs[name] = func(*args)
247263

248-
# 置換:既に呼び出されていない関数だけ変換
249264
for name, applied in applied_funcs.items():
250265
pattern = rf'(?<!\w){name}(?!\w|\s*\()'
251266
expr = re.sub(pattern, f'({str(applied)})', expr)
@@ -256,6 +271,23 @@ def ensure_expr(expr):
256271

257272
# Handle equations
258273
if "=" in str(expr1) and "=" in str(expr2):
274+
def unwrap_parens_if_equal(s: str) -> str:
275+
s = s.strip()
276+
if "=" not in s:
277+
if s.startswith("(") and s.endswith(")") and is_matching_parens(s):
278+
inner = s[1:-1].strip()
279+
return inner
280+
else:
281+
left, right = s.split("=")
282+
if left.startswith("(") and left.endswith(")") and is_matching_parens(left):
283+
left = left[1:-1].strip()
284+
if right.startswith("(") and right.endswith(")") and is_matching_parens(right):
285+
right = right[1:-1].strip()
286+
return left+"="+right
287+
return s
288+
if isinstance(expr1, str) and isinstance(expr2, str):
289+
expr1 = unwrap_parens_if_equal(expr1)
290+
expr2 = unwrap_parens_if_equal(expr2)
259291
lhs1, rhs1 = str(expr1).split("=")
260292
lhs2, rhs2 = str(expr2).split("=")
261293

@@ -276,6 +308,9 @@ def ensure_expr(expr):
276308
# Handle expression comparison
277309
expr1_parsed = ensure_expr(expr1)
278310
expr2_parsed = ensure_expr(expr2)
311+
expr1 = apply_declared_functions(expr1, params.get("function", []))
312+
expr2 = apply_declared_functions(expr2, params.get("function", []))
313+
279314

280315
if isinstance(expr1_parsed, MatrixBase) and isinstance(expr2_parsed, MatrixBase):
281316
return simplify(expr1_parsed - expr2_parsed) == Matrix.zeros(*expr1_parsed.shape)
@@ -284,8 +319,30 @@ def ensure_expr(expr):
284319

285320

286321
except Exception as e:
287-
print(f"SymPy error: {e}")
288-
return False
322+
if "could not solve" in str(e).lower():
323+
try:
324+
expr1_symbolised, expr2_symbolised, symbol_map = symbolise_by_operators(expr1, expr2)
325+
if "=" in str(expr1) and "=" in str(expr2):
326+
lhs1, rhs1 = str(expr1_symbolised).split("=")
327+
lhs2, rhs2 = str(expr2_symbolised).split("=")
328+
lhs1_expr = parse_expr(lhs1.strip())
329+
rhs1_expr = parse_expr(rhs1.strip())
330+
lhs2_expr = parse_expr(lhs2.strip())
331+
rhs2_expr = parse_expr(rhs2.strip())
332+
eq1 = Eq(lhs1_expr - rhs1_expr, 0)
333+
eq2 = Eq(lhs2_expr - rhs2_expr, 0)
334+
all_symbols = eq1.free_symbols.union(eq2.free_symbols)
335+
sol1 = solve(eq1, list(all_symbols))
336+
sol2 = solve(eq2, list(all_symbols))
337+
return set(sol1) == set(sol2)
338+
else:
339+
return simplify(expr1_symbolised - expr2_symbolised) == 0
340+
except Exception as fallback_error:
341+
print(f"Fallback error: {fallback_error}")
342+
return None
343+
else:
344+
print(f"SymPy error: {e}")
345+
return False
289346

290347
def evaluation_function(response, answer, params):
291348
if has_unbalanced_parentheses(response) or has_unbalanced_parentheses(answer):
@@ -298,6 +355,8 @@ def evaluation_function(response, answer, params):
298355
}
299356
response = response.replace("^", "**")
300357
answer = answer.replace("^", "**")
358+
response = response.replace(" ", "")
359+
answer = answer.replace(" ", "")
301360
response = replace_greek_symbols(response)
302361
answer = replace_greek_symbols(answer)
303362
print(response, answer)
@@ -321,22 +380,44 @@ def evaluation_function(response, answer, params):
321380

322381
if needs_conversion:
323382
if response_has_other:
324-
response = convert_other(response, params).content.strip()
383+
# response = convert_other(response, params).content.strip()
384+
response = convert_other_re(response, params)
325385
if response_has_diff:
326-
response = convert_diff(response, params).content.strip()
386+
# response = convert_diff(response, params).content.strip()
387+
response = convert_diff_re(response, params)
327388
if response_has_integral:
328-
response = convert_integral(response, params).content.strip()
389+
# response = convert_integral(response, params).content.strip()
390+
response = convert_integral_re(response, params)
329391

330392
if answer_has_other:
331-
answer = convert_other(answer, params).content.strip()
393+
# answer = convert_other(answer, params).content.strip()
394+
answer = convert_other_re(answer, params)
332395
if answer_has_diff:
333-
answer = convert_diff(answer, params).content.strip()
396+
# answer = convert_diff(answer, params).content.strip()
397+
answer = convert_diff_re(answer, params)
334398
if answer_has_integral:
335-
answer = convert_integral(answer, params).content.strip()
399+
# answer = convert_integral(answer, params).content.strip()
400+
answer = convert_integral_re(answer, params)
336401

337-
print(response, answer) #parentheses not removed but will be removed later
402+
print("response converted: " + response)
403+
print("answer converted: " + answer)
338404
response = strip_outer_parens(response)
339405
answer = strip_outer_parens(answer)
340406
result = is_equivalent_sympy(response, answer, params)
341407

408+
if "domain" in params and params["domain"]:
409+
try:
410+
domain = params["domain"]
411+
if isinstance(domain, str):
412+
domain = parse_domain(domain)
413+
414+
val = sympify(response)
415+
416+
if val.is_number:
417+
if not check_in_domain(val, domain):
418+
return {"is_correct": False}
419+
420+
except Exception as e:
421+
print("error in domain check:", e)
422+
342423
return {"is_correct": result}

0 commit comments

Comments
 (0)