@@ -35,6 +35,8 @@ def check_criterion(criterion, parameters_dict, generate_feedback=True):
3535 parsing_params .update ({"simplify" : False })
3636 if label in {"EQUALITY" , "WRITTEN_AS" }:
3737 result = check_equality (criterion , parameters_dict )
38+ if label == "ORDER" :
39+ result = check_order (criterion , parameters_dict )
3840 elif label == "WHERE" :
3941 crit = criterion .children [0 ]
4042 subs = criterion .children [1 ]
@@ -53,7 +55,7 @@ def check_criterion(criterion, parameters_dict, generate_feedback=True):
5355 return result
5456
5557
56- def check_equality (criterion , parameters_dict , local_substitutions = []):
58+ def create_expressions_for_comparison (criterion , parameters_dict , local_substitutions = []):
5759 parsing_params = deepcopy (parameters_dict ["parsing_parameters" ])
5860 reserved_expressions = list (parameters_dict ["reserved_expressions" ].items ())
5961 parsing_params .update (
@@ -65,7 +67,6 @@ def check_equality(criterion, parameters_dict, local_substitutions=[]):
6567 )
6668 lhs = criterion .children [0 ].content_string ()
6769 rhs = criterion .children [1 ].content_string ()
68-
6970 lhs_expr = parse_expression (lhs , parsing_params ).subs (local_substitutions ).subs (reserved_expressions ).subs (local_substitutions )
7071 rhs_expr = parse_expression (rhs , parsing_params ).subs (local_substitutions ).subs (reserved_expressions ).subs (local_substitutions )
7172 if parsing_params .get ("complexNumbers" , False ):
@@ -74,15 +75,31 @@ def check_equality(criterion, parameters_dict, local_substitutions=[]):
7475 if (im (lhs_expr ) != 0 ) or (im (lhs_expr ) != 0 ):
7576 lhs_expr = real_part (simplified_lhs_expr ) + I * im (simplified_lhs_expr )
7677 rhs_expr = real_part (simplified_rhs_expr ) + I * im (simplified_rhs_expr )
77- expression = (lhs_expr - rhs_expr )
78- result = bool (expression .cancel ().simplify ().simplify () == 0 )
78+ return lhs_expr , rhs_expr
79+
80+
81+ def do_comparison (comparison_symbol , expression ):
82+ comparisons = {
83+ "=" : lambda expr : bool (expression .cancel ().simplify ().simplify () == 0 ),
84+ ">" : lambda expr : bool (expression .cancel ().simplify ().simplify () > 0 ),
85+ ">=" : lambda expr : bool (expression .cancel ().simplify ().simplify () >= 0 ),
86+ "<" : lambda expr : bool (expression .cancel ().simplify ().simplify () < 0 ),
87+ "<=" : lambda expr : bool (expression .cancel ().simplify ().simplify () <= 0 ),
88+ }
89+ comparison = comparisons [comparison_symbol .strip ()]
90+ result = comparison (expression )
91+ return result
92+
93+
94+ def check_equality (criterion , parameters_dict , local_substitutions = []):
95+ lhs_expr , rhs_expr = create_expressions_for_comparison (criterion , parameters_dict , local_substitutions )
96+ result = do_comparison (criterion .content , lhs_expr - rhs_expr )
7997
8098 # TODO: Make numerical comparison its own context
8199 if result is False :
82100 error_below_rtol = None
83101 error_below_atol = None
84102 if parameters_dict .get ("numerical" , False ) or float (parameters_dict .get ("rtol" , 0 )) > 0 or float (parameters_dict .get ("atol" , 0 )) > 0 :
85-
86103 # REMARK: 'pi' should be a reserved symbol but it is sometimes not treated as one, possibly because of input symbols.
87104 # The two lines below this comments fixes the issue but a more robust solution should be found for cases where there
88105 # are other reserved symbols.
@@ -92,12 +109,10 @@ def replace_pi(expr):
92109 if str (s ) == 'pi' :
93110 pi_symbol = s
94111 return expr .subs (pi_symbol , float (pi ))
95-
96112 # NOTE: This code assumes that the left hand side is the response and the right hand side is the answer
97113 # Separates LHS and RHS, parses and evaluates them
98114 res = N (replace_pi (lhs_expr ))
99115 ans = N (replace_pi (rhs_expr ))
100-
101116 if float (parameters_dict .get ("atol" , 0 )) > 0 :
102117 try :
103118 absolute_error = abs (float (ans - res ))
@@ -122,6 +137,12 @@ def replace_pi(expr):
122137 return result
123138
124139
140+ def check_order (criterion , parameters_dict , local_substitutions = []):
141+ lhs_expr , rhs_expr = create_expressions_for_comparison (criterion , parameters_dict , local_substitutions )
142+ result = do_comparison (criterion .content , lhs_expr - rhs_expr )
143+ return result
144+
145+
125146def find_coords_for_node_type (expression , node_type ):
126147 stack = [(expression , tuple ())]
127148 node_coords = []
@@ -683,16 +704,18 @@ def get_candidates(unused_input):
683704
684705
685706def criterion_eval_node (criterion , parameters_dict , generate_feedback = True ):
707+ feedback_string_generator_inputs = {'criterion' : criterion }
708+
686709 def evaluation_node_internal (unused_input ):
687710 result = check_criterion (criterion , parameters_dict , generate_feedback )
688711 label = criterion .content_string ()
689712 if result :
690713 return {
691- label + "_TRUE" : None
714+ label + "_TRUE" : feedback_string_generator_inputs
692715 }
693716 else :
694717 return {
695- label + "_FALSE" : None
718+ label + "_FALSE" : feedback_string_generator_inputs
696719 }
697720 label = criterion .content_string ()
698721 graph = CriteriaGraph (label )
0 commit comments