88 extract_latex ,
99 SymbolDict ,
1010 find_matching_parenthesis ,
11+ create_expression_set ,
1112)
1213
1314class Params (TypedDict ):
@@ -26,7 +27,7 @@ class Result(TypedDict):
2627 preview : Preview
2728
2829
29- def parse_latex (response : str , symbols : SymbolDict , simplify : bool ) -> str :
30+ def parse_latex (response : str , symbols : SymbolDict , simplify : bool , parameters = None ) -> str :
3031 """Parse a LaTeX string to a sympy string while preserving custom symbols.
3132
3233 Args:
@@ -40,6 +41,9 @@ def parse_latex(response: str, symbols: SymbolDict, simplify : bool) -> str:
4041 Returns:
4142 str: The expression in sympy syntax.
4243 """
44+ if parameters is None :
45+ parameters = dict ()
46+
4347 substitutions = {}
4448
4549 pm_placeholder = None
@@ -56,8 +60,10 @@ def parse_latex(response: str, symbols: SymbolDict, simplify : bool) -> str:
5660
5761 if pm_placeholder is not None :
5862 response = response .replace (r"\pm " , pm_placeholder )
63+ substitutions [pm_placeholder ] = sympy .Symbol (pm_placeholder , commutative = False )
5964 if mp_placeholder is not None :
6065 response = response .replace (r"\mp " , mp_placeholder )
66+ substitutions [mp_placeholder ] = sympy .Symbol (mp_placeholder , commutative = False )
6167
6268 for sympy_symbol_str in symbols :
6369 symbol_str = symbols [sympy_symbol_str ]["latex" ]
@@ -73,8 +79,6 @@ def parse_latex(response: str, symbols: SymbolDict, simplify : bool) -> str:
7379 )
7480 substitutions [latex_symbol ] = sympy .Symbol (sympy_symbol_str )
7581
76- substitutions .update ({r"\pm " : pm_placeholder , r"\mp " : mp_placeholder })
77-
7882 try :
7983 expression = latex2sympy (response , substitutions )
8084 if isinstance (expression , list ):
@@ -84,13 +88,20 @@ def parse_latex(response: str, symbols: SymbolDict, simplify : bool) -> str:
8488 except Exception as e :
8589 raise ValueError (str (e ))
8690
87- result_str = str (expression .xreplace (substitutions ))
88- for ph in [(pm_placeholder , "plus_minus" ), (mp_placeholder , "minus_plus" )]:
89- if ph [0 ] is not None :
90- result_str = result_str .replace ("*" + ph [0 ]+ "*" , " " + ph [1 ]+ " " )
91- result_str = result_str .replace (ph [0 ]+ "*" , " " + ph [1 ]+ " " )
92- result_str = result_str .replace ("*" + ph [0 ], " " + ph [1 ]+ " " )
93- result_str = result_str .replace (ph [0 ], " " + ph [1 ]+ " " )
91+ if (pm_placeholder is not None ) or (mp_placeholder is not None ):
92+ result_str_set = set ()
93+ result_str = str (expression )
94+ for ph in [(pm_placeholder , "plus_minus" ), (mp_placeholder , "minus_plus" )]:
95+ if ph [0 ] is not None :
96+ result_str = result_str .replace ("*" + ph [0 ]+ "*" , " " + ph [1 ]+ " " )
97+ result_str = result_str .replace (ph [0 ]+ "*" , " " + ph [1 ]+ " " )
98+ result_str = result_str .replace ("*" + ph [0 ], " " + ph [1 ]+ " " )
99+ result_str = result_str .replace (ph [0 ], " " + ph [1 ]+ " " )
100+ for expr in create_expression_set (result_str , parameters ):
101+ result_str_set .add (expr )
102+ result_str = '{' + ', ' .join (result_str_set )+ '}'
103+ else :
104+ result_str = str (expression .xreplace (substitutions ))
94105
95106 return result_str
96107
0 commit comments