Skip to content

Commit fa588fc

Browse files
FMX2 symbols added
1 parent 74445eb commit fa588fc

File tree

6 files changed

+292
-195
lines changed

6 files changed

+292
-195
lines changed

app/FMX2_symbols.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from sympy import symbols, Function, Matrix, Expr, Derivative
2+
from sympy.matrices import MatrixBase
3+
4+
class Fluids:
5+
def __init__(self):
6+
# 座標
7+
self.x, self.y, self.z, self.t = symbols('x y z t')
8+
self.r, self.theta = symbols('r theta')
9+
10+
# 物理量(スカラー場)
11+
self.u_func = Function('u')
12+
self.v_func = Function('v')
13+
self.w_func = Function('w')
14+
self.T_func = Function('T')
15+
self.rho_func = Function('rho')
16+
self.p_func = Function('p')
17+
18+
self.u = self.u_func(self.x, self.y, self.z, self.t)
19+
self.v = self.v_func(self.x, self.y, self.z, self.t)
20+
self.w = self.w_func(self.x, self.y, self.z, self.t)
21+
self.T = self.T_func(self.x, self.y, self.z, self.t)
22+
self.rho = self.rho_func(self.x, self.y, self.z, self.t)
23+
self.p = self.p_func(self.x, self.y, self.z, self.t)
24+
25+
# 定数
26+
self.a, self.b, self.c, self.h = symbols('a b c h')
27+
self.m, self.g, self.mu, self.nu, self.R, self.c_p, self.c_v, self.kappa = symbols(
28+
'm g mu nu R c_p c_v kappa', positive=True
29+
)
30+
31+
# 速度ベクトル(Matrixベース)
32+
self.u_vec = Matrix([[self.u],[self.v],[self.w]])
33+
self.x_vec = Matrix([[self.x],[self.y],[self.z]])
34+
35+
# 勾配テンソル(成分ごとの微分)
36+
self.grad_u = self.Gradient(self.u_vec)
37+
38+
# 円筒座標系の速度(Matrixベース)
39+
self.u_r_func = Function('u_r')
40+
self.u_theta_func = Function('u_theta')
41+
self.u_z_func = Function('u_z')
42+
43+
self.u_r = self.u_r_func(self.r, self.theta, self.z, self.t)
44+
self.u_theta = self.u_theta_func(self.r, self.theta, self.z, self.t)
45+
self.u_z = self.u_z_func(self.r, self.theta, self.z, self.t)
46+
47+
self.u_vec_cyl = Matrix([[self.u_r],[self.u_theta],[self.u_z]])
48+
49+
self.x_hat = Matrix([[1],[0],[0]])
50+
self.y_hat = Matrix([[0],[1],[0]])
51+
self.z_hat = Matrix([[0],[0],[1]])
52+
self.r_hat = Matrix([[1],[0],[0]])
53+
self.theta_hat = Matrix([[0],[1],[0]])
54+
55+
def Gradient(self, f):
56+
"""スカラーかベクトルに応じて勾配 or 勾配テンソルを返す"""
57+
if isinstance(f, Expr):
58+
return Matrix([
59+
[f.diff(self.x)],
60+
[f.diff(self.y)],
61+
[f.diff(self.z)]
62+
])
63+
64+
elif isinstance(f, MatrixBase) and f.shape == (3, 1):
65+
return Matrix([
66+
[f[0].diff(self.x), f[0].diff(self.y), f[0].diff(self.z)],
67+
[f[1].diff(self.x), f[1].diff(self.y), f[1].diff(self.z)],
68+
[f[2].diff(self.x), f[2].diff(self.y), f[2].diff(self.z)]
69+
])
70+
71+
else:
72+
raise TypeError("Gradient() expects a scalar Expr or a 3x1 Matrix.")
73+
74+
def Divergence(self, vec):
75+
"""ベクトルの発散(スカラー)"""
76+
if not (isinstance(vec, MatrixBase) and vec.shape == (3, 1)):
77+
raise TypeError("Divergence expects a 3x1 vector Matrix.")
78+
return vec[0].diff(self.x) + vec[1].diff(self.y) + vec[2].diff(self.z)
79+
80+
def Curl(self, vec):
81+
"""ベクトルの回転(3x1ベクトル)"""
82+
if not (isinstance(vec, MatrixBase) and vec.shape == (3, 1)):
83+
raise TypeError("Curl expects a 3x1 vector Matrix.")
84+
return Matrix([
85+
[vec[2].diff(self.y) - vec[1].diff(self.z)],
86+
[vec[0].diff(self.z) - vec[2].diff(self.x)],
87+
[vec[1].diff(self.x) - vec[0].diff(self.y)]
88+
])
89+
def smart_derivative(self, expr, var):
90+
# スカラーならそのまま
91+
if not isinstance(expr, MatrixBase):
92+
return Derivative(expr, var)
93+
# ベクトルなら成分ごとに
94+
return Matrix([[Derivative(expr[i], var)] for i in range(expr.rows)])

