Skip to content

Commit a762afc

Browse files
Added capacity to have order comparisons (>, >=, <, <=) in criteria
1 parent b467a44 commit a762afc

File tree

3 files changed

+73
-14
lines changed

3 files changed

+73
-14
lines changed

app/context/symbolic.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def check_criterion(criterion, parameters_dict, generate_feedback=True):
3535
parsing_params.update({"simplify": False})
3636
if label in {"EQUALITY", "WRITTEN_AS"}:
3737
result = check_equality(criterion, parameters_dict)
38+
if label == "ORDER":
39+
result = check_order(criterion, parameters_dict)
3840
elif label == "WHERE":
3941
crit = criterion.children[0]
4042
subs = criterion.children[1]
@@ -53,7 +55,7 @@ def check_criterion(criterion, parameters_dict, generate_feedback=True):
5355
return result
5456

5557

56-
def check_equality(criterion, parameters_dict, local_substitutions=[]):
58+
def create_expressions_for_comparison(criterion, parameters_dict, local_substitutions=[]):
5759
parsing_params = deepcopy(parameters_dict["parsing_parameters"])
5860
reserved_expressions = list(parameters_dict["reserved_expressions"].items())
5961
parsing_params.update(
@@ -65,7 +67,6 @@ def check_equality(criterion, parameters_dict, local_substitutions=[]):
6567
)
6668
lhs = criterion.children[0].content_string()
6769
rhs = criterion.children[1].content_string()
68-
6970
lhs_expr = parse_expression(lhs, parsing_params).subs(local_substitutions).subs(reserved_expressions).subs(local_substitutions)
7071
rhs_expr = parse_expression(rhs, parsing_params).subs(local_substitutions).subs(reserved_expressions).subs(local_substitutions)
7172
if parsing_params.get("complexNumbers", False):
@@ -74,15 +75,31 @@ def check_equality(criterion, parameters_dict, local_substitutions=[]):
7475
if (im(lhs_expr) != 0) or (im(lhs_expr) != 0):
7576
lhs_expr = real_part(simplified_lhs_expr) + I*im(simplified_lhs_expr)
7677
rhs_expr = real_part(simplified_rhs_expr) + I*im(simplified_rhs_expr)
77-
expression = (lhs_expr - rhs_expr)
78-
result = bool(expression.cancel().simplify().simplify() == 0)
78+
return lhs_expr, rhs_expr
79+
80+
81+
def do_comparison(comparison_symbol, expression):
82+
comparisons = {
83+
"=": lambda expr: bool(expression.cancel().simplify().simplify() == 0),
84+
">": lambda expr: bool(expression.cancel().simplify().simplify() > 0),
85+
">=": lambda expr: bool(expression.cancel().simplify().simplify() >= 0),
86+
"<": lambda expr: bool(expression.cancel().simplify().simplify() < 0),
87+
"<=": lambda expr: bool(expression.cancel().simplify().simplify() <= 0),
88+
}
89+
comparison = comparisons[comparison_symbol.strip()]
90+
result = comparison(expression)
91+
return result
92+
93+
94+
def check_equality(criterion, parameters_dict, local_substitutions=[]):
95+
lhs_expr, rhs_expr = create_expressions_for_comparison(criterion, parameters_dict, local_substitutions)
96+
result = do_comparison(criterion.content, lhs_expr-rhs_expr)
7997

8098
# TODO: Make numerical comparison its own context
8199
if result is False:
82100
error_below_rtol = None
83101
error_below_atol = None
84102
if parameters_dict.get("numerical", False) or float(parameters_dict.get("rtol", 0)) > 0 or float(parameters_dict.get("atol", 0)) > 0:
85-
86103
# REMARK: 'pi' should be a reserved symbol but it is sometimes not treated as one, possibly because of input symbols.
87104
# The two lines below this comments fixes the issue but a more robust solution should be found for cases where there
88105
# are other reserved symbols.
@@ -92,12 +109,10 @@ def replace_pi(expr):
92109
if str(s) == 'pi':
93110
pi_symbol = s
94111
return expr.subs(pi_symbol, float(pi))
95-
96112
# NOTE: This code assumes that the left hand side is the response and the right hand side is the answer
97113
# Separates LHS and RHS, parses and evaluates them
98114
res = N(replace_pi(lhs_expr))
99115
ans = N(replace_pi(rhs_expr))
100-
101116
if float(parameters_dict.get("atol", 0)) > 0:
102117
try:
103118
absolute_error = abs(float(ans-res))
@@ -122,6 +137,12 @@ def replace_pi(expr):
122137
return result
123138

124139

