11from typing import TypedDict
22from typing_extensions import NotRequired
33
4- import sympy
4+ from sympy import Symbol
55from latex2sympy2 import latex2sympy
66
77from .expression_utilities import (
88 extract_latex ,
99 SymbolDict ,
1010 find_matching_parenthesis ,
11+ create_expression_set ,
1112)
1213
1314
@@ -27,38 +28,50 @@ class Result(TypedDict):
2728 preview : Preview
2829
2930
30- def parse_latex (response : str , symbols : SymbolDict , simplify : bool ) -> str :
31+ def parse_latex (response : str , symbols : SymbolDict , simplify : bool , parameters = None ) -> str :
3132 """Parse a LaTeX string to a sympy string while preserving custom symbols.
3233
3334 Args:
3435 response (str): The LaTeX expression to parse.
3536 symbols (SymbolDict): A mapping of sympy symbol strings and LaTeX
36- symbol strings.
37+ symbol strings.
38+ simplify (bool): If set to false the preview will attempt to preserve
39+ the way that the response was written as much as possible. If set
40+ to True the response will be simplified before the preview string
41+ is generated.
42+ parameters (dict): parameters used when generating sympy output when
43+ the response is written in LaTeX
3744
3845 Raises:
3946 ValueError: If the LaTeX string or symbol couldn't be parsed.
4047
4148 Returns:
4249 str: The expression in sympy syntax.
4350 """
51+ if parameters is None :
52+ parameters = dict ()
53+
4454 substitutions = {}
4555
4656 pm_placeholder = None
4757 mp_placeholder = None
4858
4959 if r"\pm " in response or r"\mp " in response :
60+ response_set = set ()
5061 for char in 'abcdefghjkoqrtvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' :
5162 if char not in response and pm_placeholder is None :
5263 pm_placeholder = char
64+ substitutions [pm_placeholder ] = Symbol (pm_placeholder , commutative = False )
5365 elif char not in response and mp_placeholder is None :
5466 mp_placeholder = char
67+ substitutions [mp_placeholder ] = Symbol (mp_placeholder , commutative = False )
5568 if pm_placeholder is not None and mp_placeholder is not None :
5669 break
57-
58- if pm_placeholder is not None :
59- response = response . replace ( r"\pm " , pm_placeholder )
60- if mp_placeholder is not None :
61- response = response . replace ( r"\mp " , mp_placeholder )
70+ for expr in create_expression_set ( response . replace ( r"\pm " , 'plus_minus' ). replace ( r"\mp " , 'minus_plus' ), parameters ):
71+ response_set . add ( expr )
72+ response = response_set
73+ else :
74+ response_set = { response }
6275
6376 for sympy_symbol_str in symbols :
6477 symbol_str = symbols [sympy_symbol_str ]["latex" ]
@@ -72,28 +85,25 @@ def parse_latex(response: str, symbols: SymbolDict, simplify: bool) -> str:
7285 f"Couldn't parse latex symbol { latex_symbol_str } "
7386 f"to sympy symbol."
7487 )
75- substitutions [latex_symbol ] = sympy .Symbol (sympy_symbol_str )
76-
77- substitutions .update ({r"\pm " : pm_placeholder , r"\mp " : mp_placeholder })
78-
79- try :
80- expression = latex2sympy (response , substitutions )
81- if isinstance (expression , list ):
82- expression = expression .pop ()
83- if simplify is True :
84- expression = expression .simplify ()
85- except Exception as e :
86- raise ValueError (str (e ))
87-
88- result_str = str (expression .xreplace (substitutions ))
89- for ph in [(pm_placeholder , "plus_minus" ), (mp_placeholder , "minus_plus" )]:
90- if ph [0 ] is not None :
91- result_str = result_str .replace ("*" + ph [0 ]+ "*" , " " + ph [1 ]+ " " )
92- result_str = result_str .replace (ph [0 ]+ "*" , " " + ph [1 ]+ " " )
93- result_str = result_str .replace ("*" + ph [0 ], " " + ph [1 ]+ " " )
94- result_str = result_str .replace (ph [0 ], " " + ph [1 ]+ " " )
95-
96- return result_str
88+ substitutions [latex_symbol ] = Symbol (sympy_symbol_str )
89+
90+ parsed_responses = set ()
91+ for expression in response_set :
92+ try :
93+ expression = latex2sympy (expression , substitutions )
94+ if isinstance (expression , list ):
95+ expression = expression .pop ()
96+ if simplify is True :
97+ expression = expression .simplify ()
98+ except Exception as e :
99+ raise ValueError (str (e ))
100+
101+ parsed_responses .add (str (expression .xreplace (substitutions )))
102+
103+ if len (parsed_responses ) < 2 :
104+ return parsed_responses .pop ()
105+ else :
106+ return '{' + ', ' .join (parsed_responses )+ '}'
97107
98108
99109def sanitise_latex (response ):
0 commit comments