app/evaluation.py

Lines changed: 131 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@
22
from dotenv import load_dotenv
33
from langchain_openai import ChatOpenAI
44
from typing import Any, TypedDict, Union
5-
from sympy import solve, Eq, simplify, Symbol, Function, integrate, diff
5+
from sympy import solve, Eq, simplify, Expr, symbols, Symbol, Function, FunctionClass, Integral, Derivative, Matrix, Abs, sin, cos, tan, sqrt, log, exp
6+
from sympy.core.function import AppliedUndef
7+
from sympy.matrices import MatrixBase
8+
from sympy import Basic
69
from sympy.parsing.sympy_parser import parse_expr, standard_transformations, implicit_multiplication_application
710
import re
811
from parameter import create_sympy_parsing_params, Params
912
from re_conversion import convert_diff_re, convert_integral_re, convert_other_re
13+
from llm_conversion import convert_diff, convert_integral, convert_other
14+
from FMX2_symbols import Fluids
1015

1116

1217
class Result(TypedDict):
@@ -18,6 +23,21 @@ class Result(TypedDict):
1823

1924
transformations = standard_transformations + (implicit_multiplication_application,)
2025

26+
def strip_outer_parens(expr: str) -> str:
27+
expr = expr.strip()
28+
if not expr.startswith("(") or not expr.endswith(")"):
29+
return expr
30+
31+
depth = 0
32+
for i, char in enumerate(expr):
33+
if char == "(":
34+
depth += 1
35+
elif char == ")":
36+
depth -= 1
37+
if depth == 0 and i != len(expr) - 1:
38+
return expr
39+
return expr[1:-1].strip()
40+
2141

