@@ -101,46 +101,60 @@ def do_comparison(comparison_symbol, expression):
101101
102102def check_equality (criterion , parameters_dict , local_substitutions = []):
103103 lhs_expr , rhs_expr = create_expressions_for_comparison (criterion , parameters_dict , local_substitutions )
104- result = do_comparison (criterion .content , lhs_expr - rhs_expr )
105-
106- # TODO: Make numerical comparison its own context
107- if result is False :
108- error_below_rtol = None
109- error_below_atol = None
110- if parameters_dict .get ("numerical" , False ) or float (parameters_dict .get ("rtol" , 0 )) > 0 or float (parameters_dict .get ("atol" , 0 )) > 0 :
111- # REMARK: 'pi' should be a reserved symbol but it is sometimes not treated as one, possibly because of input symbols.
112- # The two lines below this comments fixes the issue but a more robust solution should be found for cases where there
113- # are other reserved symbols.
114- def replace_pi (expr ):
115- pi_symbol = pi
116- for s in expr .free_symbols :
117- if str (s ) == 'pi' :
118- pi_symbol = s
119- return expr .subs (pi_symbol , float (pi ))
120- # NOTE: This code assumes that the left hand side is the response and the right hand side is the answer
121- # Separates LHS and RHS, parses and evaluates them
122- res = N (replace_pi (lhs_expr ))
123- ans = N (replace_pi (rhs_expr ))
124- if float (parameters_dict .get ("atol" , 0 )) > 0 :
125- try :
126- absolute_error = abs (float (ans - res ))
127- error_below_atol = bool (absolute_error < float (parameters_dict ["atol" ]))
128- except TypeError :
129- error_below_atol = None
130- else :
131- error_below_atol = True
132- if float (parameters_dict .get ("rtol" , 0 )) > 0 :
133- try :
134- relative_error = abs (float ((ans - res )/ ans ))
135- error_below_rtol = bool (relative_error < float (parameters_dict ["rtol" ]))
136- except TypeError :
137- error_below_rtol = None
138- else :
139- error_below_rtol = True
140- if error_below_atol is None or error_below_rtol is None :
141- result = False
142- elif error_below_atol is True and error_below_rtol is True :
143- result = True
104+ if isinstance (lhs_expr , Equality ) and not isinstance (rhs_expr , Equality ):
105+ result = False
106+ elif not isinstance (lhs_expr , Equality ) and isinstance (rhs_expr , Equality ):
107+ result = False
108+ else :
109+ result = do_comparison (criterion .content , lhs_expr - rhs_expr )
110+ # There are some types of expression, e.g. those containing hyperbolic trigonometric functions, that can behave
111+ # unpredictably when simplification is applied. For that reason we check several different combinations of
112+ # simplifications here in order to reduce the likelihood of false negatives.
113+ if result is False :
114+ result = do_comparison (criterion .content , lhs_expr - rhs_expr .simplify ())
115+ if result is False :
116+ result = do_comparison (criterion .content , lhs_expr .simplify ()- rhs_expr )
117+ if result is False :
118+ result = do_comparison (criterion .content , lhs_expr .simplify ()- rhs_expr .simplify ())
119+
120+ # TODO: Make numerical comparison its own context
121+ if result is False :
122+ error_below_rtol = None
123+ error_below_atol = None
124+ if parameters_dict .get ("numerical" , False ) or float (parameters_dict .get ("rtol" , 0 )) > 0 or float (parameters_dict .get ("atol" , 0 )) > 0 :
125+ # REMARK: 'pi' should be a reserved symbol but it is sometimes not treated as one, possibly because of input symbols.
126+ # The two lines below this comments fixes the issue but a more robust solution should be found for cases where there
127+ # are other reserved symbols.
128+ def replace_pi (expr ):
129+ pi_symbol = pi
130+ for s in expr .free_symbols :
131+ if str (s ) == 'pi' :
132+ pi_symbol = s
133+ return expr .subs (pi_symbol , float (pi ))
134+ # NOTE: This code assumes that the left hand side is the response and the right hand side is the answer
135+ # Separates LHS and RHS, parses and evaluates them
136+ res = N (replace_pi (lhs_expr ))
137+ ans = N (replace_pi (rhs_expr ))
138+ if float (parameters_dict .get ("atol" , 0 )) > 0 :
139+ try :
140+ absolute_error = abs (float (ans - res ))
141+ error_below_atol = bool (absolute_error < float (parameters_dict ["atol" ]))
142+ except TypeError :
143+ error_below_atol = None
144+ else :
145+ error_below_atol = True
146+ if float (parameters_dict .get ("rtol" , 0 )) > 0 :
147+ try :
148+ relative_error = abs (float ((ans - res )/ ans ))
149+ error_below_rtol = bool (relative_error < float (parameters_dict ["rtol" ]))
150+ except TypeError :
151+ error_below_rtol = None
152+ else :
153+ error_below_rtol = True
154+ if error_below_atol is None or error_below_rtol is None :
155+ result = False
156+ elif error_below_atol is True and error_below_rtol is True :
157+ result = True
144158
145159 return result
146160
@@ -252,7 +266,12 @@ def set_equivalence(unused_input):
252266 result = None
253267 for j , answer in enumerate (answer_list ):
254268 current_pair = [("response" , response ), ("answer" , answer )]
255- result = check_equality (criterion , parameters_dict , local_substitutions = current_pair )
269+ if isinstance (response , Equality ) and not isinstance (answer , Equality ):
270+ result = False
271+ elif not isinstance (response , Equality ) and isinstance (answer , Equality ):
272+ result = False
273+ else :
274+ result = check_equality (criterion , parameters_dict , local_substitutions = current_pair )
256275 if result is True :
257276 matches ["responses" ][i ] = True
258277 matches ["answers" ][j ] = True
@@ -397,6 +416,14 @@ def same_symbols(unused_input):
397416 details = "Checks if " + str (lhs )+ " is equivalent to " + str (rhs )+ "." ,
398417 evaluate = equality_equivalence
399418 )
419+ graph .attach (
420+ label ,
421+ label + "_UNKNOWN" ,
422+ summary = "Cannot determine if " + str (lhs )+ " is equivalent to " + str (rhs ),
423+ details = "Cannot determine if " + str (lhs )+ " is equivalent to " + str (rhs )+ "." ,
424+ feedback_string_generator = symbolic_feedback_string_generators ["INTERNAL" ]("EQUALITY_EQUIVALENCE_UNKNOWN" )
425+ )
426+ graph .attach (label + "_UNKNOWN" , END .label )
400427 graph .attach (
401428 label ,
402429 label + "_TRUE" ,
@@ -474,6 +501,14 @@ def same_symbols(unused_input):
474501 feedback_string_generator = symbolic_feedback_string_generators ["response=answer" ]("FALSE" )
475502 )
476503 graph .attach (label + "_FALSE" , END .label )
504+ graph .attach (
505+ label ,
506+ label + "_UNKNOWN" ,
507+ summary = "Cannot detrmine if " + str (lhs )+ "=" + str (rhs ),
508+ details = "Cannot detrmine if " + str (lhs )+ " is equal to " + str (rhs )+ "." ,
509+ feedback_string_generator = symbolic_feedback_string_generators ["response=answer" ]("UNKNOWN" )
510+ )
511+ graph .attach (label + "_UNKNOWN" , END .label )
477512 return graph
478513
479514
0 commit comments