Skip to content

Commit 102f73b

Browse files
committed
Added support for ! and !! factorials
1 parent 9570707 commit 102f73b

File tree

2 files changed

+209
-10
lines changed

2 files changed

+209
-10
lines changed

app/tests/symbolic_evaluation_tests.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1861,15 +1861,41 @@ def test_sum_in_answer(self, response, answer, value):
18611861
result = evaluation_function(response, answer, params)
18621862
assert result["is_correct"] is value
18631863

1864-
def test_exclamation_mark_for_factorial(self):
1865-
response = "3!"
1866-
answer = "factorial(3)"
1864+
@pytest.mark.parametrize(
1865+
"response, answer, value",
1866+
[
1867+
("3!", "factorial(3)", True),
1868+
("(n+1)!", "factorial(n+1)", True),
1869+
("n!", "factorial(n)", True),
1870+
("a!=b", "factorial(3)", False),
1871+
("2*n!", "2*factorial(n)", True),
1872+
]
1873+
)
1874+
def test_exclamation_mark_for_factorial(self, response, answer, value):
18671875
params = {
18681876
"strict_syntax": False,
18691877
"elementary_functions": True,
18701878
}
18711879
result = evaluation_function(response, answer, params)
1872-
assert result["is_correct"] is True
1880+
assert result["is_correct"] is value
1881+
1882+
@pytest.mark.parametrize(
1883+
"response, answer, value",
1884+
[
1885+
("3!!", "factorial2(3)", True),
1886+
("(n+1)!!", "factorial2(n+1)", True),
1887+
("n!!", "factorial2(n)", True),
1888+
("a!=b", "factorial2(3)", False),
1889+
("2*n!!", "2*factorial2(n)", True),
1890+
]
1891+
)
1892+
def test_double_exclamation_mark_for_factorial(self, response, answer, value):
1893+
params = {
1894+
"strict_syntax": False,
1895+
"elementary_functions": True,
1896+
}
1897+
result = evaluation_function(response, answer, params)
1898+
assert result["is_correct"] is value
18731899

18741900
def test_alternatives_to_input_symbols_takes_priority_over_elementary_function_alternatives(self):
18751901
answer = "Ef*exp(x)"

app/utility/expression_utilities.py

Lines changed: 179 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
from sympy.printing.latex import LatexPrinter
2525
from sympy import Basic, Symbol, Equality, Function
2626

27+
from sympy import factorial as _sympy_factorial
28+
from sympy.functions.combinatorial.factorials import factorial2 as _sympy_factorial2
29+
30+
2731
import re
2832
from typing import Dict, List, TypedDict
2933

@@ -661,6 +665,149 @@ def preprocess_expression(name, expr, parameters):
661665
success = False
662666
return success, expr, abs_feedback
663667

668+
def convert_double_bang_factorial(s: str) -> str:
669+
"""
670+
Convert double postfix factorial (e.g., n!!, (x+1)!!, 3!!) to function form: factorial2(n), etc.
671+
Safeguards:
672+
- Does NOT treat '!=' specially (since we target '!!').
673+
- Requires two consecutive '!' characters (no whitespace in between).
674+
- Handles balanced parenthesis operands (e.g., '(... )!!').
675+
- Handles identifiers and numeric literals.
676+
"""
677+
n = len(s)
678+
i = 0
679+
last = 0
680+
out = []
681+
682+
while i < n:
683+
ch = s[i]
684+
if ch == '!' and (i + 1) < n and s[i + 1] == '!':
685+
# Look left to find the operand (skip whitespace)
686+
j = i - 1
687+
while j >= 0 and s[j].isspace():
688+
j -= 1
689+
if j < 0:
690+
# Nothing to the left; keep as-is
691+
i += 1
692+
continue
693+
694+
# Case 1: operand ends with ')': parenthesized group
695+
if s[j] == ')':
696+
depth = 1
697+
k = j - 1
698+
while k >= 0 and depth > 0:
699+
if s[k] == ')':
700+
depth += 1
701+
elif s[k] == '(':
702+
depth -= 1
703+
k -= 1
704+
if depth == 0:
705+
L = k + 1 # index of '('
706+
R = j # index of ')'
707+
out.append(s[last:L])
708+
out.append('factorial2(')
709+
out.append(s[L:R+1])
710+
out.append(')')
711+
last = i + 2 # consume both '!'
712+
i += 2
713+
continue
714+
else:
715+
# Unbalanced parentheses; leave as-is
716+
i += 1
717+
continue
718+
719+
# Case 2: operand is an identifier/number ending at j
720+
k = j
721+
while k >= 0 and (s[k].isalnum() or s[k] in ('_', '.')):
722+
k -= 1
723+
L = k + 1
724+
if L <= j:
725+
out.append(s[last:L])
726+
out.append('factorial2(')
727+
out.append(s[L:j+1])
728+
out.append(')')
729+
last = i + 2
730+
i += 2
731+
continue
732+
# If we get here, no valid operand token; fall through and keep scanning.
733+
734+
i += 1
735+
736+
out.append(s[last:])
737+
return ''.join(out)
738+
739+
def convert_bang_factorial(s: str) -> str:
740+
"""
741+
Convert single postfix factorial (e.g., n!, (x+1)!, 3!) to function form: factorial(n), etc.
742+
Safeguards:
743+
- Does NOT convert '!='.
744+
- Does NOT convert '!!' (handled by convert_double_bang_factorial).
745+
"""
746+
n = len(s)
747+
i = 0
748+
last = 0
749+
out = []
750+
751+
while i < n:
752+
ch = s[i]
753+
if ch == '!':
754+
# Skip '!=' and '!!' (the latter handled in a separate pass)
755+
nxt = s[i+1] if i + 1 < n else ''
756+
if nxt in ('=', '!'):
757+
i += 1
758+
continue
759+
760+
# Move left to find the operand (skip whitespace)
761+
j = i - 1
762+
while j >= 0 and s[j].isspace():
763+
j -= 1
764+
if j < 0:
765+
i += 1
766+
continue
767+
768+
# Parenthesized operand
769+
if s[j] == ')':
770+
depth = 1
771+
k = j - 1
772+
while k >= 0 and depth > 0:
773+
if s[k] == ')':
774+
depth += 1
775+
elif s[k] == '(':
776+
depth -= 1
777+
k -= 1
778+
if depth == 0:
779+
L = k + 1
780+
R = j
781+
out.append(s[last:L])
782+
out.append('factorial(')
783+
out.append(s[L:R+1])
784+
out.append(')')
785+
last = i + 1
786+
i += 1
787+
continue
788+
else:
789+
i += 1
790+
continue
791+
792+
# Identifier/number operand
793+
k = j
794+
while k >= 0 and (s[k].isalnum() or s[k] in ('_', '.')):
795+
k -= 1
796+
L = k + 1
797+
if L <= j:
798+
out.append(s[last:L])
799+
out.append('factorial(')
800+
out.append(s[L:j+1])
801+
out.append(')')
802+
last = i + 1
803+
i += 1
804+
continue
805+
806+
i += 1
807+
808+
out.append(s[last:])
809+
return ''.join(out)
810+
664811

