Skip to content

Commit 50290c0

Browse files
Fixed bug with latex input that contained \pm or \mp
1 parent fcee544 commit 50290c0

File tree

4 files changed

+100
-34
lines changed

4 files changed

+100
-34
lines changed

app/evaluation_tests.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,17 @@ def test_eval_function_can_handle_latex_input(self):
4242
result = evaluation_function(response, answer, params)
4343
assert result["is_correct"] is True
4444

45+
def test_eval_function_preserves_order_in_latex_input(self):
46+
response = r"c + a + b"
47+
answer = "c + a + b"
48+
params = {
49+
"strict_syntax": False,
50+
"elementary_functions": True,
51+
"is_latex": True
52+
}
53+
result = evaluation_function(response, answer, params)
54+
assert result["is_correct"] is True
55+
4556
def test_AERO40007_1_6_instance_2024_25(self):
4657
params = {
4758
"strict_syntax": False,

app/tests/symbolic_preview_tests.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,44 @@ def test_sympy_with_equality_symbol(self):
7777
preview = result["preview"]
7878
assert preview.get("latex") == "\\frac{x^{2} + x + x}{x} = 1"
7979

80+
def test_latex_with_plus_minus(self):
81+
response = r"\pm \frac{3}{\sqrt{5}} i"
82+
params = Params(
83+
is_latex=True,
84+
simplify=False,
85+
complexNumbers=True,
86+
symbols={
87+
"I": {
88+
"latex": "$i$",
89+
"aliases": ["i"],
90+
},
91+
"plus_minus": {
92+
"latex": "$\\pm$",
93+
"aliases": ["pm", "+-"],
94+
},
95+
}
96+
)
97+
result = preview_function(response, params)
98+
preview = result["preview"]
99+
assert preview.get("sympy") in {'{3*(sqrt(5)/5)*I, -3*sqrt(5)/5*I}', '{-3*sqrt(5)/5*I, 3*(sqrt(5)/5)*I}'}
100+
assert preview.get("latex") == r'\pm \frac{3}{\sqrt{5}} i'
101+
response = r"4 \pm \sqrt{6}}"
102+
params = Params(
103+
is_latex=True,
104+
simplify=False,
105+
complexNumbers=True,
106+
symbols={
107+
"plus_minus": {
108+
"latex": "$\\pm$",
109+
"aliases": ["pm", "+-"],
110+
},
111+
}
112+
)
113+
result = preview_function(response, params)
114+
preview = result["preview"]
115+
assert preview.get("sympy") in {'{sqrt(6) + 4, 4 - sqrt(6)}', '{4 - sqrt(6), sqrt(6) + 4}'}
116+
assert preview.get("latex") == r'4 \pm \sqrt{6}}'
117+
80118
def test_latex_conversion_preserves_default_symbols(self):
81119
response = "\\mu + x + 1"
82120
params = Params(is_latex=True, simplify=False)

app/utility/expression_utilities.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,25 @@ def _print_log(self, expr, exp=None):
7272
# -------- String Manipulation Utilities
7373
def create_expression_set(exprs, params):
7474
if isinstance(exprs, str):
75-
exprs = [exprs]
75+
if exprs.startswith('{') and exprs.endswith('}'):
76+
exprs = [expr.strip() for expr in exprs[1:-1].split(',')]
77+
else:
78+
exprs = [exprs]
7679
expr_set = set()
77-
7880
for expr in exprs:
7981
expr = substitute_input_symbols(expr, params)[0]
8082
if "plus_minus" in params.keys():
8183
expr = expr.replace(params["plus_minus"], "plus_minus")
84+
8285
if "minus_plus" in params.keys():
8386
expr = expr.replace(params["minus_plus"], "minus_plus")
87+
8488
if ("plus_minus" in expr) or ("minus_plus" in expr):
85-
expr_set.add(expr.replace("plus_minus", "+").replace("minus_plus", "-"))
86-
expr_set.add(expr.replace("plus_minus", "-").replace("minus_plus", "+"))
89+
for pm_mp_ops in [("+","-"),("-","+")]:
90+
expr_string = expr.replace("plus_minus", pm_mp_ops[0]).replace("minus_plus", pm_mp_ops[1]).strip()
91+
while expr_string[0] == "+":
92+
expr_string = expr_string[1:]
93+
expr_set.add(expr_string.strip())
8794
else:
8895
expr_set.add(expr)
8996

app/utility/preview_utilities.py

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from typing import TypedDict
22
from typing_extensions import NotRequired
33

4-
import sympy
4+
from sympy import Symbol
55
from latex2sympy2 import latex2sympy
66

77
from .expression_utilities import (
88
extract_latex,
99
SymbolDict,
1010
find_matching_parenthesis,
11+
create_expression_set,
1112
)
1213

1314

@@ -27,38 +28,50 @@ class Result(TypedDict):
2728
preview: Preview
2829

2930

30-
def parse_latex(response: str, symbols: SymbolDict, simplify: bool) -> str:
31+
def parse_latex(response: str, symbols: SymbolDict, simplify: bool, parameters=None) -> str:
3132
"""Parse a LaTeX string to a sympy string while preserving custom symbols.
3233
3334
Args:
3435
response (str): The LaTeX expression to parse.
3536
symbols (SymbolDict): A mapping of sympy symbol strings and LaTeX
36-
symbol strings.
37+
symbol strings.
38+
simplify (bool): If set to false the preview will attempt to preserve
39+
the way that the response was written as much as possible. If set
40+
to True the response will be simplified before the preview string
41+
is generated.
42+
parameters (dict): parameters used when generating sympy output when
43+
the response is written in LaTeX
3744
3845
Raises:
3946
ValueError: If the LaTeX string or symbol couldn't be parsed.
4047
4148
Returns:
4249
str: The expression in sympy syntax.
4350
"""
51+
if parameters is None:
52+
parameters = dict()
53+
4454
substitutions = {}
4555

4656
pm_placeholder = None
4757
mp_placeholder = None
4858

4959
if r"\pm " in response or r"\mp " in response:
60+
response_set = set()
5061
for char in 'abcdefghjkoqrtvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ':
5162
if char not in response and pm_placeholder is None:
5263
pm_placeholder = char
64+
substitutions[pm_placeholder] = Symbol(pm_placeholder, commutative=False)
5365
elif char not in response and mp_placeholder is None:
5466
mp_placeholder = char
67+
substitutions[mp_placeholder] = Symbol(mp_placeholder, commutative=False)
5568
if pm_placeholder is not None and mp_placeholder is not None:
5669
break
57-
58-
if pm_placeholder is not None:
59-
response = response.replace(r"\pm ", pm_placeholder)
60-
if mp_placeholder is not None:
61-
response = response.replace(r"\mp ", mp_placeholder)
70+
for expr in create_expression_set(response.replace(r"\pm ",'plus_minus').replace(r"\mp ",'minus_plus'), parameters):
71+
response_set.add(expr)
72+
response = response_set
73+
else:
74+
response_set = {response}
6275

6376
for sympy_symbol_str in symbols:
6477
symbol_str = symbols[sympy_symbol_str]["latex"]
@@ -72,28 +85,25 @@ def parse_latex(response: str, symbols: SymbolDict, simplify: bool) -> str:
7285
f"Couldn't parse latex symbol {latex_symbol_str} "
7386
f"to sympy symbol."
7487
)
75-
substitutions[latex_symbol] = sympy.Symbol(sympy_symbol_str)
76-
77-
substitutions.update({r"\pm ": pm_placeholder, r"\mp ": mp_placeholder})
78-
79-
try:
80-
expression = latex2sympy(response, substitutions)
81-
if isinstance(expression, list):
82-
expression = expression.pop()
83-
if simplify is True:
84-
expression = expression.simplify()
85-
except Exception as e:
86-
raise ValueError(str(e))
87-
88-
result_str = str(expression.xreplace(substitutions))
89-
for ph in [(pm_placeholder, "plus_minus"), (mp_placeholder, "minus_plus")]:
90-
if ph[0] is not None:
91-
result_str = result_str.replace("*"+ph[0]+"*", " "+ph[1]+" ")
92-
result_str = result_str.replace(ph[0]+"*", " "+ph[1]+" ")
93-
result_str = result_str.replace("*"+ph[0], " "+ph[1]+" ")
94-
result_str = result_str.replace(ph[0], " "+ph[1]+" ")
95-
96-
return result_str
88+
substitutions[latex_symbol] = Symbol(sympy_symbol_str)
89+
90+
parsed_responses = set()
91+
for expression in response_set:
92+
try:
93+
expression = latex2sympy(expression, substitutions)
94+
if isinstance(expression, list):
95+
expression = expression.pop()
96+
if simplify is True:
97+
expression = expression.simplify()
98+
except Exception as e:
99+
raise ValueError(str(e))
100+
101+
parsed_responses.add(str(expression.xreplace(substitutions)))
102+
103+
if len(parsed_responses) < 2:
104+
return parsed_responses.pop()
105+
else:
106+
return '{'+', '.join(parsed_responses)+'}'
97107

98108

99109
def sanitise_latex(response):

0 commit comments

Comments
 (0)