Skip to content

Commit a07755d

Browse files
committed
Switched to in-built functionality of Sympy
1 parent e0ae030 commit a07755d

File tree

1 file changed

+5
-160
lines changed

1 file changed

+5
-160
lines changed

app/utility/expression_utilities.py

Lines changed: 5 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,6 @@
2424
from sympy.printing.latex import LatexPrinter
2525
from sympy import Basic, Symbol, Equality, Function
2626

27-
from sympy import factorial as _sympy_factorial
28-
from sympy.functions.combinatorial.factorials import factorial2 as _sympy_factorial2
29-
30-
3127
import re
3228
from typing import Dict, List, TypedDict
3329

@@ -665,149 +661,6 @@ def preprocess_expression(name, expr, parameters):
665661
success = False
666662
return success, expr, abs_feedback
667663

668-
def convert_double_bang_factorial(s: str) -> str:
669-
"""
670-
Convert double postfix factorial (e.g., n!!, (x+1)!!, 3!!) to function form: factorial2(n), etc.
671-
Safeguards:
672-
- Does NOT treat '!=' specially (since we target '!!').
673-
- Requires two consecutive '!' characters (no whitespace in between).
674-
- Handles balanced parenthesis operands (e.g., '(... )!!').
675-
- Handles identifiers and numeric literals.
676-
"""
677-
n = len(s)
678-
i = 0
679-
last = 0
680-
out = []
681-
682-
while i < n:
683-
ch = s[i]
684-
if ch == '!' and (i + 1) < n and s[i + 1] == '!':
685-
# Look left to find the operand (skip whitespace)
686-
j = i - 1
687-
while j >= 0 and s[j].isspace():
688-
j -= 1
689-
if j < 0:
690-
# Nothing to the left; keep as-is
691-
i += 1
692-
continue
693-
694-
# Case 1: operand ends with ')': parenthesized group
695-
if s[j] == ')':
696-
depth = 1
697-
k = j - 1
698-
while k >= 0 and depth > 0:
699-
if s[k] == ')':
700-
depth += 1
701-
elif s[k] == '(':
702-
depth -= 1
703-
k -= 1
704-
if depth == 0:
705-
L = k + 1 # index of '('
706-
R = j # index of ')'
707-
out.append(s[last:L])
708-
out.append('factorial2(')
709-
out.append(s[L:R+1])
710-
out.append(')')
711-
last = i + 2 # consume both '!'
712-
i += 2
713-
continue
714-
else:
715-
# Unbalanced parentheses; leave as-is
716-
i += 1
717-
continue
718-
719-
# Case 2: operand is an identifier/number ending at j
720-
k = j
721-
while k >= 0 and (s[k].isalnum() or s[k] in ('_', '.')):
722-
k -= 1
723-
L = k + 1
724-
if L <= j:
725-
out.append(s[last:L])
726-
out.append('factorial2(')
727-
out.append(s[L:j+1])
728-
out.append(')')
729-
last = i + 2
730-
i += 2
731-
continue
732-
# If we get here, no valid operand token; fall through and keep scanning.
733-
734-
i += 1
735-
736-
out.append(s[last:])
737-
return ''.join(out)
738-
739-
def convert_bang_factorial(s: str) -> str:
740-
"""
741-
Convert single postfix factorial (e.g., n!, (x+1)!, 3!) to function form: factorial(n), etc.
742-
Safeguards:
743-
- Does NOT convert '!='.
744-
- Does NOT convert '!!' (handled by convert_double_bang_factorial).
745-
"""
746-
n = len(s)
747-
i = 0
748-
last = 0
749-
out = []
750-
751-
while i < n:
752-
ch = s[i]
753-
if ch == '!':
754-
# Skip '!=' and '!!' (the latter handled in a separate pass)
755-
nxt = s[i+1] if i + 1 < n else ''
756-
if nxt in ('=', '!'):
757-
i += 1
758-
continue
759-
760-
# Move left to find the operand (skip whitespace)
761-
j = i - 1
762-
while j >= 0 and s[j].isspace():
763-
j -= 1
764-
if j < 0:
765-
i += 1
766-
continue
767-
768-
# Parenthesized operand
769-
if s[j] == ')':
770-
depth = 1
771-
k = j - 1
772-
while k >= 0 and depth > 0:
773-
if s[k] == ')':
774-
depth += 1
775-
elif s[k] == '(':
776-
depth -= 1
777-
k -= 1
778-
if depth == 0:
779-
L = k + 1
780-
R = j
781-
out.append(s[last:L])
782-
out.append('factorial(')
783-
out.append(s[L:R+1])
784-
out.append(')')
785-
last = i + 1
786-
i += 1
787-
continue
788-
else:
789-
i += 1
790-
continue
791-
792-
# Identifier/number operand
793-
k = j
794-
while k >= 0 and (s[k].isalnum() or s[k] in ('_', '.')):
795-
k -= 1
796-
L = k + 1
797-
if L <= j:
798-
out.append(s[last:L])
799-
out.append('factorial(')
800-
out.append(s[L:j+1])
801-
out.append(')')
802-
last = i + 1
803-
i += 1
804-
continue
805-
806-
i += 1
807-
808-
out.append(s[last:])
809-
return ''.join(out)
810-
811664