665812
def parse_expression(expr_string, parsing_params):
666813
'''
@@ -678,27 +825,52 @@ def parse_expression(expr_string, parsing_params):
678825
extra_transformations = parsing_params.get("extra_transformations", ())
679826
unsplittable_symbols = parsing_params.get("unsplittable_symbols", ())
680827
symbol_dict = parsing_params.get("symbol_dict", {})
828+
829+
# --- Ensure factorial and factorial2 are known and not split ---
830+
symbol_dict = dict(symbol_dict)
831+
symbol_dict.setdefault("factorial", _sympy_factorial)
832+
symbol_dict.setdefault("factorial2", _sympy_factorial2)
833+
834+
if 'factorial' not in unsplittable_symbols or 'factorial2' not in unsplittable_symbols:
835+
unsplittable_symbols = tuple(
836+
list(unsplittable_symbols)
837+
+ [s for s in ('factorial', 'factorial2') if s not in unsplittable_symbols]
838+
)
839+
681840
separate_unsplittable_symbols = [(x, " "+x) for x in unsplittable_symbols]
682841
substitutions = separate_unsplittable_symbols
683842

684843
parsed_expr_set = set()
685844
for expr in expr_set:
686845
expr = preprocess_according_to_chosen_convention(expr, parsing_params)
846+
687847
substitutions = list(set(substitutions))
688848
substitutions.sort(key=substitutions_sort_key)
689-
if parsing_params["elementary_functions"] is True:
849+
if parsing_params.get("elementary_functions") is True:
690850
substitutions += protect_elementary_functions_substitutions(expr)
851+
852+
expr = convert_double_bang_factorial(expr) # n!! -> factorial2(n)
853+
expr = convert_bang_factorial(expr) # n! -> factorial(n)
854+
691855
substitutions = list(set(substitutions))
692856
substitutions.sort(key=substitutions_sort_key)
693857
expr = substitute(expr, substitutions)
694858
expr = " ".join(expr.split())
859+
695860
can_split = lambda x: False if x in unsplittable_symbols else _token_splittable(x)
696861
if strict_syntax is True:
697862
transformations = parser_transformations[0:4]+extra_transformations
698863
else:
699-
transformations = parser_transformations[0:5, 6]+extra_transformations+(split_symbols_custom(can_split),)+parser_transformations[8, 9]
864+
transformations = (
865+
parser_transformations[0:5, 6] # keep your existing set-up
866+
+ extra_transformations
867+
+ (split_symbols_custom(can_split),)
868+
+ parser_transformations[8, 9]
869+
)
870+
700871
if parsing_params.get("rationalise", False):
701872
transformations += parser_transformations[11]
873+
702874
if "=" in expr:
703875
expr_parts = expr.split("=")
704876
lhs = parse_expr(expr_parts[0], transformations=transformations, local_dict=symbol_dict)
@@ -710,11 +882,12 @@ def parse_expression(expr_string, parsing_params):
710882
parsed_expr = parsed_expr.simplify()
711883
else:
712884
parsed_expr = parse_expr(expr, transformations=transformations, local_dict=symbol_dict, evaluate=False)
885+
713886
if not isinstance(parsed_expr, Basic):
714887
raise ValueError(f"Failed to parse Sympy expression `{expr}`")
715888
parsed_expr_set.add(parsed_expr)
716889

717-
if len(expr_set) == 1:
718-
return parsed_expr_set.pop()
719-
else:
720-
return parsed_expr_set
890+
if len(expr_set) == 1:
891+
return parsed_expr_set.pop()
892+
else:
893+
return parsed_expr_set

0 commit comments

Comments
 (0)