2424from sympy .printing .latex import LatexPrinter
2525from sympy import Basic , Symbol , Equality , Function
2626
27+ from sympy import factorial as _sympy_factorial
28+ from sympy .functions .combinatorial .factorials import factorial2 as _sympy_factorial2
29+
30+
2731import re
2832from typing import Dict , List , TypedDict
2933
@@ -661,6 +665,149 @@ def preprocess_expression(name, expr, parameters):
661665 success = False
662666 return success , expr , abs_feedback
663667
668+ def convert_double_bang_factorial (s : str ) -> str :
669+ """
670+ Convert double postfix factorial (e.g., n!!, (x+1)!!, 3!!) to function form: factorial2(n), etc.
671+ Safeguards:
672+ - Does NOT treat '!=' specially (since we target '!!').
673+ - Requires two consecutive '!' characters (no whitespace in between).
674+ - Handles balanced parenthesis operands (e.g., '(... )!!').
675+ - Handles identifiers and numeric literals.
676+ """
677+ n = len (s )
678+ i = 0
679+ last = 0
680+ out = []
681+
682+ while i < n :
683+ ch = s [i ]
684+ if ch == '!' and (i + 1 ) < n and s [i + 1 ] == '!' :
685+ # Look left to find the operand (skip whitespace)
686+ j = i - 1
687+ while j >= 0 and s [j ].isspace ():
688+ j -= 1
689+ if j < 0 :
690+ # Nothing to the left; keep as-is
691+ i += 1
692+ continue
693+
694+ # Case 1: operand ends with ')': parenthesized group
695+ if s [j ] == ')' :
696+ depth = 1
697+ k = j - 1
698+ while k >= 0 and depth > 0 :
699+ if s [k ] == ')' :
700+ depth += 1
701+ elif s [k ] == '(' :
702+ depth -= 1
703+ k -= 1
704+ if depth == 0 :
705+ L = k + 1 # index of '('
706+ R = j # index of ')'
707+ out .append (s [last :L ])
708+ out .append ('factorial2(' )
709+ out .append (s [L :R + 1 ])
710+ out .append (')' )
711+ last = i + 2 # consume both '!'
712+ i += 2
713+ continue
714+ else :
715+ # Unbalanced parentheses; leave as-is
716+ i += 1
717+ continue
718+
719+ # Case 2: operand is an identifier/number ending at j
720+ k = j
721+ while k >= 0 and (s [k ].isalnum () or s [k ] in ('_' , '.' )):
722+ k -= 1
723+ L = k + 1
724+ if L <= j :
725+ out .append (s [last :L ])
726+ out .append ('factorial2(' )
727+ out .append (s [L :j + 1 ])
728+ out .append (')' )
729+ last = i + 2
730+ i += 2
731+ continue
732+ # If we get here, no valid operand token; fall through and keep scanning.
733+
734+ i += 1
735+
736+ out .append (s [last :])
737+ return '' .join (out )
738+
739+ def convert_bang_factorial (s : str ) -> str :
740+ """
741+ Convert single postfix factorial (e.g., n!, (x+1)!, 3!) to function form: factorial(n), etc.
742+ Safeguards:
743+ - Does NOT convert '!='.
744+ - Does NOT convert '!!' (handled by convert_double_bang_factorial).
745+ """
746+ n = len (s )
747+ i = 0
748+ last = 0
749+ out = []
750+
751+ while i < n :
752+ ch = s [i ]
753+ if ch == '!' :
754+ # Skip '!=' and '!!' (the latter handled in a separate pass)
755+ nxt = s [i + 1 ] if i + 1 < n else ''
756+ if nxt in ('=' , '!' ):
757+ i += 1
758+ continue
759+
760+ # Move left to find the operand (skip whitespace)
761+ j = i - 1
762+ while j >= 0 and s [j ].isspace ():
763+ j -= 1
764+ if j < 0 :
765+ i += 1
766+ continue
767+
768+ # Parenthesized operand
769+ if s [j ] == ')' :
770+ depth = 1
771+ k = j - 1
772+ while k >= 0 and depth > 0 :
773+ if s [k ] == ')' :
774+ depth += 1
775+ elif s [k ] == '(' :
776+ depth -= 1
777+ k -= 1
778+ if depth == 0 :
779+ L = k + 1
780+ R = j
781+ out .append (s [last :L ])
782+ out .append ('factorial(' )
783+ out .append (s [L :R + 1 ])
784+ out .append (')' )
785+ last = i + 1
786+ i += 1
787+ continue
788+ else :
789+ i += 1
790+ continue
791+
792+ # Identifier/number operand
793+ k = j
794+ while k >= 0 and (s [k ].isalnum () or s [k ] in ('_' , '.' )):
795+ k -= 1
796+ L = k + 1
797+ if L <= j :
798+ out .append (s [last :L ])
799+ out .append ('factorial(' )
800+ out .append (s [L :j + 1 ])
801+ out .append (')' )
802+ last = i + 1
803+ i += 1
804+ continue
805+
806+ i += 1
807+
808+ out .append (s [last :])
809+ return '' .join (out )
810+
664811
665812def parse_expression (expr_string , parsing_params ):
666813 '''
@@ -678,27 +825,52 @@ def parse_expression(expr_string, parsing_params):
678825 extra_transformations = parsing_params .get ("extra_transformations" , ())
679826 unsplittable_symbols = parsing_params .get ("unsplittable_symbols" , ())
680827 symbol_dict = parsing_params .get ("symbol_dict" , {})
828+
829+ # --- Ensure factorial and factorial2 are known and not split ---
830+ symbol_dict = dict (symbol_dict )
831+ symbol_dict .setdefault ("factorial" , _sympy_factorial )
832+ symbol_dict .setdefault ("factorial2" , _sympy_factorial2 )
833+
834+ if 'factorial' not in unsplittable_symbols or 'factorial2' not in unsplittable_symbols :
835+ unsplittable_symbols = tuple (
836+ list (unsplittable_symbols )
837+ + [s for s in ('factorial' , 'factorial2' ) if s not in unsplittable_symbols ]
838+ )
839+
681840 separate_unsplittable_symbols = [(x , " " + x ) for x in unsplittable_symbols ]
682841 substitutions = separate_unsplittable_symbols
683842
684843 parsed_expr_set = set ()
685844 for expr in expr_set :
686845 expr = preprocess_according_to_chosen_convention (expr , parsing_params )
846+
687847 substitutions = list (set (substitutions ))
688848 substitutions .sort (key = substitutions_sort_key )
689- if parsing_params [ "elementary_functions" ] is True :
849+ if parsing_params . get ( "elementary_functions" ) is True :
690850 substitutions += protect_elementary_functions_substitutions (expr )
851+
852+ expr = convert_double_bang_factorial (expr ) # n!! -> factorial2(n)
853+ expr = convert_bang_factorial (expr ) # n! -> factorial(n)
854+
691855 substitutions = list (set (substitutions ))
692856 substitutions .sort (key = substitutions_sort_key )
693857 expr = substitute (expr , substitutions )
694858 expr = " " .join (expr .split ())
859+
695860 can_split = lambda x : False if x in unsplittable_symbols else _token_splittable (x )
696861 if strict_syntax is True :
697862 transformations = parser_transformations [0 :4 ]+ extra_transformations
698863 else :
699- transformations = parser_transformations [0 :5 , 6 ]+ extra_transformations + (split_symbols_custom (can_split ),)+ parser_transformations [8 , 9 ]
864+ transformations = (
865+ parser_transformations [0 :5 , 6 ] # keep your existing set-up
866+ + extra_transformations
867+ + (split_symbols_custom (can_split ),)
868+ + parser_transformations [8 , 9 ]
869+ )
870+
700871 if parsing_params .get ("rationalise" , False ):
701872 transformations += parser_transformations [11 ]
873+
702874 if "=" in expr :
703875 expr_parts = expr .split ("=" )
704876 lhs = parse_expr (expr_parts [0 ], transformations = transformations , local_dict = symbol_dict )
@@ -710,11 +882,12 @@ def parse_expression(expr_string, parsing_params):
710882 parsed_expr = parsed_expr .simplify ()
711883 else :
712884 parsed_expr = parse_expr (expr , transformations = transformations , local_dict = symbol_dict , evaluate = False )
885+
713886 if not isinstance (parsed_expr , Basic ):
714887 raise ValueError (f"Failed to parse Sympy expression `{ expr } `" )
715888 parsed_expr_set .add (parsed_expr )
716889
717- if len (expr_set ) == 1 :
718- return parsed_expr_set .pop ()
719- else :
720- return parsed_expr_set
890+ if len (expr_set ) == 1 :
891+ return parsed_expr_set .pop ()
892+ else :
893+ return parsed_expr_set
0 commit comments