Skip to content

Commit 1e00c6e

Browse files
committed
Fixed issues with plus_minus and other failing tests
1 parent 7cce2fa commit 1e00c6e

File tree

1 file changed

+10
-24
lines changed

1 file changed

+10
-24
lines changed

app/utility/expression_utilities.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -825,8 +825,9 @@ def parse_expression(expr_string, parsing_params):
825825
extra_transformations = parsing_params.get("extra_transformations", ())
826826
unsplittable_symbols = parsing_params.get("unsplittable_symbols", ())
827827
symbol_dict = parsing_params.get("symbol_dict", {})
828+
separate_unsplittable_symbols = [(x, " "+x) for x in unsplittable_symbols]
829+
substitutions = separate_unsplittable_symbols
828830

829-
# --- Ensure factorial and factorial2 are known and not split ---
830831
symbol_dict = dict(symbol_dict)
831832
symbol_dict.setdefault("factorial", _sympy_factorial)
832833
symbol_dict.setdefault("factorial2", _sympy_factorial2)
@@ -837,40 +838,26 @@ def parse_expression(expr_string, parsing_params):
837838
+ [s for s in ('factorial', 'factorial2') if s not in unsplittable_symbols]
838839
)
839840

840-
separate_unsplittable_symbols = [(x, " "+x) for x in unsplittable_symbols]
841-
substitutions = separate_unsplittable_symbols
842-
843841
parsed_expr_set = set()
844842
for expr in expr_set:
845843
expr = preprocess_according_to_chosen_convention(expr, parsing_params)
846-
847844
substitutions = list(set(substitutions))
848845
substitutions.sort(key=substitutions_sort_key)
849-
if parsing_params.get("elementary_functions") is True:
846+
if parsing_params["elementary_functions"] is True:
850847
substitutions += protect_elementary_functions_substitutions(expr)
851-
852-
expr = convert_double_bang_factorial(expr) # n!! -> factorial2(n)
853-
expr = convert_bang_factorial(expr) # n! -> factorial(n)
854-
848+
expr = convert_double_bang_factorial(expr)
849+
expr = convert_bang_factorial(expr)
855850
substitutions = list(set(substitutions))
856851
substitutions.sort(key=substitutions_sort_key)
857852
expr = substitute(expr, substitutions)
858853
expr = " ".join(expr.split())
859-
860854
can_split = lambda x: False if x in unsplittable_symbols else _token_splittable(x)
861855
if strict_syntax is True:
862856
transformations = parser_transformations[0:4]+extra_transformations
863857
else:
864-
transformations = (
865-
parser_transformations[0:5, 6] # keep your existing set-up
866-
+ extra_transformations
867-
+ (split_symbols_custom(can_split),)
868-
+ parser_transformations[8, 9]
869-
)
870-
858+
transformations = parser_transformations[0:5, 6]+extra_transformations+(split_symbols_custom(can_split),)+parser_transformations[8, 9]
871859
if parsing_params.get("rationalise", False):
872860
transformations += parser_transformations[11]
873-
874861
if "=" in expr:
875862
expr_parts = expr.split("=")
876863
lhs = parse_expr(expr_parts[0], transformations=transformations, local_dict=symbol_dict)
@@ -882,12 +869,11 @@ def parse_expression(expr_string, parsing_params):
882869
parsed_expr = parsed_expr.simplify()
883870
else:
884871
parsed_expr = parse_expr(expr, transformations=transformations, local_dict=symbol_dict, evaluate=False)
885-
886872
if not isinstance(parsed_expr, Basic):
887873
raise ValueError(f"Failed to parse Sympy expression `{expr}`")
888874
parsed_expr_set.add(parsed_expr)
889875

890-
if len(expr_set) == 1:
891-
return parsed_expr_set.pop()
892-
else:
893-
return parsed_expr_set
876+
if len(expr_set) == 1:
877+
return parsed_expr_set.pop()
878+
else:
879+
return parsed_expr_set

0 commit comments

Comments
 (0)