812665
def parse_expression(expr_string, parsing_params):
813666
'''
@@ -828,36 +681,27 @@ def parse_expression(expr_string, parsing_params):
828681
separate_unsplittable_symbols = [(x, " "+x) for x in unsplittable_symbols]
829682
substitutions = separate_unsplittable_symbols
830683

831-
symbol_dict = dict(symbol_dict)
832-
symbol_dict.setdefault("factorial", _sympy_factorial)
833-
symbol_dict.setdefault("factorial2", _sympy_factorial2)
834-
835-
if 'factorial' not in unsplittable_symbols or 'factorial2' not in unsplittable_symbols:
836-
unsplittable_symbols = tuple(
837-
list(unsplittable_symbols)
838-
+ [s for s in ('factorial', 'factorial2') if s not in unsplittable_symbols]
839-
)
840-
841684
parsed_expr_set = set()
842685
for expr in expr_set:
843686
expr = preprocess_according_to_chosen_convention(expr, parsing_params)
844687
substitutions = list(set(substitutions))
845688
substitutions.sort(key=substitutions_sort_key)
846689
if parsing_params["elementary_functions"] is True:
847690
substitutions += protect_elementary_functions_substitutions(expr)
848-
expr = convert_double_bang_factorial(expr)
849-
expr = convert_bang_factorial(expr)
691+
850692
substitutions = list(set(substitutions))
851693
substitutions.sort(key=substitutions_sort_key)
852694
expr = substitute(expr, substitutions)
853695
expr = " ".join(expr.split())
696+
854697
can_split = lambda x: False if x in unsplittable_symbols else _token_splittable(x)
855698
if strict_syntax is True:
856699
transformations = parser_transformations[0:4]+extra_transformations
857700
else:
858701
transformations = parser_transformations[0:5, 6]+extra_transformations+(split_symbols_custom(can_split),)+parser_transformations[8, 9]
859702
if parsing_params.get("rationalise", False):
860703
transformations += parser_transformations[11]
704+
861705
if "=" in expr:
862706
expr_parts = expr.split("=")
863707
lhs = parse_expr(expr_parts[0], transformations=transformations, local_dict=symbol_dict)
@@ -869,11 +713,12 @@ def parse_expression(expr_string, parsing_params):
869713
parsed_expr = parsed_expr.simplify()
870714
else:
871715
parsed_expr = parse_expr(expr, transformations=transformations, local_dict=symbol_dict, evaluate=False)
716+
872717
if not isinstance(parsed_expr, Basic):
873718
raise ValueError(f"Failed to parse Sympy expression `{expr}`")
874719
parsed_expr_set.add(parsed_expr)
875720

876721
if len(expr_set) == 1:
877722
return parsed_expr_set.pop()
878723
else:
879-
return parsed_expr_set
724+
return parsed_expr_set

0 commit comments

Comments
 (0)