2242
def has_unbalanced_parentheses(expr: str) -> bool:
2343
"""
@@ -127,15 +147,15 @@ def replace_greek_symbols(expr: str) -> str:
127147

128148
def extract_symbols(expr: str) -> dict:
129149

130-
high_order_pattern = r"\b(?:d|del)\*\*\d+[a-zA-Z_]\w*/(?:d|del)[a-zA-Z_]\w*\*\*\d+\b"
131-
first_order_pattern = r"\b(?:d|del)[a-zA-Z_]\w*/(?:d|del)[a-zA-Z_]\w*\b"
150+
# high_order_pattern = r"\b(?:d|del)\*\*\d+[a-zA-Z_]\w*/(?:d|del)[a-zA-Z_]\w*\*\*\d+\b"
151+
# first_order_pattern = r"\b(?:d|del)[a-zA-Z_]\w*/(?:d|del)[a-zA-Z_]\w*\b"
132152
material_pattern = r"\bD_[a-zA-Z_]\w*_[a-zA-Z_]\w*\b"
133153

134-
intg_pattern = r"(?:o?intg)\((?:[^()]+|\((?:[^()]+|\([^()]*\))*\))*\)"
154+
# intg_pattern = r"(?:o?intg)\((?:[^()]+|\((?:[^()]+|\([^()]*\))*\))*\)"
135155

136-
nabla_pattern = r"(?:\bgrad\b|\bdivg\b|\brot\b|\bdot\b|\bcross\b|\bvec\b|\bhat\b)"
156+
# nabla_pattern = r"(?:\bgrad\b|\bdivg\b|\brot\b|\bdot\b|\bcross\b|\bvec\b|\bhat\b)"
137157

138-
combined_pattern = f"{high_order_pattern}|{first_order_pattern}|{material_pattern}|{intg_pattern}|{nabla_pattern}"
158+
combined_pattern = f"{material_pattern}"
139159

140160
matches = re.findall(combined_pattern, expr)
141161

@@ -147,73 +167,127 @@ def extract_symbols(expr: str) -> dict:
147167
def is_equivalent_sympy(expr1, expr2, params) -> Union[bool, None]:
148168
"""
149169
Return True/False if comparable with SymPy,
150-
or None if an error occurs.
170+
or False if an error occurs.
151171
"""
152-
if not expr1.strip() and not expr2.strip():
153-
return True
154-
if not expr1.strip() or not expr2.strip():
155-
return False
172+
173+
if isinstance(expr1, str) and isinstance(expr2, str):
174+
if not expr1.strip() and not expr2.strip():
175+
return True
176+
if not expr1.strip() or not expr2.strip():
177+
return False
156178

157179
try:
158-
parsing_params = create_sympy_parsing_params(params, expr1, expr2)
180+
# Always convert to string before parsing assumptions
181+
parsing_params = create_sympy_parsing_params(params, str(expr1), str(expr2))
159182
raw_dict = parsing_params["symbol_dict"]
160-
transformations = parsing_params.get("extra_transformations", standard_transformations + (implicit_multiplication_application,))
183+
transformations = parsing_params.get(
184+
"extra_transformations",
185+
standard_transformations + (implicit_multiplication_application,)
186+
)
161187

162-
symbols1 = extract_symbols(expr1)
163-
symbols2 = extract_symbols(expr2)
188+
# Optional: Extract symbols from string inputs
189+
# symbols1 = extract_symbols(expr1) if isinstance(expr1, str) else expr1
190+
# symbols2 = extract_symbols(expr2) if isinstance(expr2, str) else expr2
191+
192+
# Optional fluid context
193+
fd = Fluids()
194+
fd_dict = vars(fd)
195+
valid_types = (Basic, MatrixBase)
196+
fd_func_dict = {
197+
k.replace('_func', ''): v
198+
for k, v in fd_dict.items()
199+
if k.endswith('_func') and isinstance(v, FunctionClass)
200+
}
164201

202+
fd_filtered = {
203+
k: v for k, v in fd_dict.items()
204+
if isinstance(v, valid_types)
205+
and not k.endswith('_func')
206+
and not isinstance(v, AppliedUndef)
207+
}
208+
209+
# Build local_dict for parser
165210
local_dict = {
166-
**symbols1,
167-
**symbols2,
211+
"Gradient": fd.Gradient,
212+
"Divergence": fd.Divergence,
213+
"Curl": fd.Curl,
214+
"smart_derivative": fd.smart_derivative,
215+
**fd_func_dict,
216+
**fd_filtered,
217+
# **symbols1,
218+
# **symbols2,
168219
}
169-
local_dict["intg"] = Function("intg")
170-
local_dict["ointg"] = Function("ointg")
171-
local_dict["grad"] = Function("grad")
172-
local_dict["divg"] = Function("divg")
173-
local_dict["rot"] = Function("rot")
174-
local_dict["dot"] = Function("dot")
175-
local_dict["cross"] = Function("cross")
176-
local_dict["vec"] = Function("vec")
177-
local_dict["hat"] = Function("hat")
178-
179-
for name, attrs in raw_dict.items():
180-
if name == "integrate":
181-
local_dict[name] = Function(name)
182-
elif name not in local_dict:
183-
local_dict[name] = Symbol(name, **attrs) if isinstance(attrs, dict) else attrs
184-
185-
if "=" in expr1 and "=" in expr2:
186-
lhs1, rhs1 = expr1.split("=")
187-
lhs2, rhs2 = expr2.split("=")
188-
lhs1_parsed = parse_expr(lhs1, transformations=transformations, local_dict=local_dict)
189-
rhs1_parsed = parse_expr(rhs1, transformations=transformations, local_dict=local_dict)
190-
lhs2_parsed = parse_expr(lhs2, transformations=transformations, local_dict=local_dict)
191-
rhs2_parsed = parse_expr(rhs2, transformations=transformations, local_dict=local_dict)
220+
221+
for name, sym in raw_dict.items():
222+
if not isinstance(sym, dict):
223+
if params:
224+
local_dict[name] = sym
225+
elif name not in local_dict:
226+
local_dict[name] = sym
227+
228+
def ensure_expr(expr):
229+
if isinstance(expr, str):
230+
func_args_map = {
231+
"u": (fd.x, fd.y, fd.z, fd.t),
232+
"v": (fd.x, fd.y, fd.z, fd.t),
233+
"w": (fd.x, fd.y, fd.z, fd.t),
234+
"T": (fd.x, fd.y, fd.z, fd.t),
235+
"rho": (fd.x, fd.y, fd.z, fd.t),
236+
"p": (fd.x, fd.y, fd.z, fd.t),
237+
"u_r": (fd.r, fd.theta, fd.z, fd.t),
238+
"u_theta": (fd.r, fd.theta, fd.z, fd.t),
239+
"u_z": (fd.r, fd.theta, fd.z, fd.t),
240+
}
241+
242+
applied_funcs = {}
243+
for name, func in local_dict.items():
244+
if isinstance(func, FunctionClass) and name in func_args_map:
245+
args = func_args_map[name]
246+
applied_funcs[name] = func(*args)
247+
248+
# 置換:既に呼び出されていない関数だけ変換
249+
for name, applied in applied_funcs.items():
250+
pattern = rf'(?<!\w){name}(?!\w|\s*\()'
251+
expr = re.sub(pattern, f'({str(applied)})', expr)
252+
253+
return parse_expr(expr, transformations=transformations, local_dict=local_dict)
254+
else:
255+
return expr
256+
257+
# Handle equations
258+
if "=" in str(expr1) and "=" in str(expr2):
259+
lhs1, rhs1 = str(expr1).split("=")
260+
lhs2, rhs2 = str(expr2).split("=")
261+
262+
lhs1_parsed = ensure_expr(lhs1)
263+
rhs1_parsed = ensure_expr(rhs1)
264+
lhs2_parsed = ensure_expr(lhs2)
265+
rhs2_parsed = ensure_expr(rhs2)
266+
192267
eq1 = Eq(lhs1_parsed - rhs1_parsed, 0)
193268
eq2 = Eq(lhs2_parsed - rhs2_parsed, 0)
194269

195270
all_symbols = eq1.free_symbols.union(eq2.free_symbols)
196-
197271
sol1 = solve(eq1, list(all_symbols))
198272
sol2 = solve(eq2, list(all_symbols))
199273

200274
return set(sol1) == set(sol2)
201275

202-
# Parse expressions
203-
expr1_parsed = parse_expr(expr1, transformations=transformations, local_dict=local_dict)
204-
expr2_parsed = parse_expr(expr2, transformations=transformations, local_dict=local_dict)
276+
# Handle expression comparison
277+
expr1_parsed = ensure_expr(expr1)
278+
expr2_parsed = ensure_expr(expr2)
205279

206-
print("expr1_parsed:", expr1_parsed)
207-
print("expr2_parsed:", expr2_parsed)
280+
if isinstance(expr1_parsed, MatrixBase) and isinstance(expr2_parsed, MatrixBase):
281+
return simplify(expr1_parsed - expr2_parsed) == Matrix.zeros(*expr1_parsed.shape)
282+
else:
283+
return simplify(expr1_parsed - expr2_parsed) == 0
208284

209-
return simplify(expr1_parsed - expr2_parsed) == 0
210285

211286
except Exception as e:
212287
print(f"SymPy error: {e}")
213288
return False
214289

215290
def evaluation_function(response, answer, params):
216-
217291
if has_unbalanced_parentheses(response) or has_unbalanced_parentheses(answer):
218292
return {
219293
"is_correct": False,
@@ -247,21 +321,22 @@ def evaluation_function(response, answer, params):
247321

248322
if needs_conversion:
249323
if response_has_other:
250-
response = convert_other_re(response, params)
324+
response = convert_other(response, params).content.strip()
251325
if response_has_diff:
252-
response = convert_diff_re(response, params)
326+
response = convert_diff(response, params).content.strip()
253327
if response_has_integral:
254-
response = convert_integral_re(response, params)
328+
response = convert_integral(response, params).content.strip()
255329

256330
if answer_has_other:
257-
answer = convert_other_re(answer, params)
331+
answer = convert_other(answer, params).content.strip()
258332
if answer_has_diff:
259-
answer = convert_diff_re(answer, params)
333+
answer = convert_diff(answer, params).content.strip()
260334
if answer_has_integral:
261-
answer = convert_integral_re(answer, params)
262-
263-
print(response, answer)
335+
answer = convert_integral(answer, params).content.strip()
264336

337+
print(response, answer) #parentheses not removed but will be removed later
338+
response = strip_outer_parens(response)
339+
answer = strip_outer_parens(answer)
265340
result = is_equivalent_sympy(response, answer, params)
266341

267342
return {"is_correct": result}

app/evaluation_test_cases.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
["∫fdx", "int(f, x)", Params(), True],
2424
["dy/dx + 1", "diff(y, x) + 1", Params(), True],
2525
["dp/dt", "diff(p, t)", Params(), True],
26-
["dg/dm", "diff(y,x)", Params(), False],
27-
["infty", "Infinity", Params(), True], #1
26+
["du/dx", "diff(y,x)", Params(), False],
27+
["infty", "Infinity", Params(), True],
2828
["sqrt(-1)", "I", Params(), True],
2929
["sqrt(x**2)", "x", Params(), False],
3030
["1/(x-1)", "1/(1-x)", Params(), False],
@@ -41,12 +41,11 @@
4141
["abs(x)", "sqrt(x**2)", Params(), False],
4242
["abs(x)", "sqrt(x**2)", Params(symbol_assumptions={"x": {"real": True},}), True],
4343
]
44+
4445
test_cases2 = [
45-
["nabla(f)=delf/delr*hat(r)+delf/deltheta*hat(theta)*delf/delz*hat(z)","grad(f)=delf/delr*hat(r)+delf/deltheta*hat(theta)*delf/delz*hat(z)",Params(),True],
46-
["∫_{V_sys} ∂ρ/∂t dV", "int_Vsys(partial(ρ)/partial(t), V)", Params(), True],
47-
["∫_{V_sys} ∂/∂t(ρ*u) dV", "int_Vsys(partial(ρ*u)/partial(t), V)", Params(), True],
48-
["∮_{A_sys} ρ*u·n dA", "Integral(rho * u.dot(n), (A, A_sys), circular=True)", Params(), True],
49-
["exp(a/b)","e^(a/b)",Params(),True],
50-
["exp(i*(omega*t-k*r))","cos(omega*t-k*r)+i*sin(omega*t-k*r)",Params(),True],
51-
["grad(p)=rho(g-a)","nabla(p)=ρ*(g-a)",Params(),True],
46+
["grad(f)=delf/delr*hat(r)+delf/deltheta*hat(theta)+delf/delz*hat(z)", "grad(f)=delf/delr*hat(r)+delf/deltheta*hat(theta)+delf/delz*hat(z)", Params(), True],
47+
["Du_vec/Dt", "smart_derivative(u_vec,t) + grad(u_vec)*u_vec", Params(), True],
48+
["u_vec", "u_vec", Params(), True],
49+
["Gradient(u_vec)", "Gradient(u_vec)", Params(), True],
50+
["u_vec.dot(x_vec)", "u_vec.dot(x_vec)", Params(), True]
5251
]

0 commit comments

Comments
 (0)