@@ -49,21 +49,24 @@ def parse_latex(response: str, symbols: SymbolDict, simplify : bool, parameters=
4949 pm_placeholder = None
5050 mp_placeholder = None
5151
52+ results = set ()
53+
5254 if r"\pm " in response or r"\mp " in response :
55+ response_set = set ()
5356 for char in 'abcdefghjkoqrtvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' :
5457 if char not in response and pm_placeholder is None :
5558 pm_placeholder = char
59+ substitutions [pm_placeholder ] = sympy .Symbol (pm_placeholder , commutative = False )
5660 elif char not in response and mp_placeholder is None :
5761 mp_placeholder = char
62+ substitutions [mp_placeholder ] = sympy .Symbol (mp_placeholder , commutative = False )
5863 if pm_placeholder is not None and mp_placeholder is not None :
5964 break
60-
61- if pm_placeholder is not None :
62- response = response .replace (r"\pm " , pm_placeholder )
63- substitutions [pm_placeholder ] = sympy .Symbol (pm_placeholder , commutative = False )
64- if mp_placeholder is not None :
65- response = response .replace (r"\mp " , mp_placeholder )
66- substitutions [mp_placeholder ] = sympy .Symbol (mp_placeholder , commutative = False )
65+ for expr in create_expression_set (response .replace (r"\pm " ,'plus_minus' ).replace (r"\mp " ,'minus_plus' ), parameters ):
66+ response_set .add (expr )
67+ response = response_set
68+ else :
69+ response_set = {response }
6770
6871 for sympy_symbol_str in symbols :
6972 symbol_str = symbols [sympy_symbol_str ]["latex" ]
@@ -79,28 +82,23 @@ def parse_latex(response: str, symbols: SymbolDict, simplify : bool, parameters=
7982 )
8083 substitutions [latex_symbol ] = sympy .Symbol (sympy_symbol_str )
8184
82- try :
83- expression = latex2sympy (response , substitutions )
84- if isinstance (expression , list ):
85- expression = expression .pop ()
86- if simplify is True :
87- expression = expression .simplify ()
88- except Exception as e :
89- raise ValueError (str (e ))
90-
91- if (pm_placeholder is not None ) or (mp_placeholder is not None ):
92- result_str = str (expression )
93- for ph in [(pm_placeholder , "plus_minus" ), (mp_placeholder , "minus_plus" )]:
94- if ph [0 ] is not None :
95- result_str = result_str .replace ("*" + ph [0 ]+ "*" , " " + ph [1 ]+ " " )
96- result_str = result_str .replace (ph [0 ]+ "*" , " " + ph [1 ]+ " " )
97- result_str = result_str .replace ("*" + ph [0 ], " " + ph [1 ]+ " " )
98- result_str = result_str .replace (ph [0 ], " " + ph [1 ]+ " " )
85+ parsed_responses = set ()
86+ for expression in response_set :
87+ try :
88+ expression = latex2sympy (expression , substitutions )
89+ if isinstance (expression , list ):
90+ expression = expression .pop ()
91+ if simplify is True :
92+ expression = expression .simplify ()
93+ except Exception as e :
94+ raise ValueError (str (e ))
95+
96+ parsed_responses .add (str (expression .xreplace (substitutions )))
97+
98+ if len (parsed_responses ) < 2 :
99+ return parsed_responses .pop ()
99100 else :
100- result_str = str (expression .xreplace (substitutions ))
101-
102- return result_str
103-
101+ return '{' + ', ' .join (parsed_responses )+ '}'
104102
105103def sanitise_latex (response ):
106104 response = "" .join (response .split ())
0 commit comments