140+
def check_order(criterion, parameters_dict, local_substitutions=[]):
141+
lhs_expr, rhs_expr = create_expressions_for_comparison(criterion, parameters_dict, local_substitutions)
142+
result = do_comparison(criterion.content, lhs_expr-rhs_expr)
143+
return result
144+
145+
125146
def find_coords_for_node_type(expression, node_type):
126147
stack = [(expression, tuple())]
127148
node_coords = []
@@ -683,16 +704,18 @@ def get_candidates(unused_input):
683704

684705

685706
def criterion_eval_node(criterion, parameters_dict, generate_feedback=True):
707+
feedback_string_generator_inputs = {'criterion': criterion}
708+
686709
def evaluation_node_internal(unused_input):
687710
result = check_criterion(criterion, parameters_dict, generate_feedback)
688711
label = criterion.content_string()
689712
if result:
690713
return {
691-
label+"_TRUE": None
714+
label+"_TRUE": feedback_string_generator_inputs
692715
}
693716
else:
694717
return {
695-
label+"_FALSE": None
718+
label+"_FALSE": feedback_string_generator_inputs
696719
}
697720
label = criterion.content_string()
698721
graph = CriteriaGraph(label)

app/tests/symbolic_evaluation_tests.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,6 +1780,32 @@ def test_criteria_where_numerical_comparison(self, response, answer, criteria, v
17801780
for feedback_tag in feedback_tags:
17811781
assert feedback_tag in result["tags"]
17821782

1783+
@pytest.mark.parametrize(
1784+
"response, answer, criteria, value",
1785+
[
1786+
("1", "2", "response > answer", False),
1787+
("2", "2", "response > answer", False),
1788+
("3", "2", "response > answer", True),
1789+
("1", "2", "response >= answer", False),
1790+
("2", "2", "response >= answer", True),
1791+
("3", "2", "response >= answer", True),
1792+
("1", "2", "response < answer", True),
1793+
("2", "2", "response < answer", False),
1794+
("3", "2", "response < answer", False),
1795+
("1", "2", "response <= answer", True),
1796+
("2", "2", "response <= answer", True),
1797+
("3", "2", "response <= answer", False),
1798+
]
1799+
)
1800+
def test_criteria_order_comparison(self, response, answer, criteria, value):
1801+
params = {
1802+
"strict_syntax": False,
1803+
"elementary_functions": True,
1804+
"criteria": criteria,
1805+
}
1806+
result = evaluation_function(response, answer, params)
1807+
assert result["is_correct"] is value
1808+
17831809
@pytest.mark.parametrize(
17841810
"response, answer, value",
17851811
[

app/utility/criteria_parsing.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
(" *EQUAL_LIST *", "EQUAL_LIST"),
1515
(" *RESERVED *", "RESERVED"),
1616
(" *= *", "EQUALITY"),
17+
(" *(>=?|<=?|ORDER) *", "ORDER"), # less than (or equal), < (<=), greater than (or equal), > (>=)
1718
(" *where *", "WHERE"),
1819
(" *written as *", "WRITTEN_AS"),
1920
(" *; *", "SEPARATOR"),
@@ -23,6 +24,7 @@
2324
base_productions = [
2425
("START", "BOOL", create_node),
2526
("BOOL", "EQUAL", proceed),
27+
("BOOL", "ORDER", proceed),
2628
("BOOL", "EQUAL where EQUAL", infix),
2729
("BOOL", "EQUAL where EQUAL_LIST", infix),
2830
("BOOL", "RESERVED written as OTHER", infix),
@@ -33,6 +35,10 @@
3335
("EQUAL", "RESERVED = OTHER", infix),
3436
("EQUAL", "OTHER = RESERVED", infix),
3537
("EQUAL", "RESERVED = RESERVED", infix),
38+
("EQUAL", "OTHER ORDER OTHER", infix),
39+
("EQUAL", "RESERVED ORDER OTHER", infix),
40+
("EQUAL", "OTHER ORDER RESERVED", infix),
41+
("EQUAL", "RESERVED ORDER RESERVED", infix),
3642
("OTHER", "RESERVED OTHER", join),
3743
("OTHER", "OTHER RESERVED", join),
3844
("OTHER", "OTHER OTHER", join),
@@ -48,11 +54,15 @@ def generate_criteria_parser(reserved_expressions, token_list=base_token_list, p
4854

4955

5056
if __name__ == "__main__":
51-
test_criteria = [
52-
"a = b",
53-
"response = b",
54-
"a = response",
55-
"response = answer",
57+
test_criteria = []
58+
for comparison in ["=", ">", "<", ">=", "<="]:
59+
test_criteria += [
60+
f"a {comparison} b",
61+
f"response {comparison} b",
62+
f"a {comparison} response",
63+
f"response {comparison} answer",
64+
]
65+
test_criteria += [
5666
"response = b*answer",
5767
"response = q where q = a*b",
5868
"response = q+p where q = a*b; p = b*c",

0 commit comments

Comments
 (0)