diff --git a/personal_python_ast_optimizer/parser/config.py b/personal_python_ast_optimizer/parser/config.py index 2160a92..559d12b 100644 --- a/personal_python_ast_optimizer/parser/config.py +++ b/personal_python_ast_optimizer/parser/config.py @@ -3,6 +3,10 @@ from enum import Enum, EnumType from types import EllipsisType +from personal_python_ast_optimizer.python_info import ( + default_functions_safe_to_exclude_in_test_expr, +) + class TypeHintsToSkip(Enum): NONE = 0 @@ -132,18 +136,20 @@ class OptimizationsConfig(_Config): __slots__ = ( "vars_to_fold", "enums_to_fold", + "functions_safe_to_exclude_in_test_expr", "remove_unused_imports", "fold_constants", "assume_this_machine", ) - def __init__( + def __init__( # noqa: PLR0913 self, vars_to_fold: dict[ str, str | bytes | bool | int | float | complex | None | EllipsisType ] | None = None, enums_to_fold: Iterable[EnumType] | None = None, + functions_safe_to_exclude_in_test_expr: set[str] | None = None, fold_constants: bool = True, remove_unused_imports: bool = True, assume_this_machine: bool = False, @@ -156,6 +162,10 @@ def __init__( if enums_to_fold is None else self._format_enums_to_fold_as_dict(enums_to_fold) ) + self.functions_safe_to_exclude_in_test_expr: set[str] = ( + functions_safe_to_exclude_in_test_expr + or default_functions_safe_to_exclude_in_test_expr + ) self.remove_unused_imports: bool = remove_unused_imports self.assume_this_machine: bool = assume_this_machine self.fold_constants: bool = fold_constants diff --git a/personal_python_ast_optimizer/parser/skipper.py b/personal_python_ast_optimizer/parser/skipper.py index af09ee5..07c3b2e 100644 --- a/personal_python_ast_optimizer/parser/skipper.py +++ b/personal_python_ast_optimizer/parser/skipper.py @@ -23,9 +23,11 @@ is_return_none, remove_duplicate_slots, skip_base_classes, - skip_dangling_expressions, skip_decorators, ) +from personal_python_ast_optimizer.python_info import ( + default_functions_safe_to_exclude_in_test_expr, +) class _NodeContext(Enum): @@ -103,11 +105,18 @@ def generic_visit(self, node: ast.AST) -> ast.AST: for value in old_value: if isinstance(value, ast.AST): value = self.visit(value) # noqa: PLW2901 - if value is None: + + if value is None or ( + self.token_types_config.skip_dangling_expressions + and isinstance(value, ast.Expr) + and isinstance(value.value, ast.Constant) + ): continue + if not isinstance(value, ast.AST): new_values.extend(value) continue + new_values.append(value) if ( @@ -154,9 +163,6 @@ def _combine_imports(body: list) -> None: body[:] = new_body def visit_Module(self, node: ast.Module) -> ast.AST: - if self.token_types_config.skip_dangling_expressions: - skip_dangling_expressions(node) - self.generic_visit(node) if self._simplified_named_tuple: @@ -189,9 +195,6 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.AST | None: if self._use_version_optimization((3, 0)): skip_base_classes(node, ["object"]) - if self.token_types_config.skip_dangling_expressions: - skip_dangling_expressions(node) - skip_base_classes(node, self.tokens_config.classes_to_skip) skip_decorators(node, self.tokens_config.decorators_to_skip) @@ -272,9 +275,6 @@ def _handle_function_node( if self.token_types_config.skip_type_hints: node.returns = None - if self.token_types_config.skip_dangling_expressions: - skip_dangling_expressions(node) - skip_decorators(node, self.tokens_config.decorators_to_skip) if node.body: @@ -298,9 +298,9 @@ def _should_skip_function( def visit_Try(self, node: ast.Try) -> ast.AST | list[ast.stmt] | None: parsed_node = self.generic_visit(node) - if isinstance( - parsed_node, (ast.Try, ast.TryStar) - ) and self._is_useless_try_node(parsed_node): + if isinstance(parsed_node, (ast.Try, ast.TryStar)) and self._body_is_only_pass( + parsed_node.body + ): return parsed_node.finalbody or None return parsed_node @@ -308,16 +308,16 @@ def visit_Try(self, node: ast.Try) -> ast.AST | list[ast.stmt] | None: def visit_TryStar(self, node: ast.TryStar) -> ast.AST | list[ast.stmt] | None: parsed_node = self.generic_visit(node) - if isinstance( - parsed_node, (ast.Try, ast.TryStar) - ) and self._is_useless_try_node(parsed_node): + if isinstance(parsed_node, (ast.Try, ast.TryStar)) and self._body_is_only_pass( + parsed_node.body + ): return parsed_node.finalbody or None return parsed_node @staticmethod - def _is_useless_try_node(node: ast.Try | ast.TryStar) -> bool: - return all(isinstance(n, ast.Pass) for n in node.body) + def _body_is_only_pass(node_body: list[ast.stmt]) -> bool: + return all(isinstance(n, ast.Pass) for n in node_body) def visit_Attribute(self, node: ast.Attribute) -> ast.AST | None: if isinstance(node.value, ast.Name): @@ -497,13 +497,19 @@ def visit_Dict(self, node: ast.Dict) -> ast.AST: def visit_If(self, node: ast.If) -> ast.AST | list[ast.stmt] | None: parsed_node: ast.AST = self.generic_visit(node) - if isinstance(parsed_node, ast.If) and isinstance( - parsed_node.test, ast.Constant - ): - if_body: list[ast.stmt] = ( - parsed_node.body if parsed_node.test.value else parsed_node.orelse - ) - return if_body or None + if isinstance(parsed_node, ast.If): + if isinstance(parsed_node.test, ast.Constant): + if_body: list[ast.stmt] = ( + parsed_node.body if parsed_node.test.value else parsed_node.orelse + ) + return if_body or None + + if not parsed_node.orelse and self._body_is_only_pass(parsed_node.body): + call_finder = _DanglingExprCallFinder( + self.optimizations_config.functions_safe_to_exclude_in_test_expr + ) + call_finder.visit(parsed_node.test) + return [ast.Expr(expr) for expr in call_finder.calls] return parsed_node @@ -832,3 +838,21 @@ def visit_Continue(self, node: ast.Continue) -> ast.Continue: def visit_Constant(self, node: ast.Constant) -> ast.Constant: return node + + +class _DanglingExprCallFinder(ast.NodeTransformer): + """Finds all calls in a given dangling expression + except for a subset of builtin functions that have + no side effects.""" + + __slots__ = ("calls", "excludes") + + def __init__(self, excludes: set[str]) -> None: + self.calls: list[ast.Call] = [] + self.excludes: set[str] = excludes + + def visit_Call(self, node: ast.Call) -> ast.Call: + if get_node_name(node) not in default_functions_safe_to_exclude_in_test_expr: + self.calls.append(node) + + return node diff --git a/personal_python_ast_optimizer/parser/utils.py b/personal_python_ast_optimizer/parser/utils.py index 988fd25..a8ee65a 100644 --- a/personal_python_ast_optimizer/parser/utils.py +++ b/personal_python_ast_optimizer/parser/utils.py @@ -37,19 +37,6 @@ def is_return_none(node: ast.Return) -> bool: return isinstance(node.value, ast.Constant) and node.value.value is None -def skip_dangling_expressions( - node: ast.Module | ast.ClassDef | ast.FunctionDef | ast.AsyncFunctionDef, -) -> None: - """Removes constant dangling expression like doc strings""" - node.body = [ - element - for element in node.body - if not ( - isinstance(element, ast.Expr) and isinstance(element.value, ast.Constant) - ) - ] - - def skip_base_classes( node: ast.ClassDef, classes_to_ignore: Iterable[str] | TokensToSkip ) -> None: diff --git a/personal_python_ast_optimizer/python_info.py b/personal_python_ast_optimizer/python_info.py index c8ade66..d526632 100644 --- a/personal_python_ast_optimizer/python_info.py +++ b/personal_python_ast_optimizer/python_info.py @@ -1,5 +1,17 @@ """Various tokens in Python that the ast module writes""" +# Functions that have no side effects and thus are safe to remove +# if a test expression is found to be useless. For example: +# if "str(a) == 'a':pass" will be turned into just "str(a) == 'a'" +# but if its known str has no side effects then it can be fully removed +default_functions_safe_to_exclude_in_test_expr: set[str] = { + "int", + "str", + "isinstance", + "getattr", + "hasattr", +} + comparison_and_conjunctions: list[str] = [ " if ", " else ", diff --git a/tests/parser/test_if.py b/tests/parser/test_if.py index 0103df3..6fb2af5 100644 --- a/tests/parser/test_if.py +++ b/tests/parser/test_if.py @@ -5,17 +5,17 @@ _if_cases = [ BeforeAndAfter( """ -if a() == b:pass +if a() == b:eggs() else:pass """, - "if a()==b:pass", + "if a()==b:eggs()", ), BeforeAndAfter( """ -if a == b:pass +if a == b:eggs() else:print()""", """ -if a==b:pass +if a==b:eggs() else:print() """.strip(), ), @@ -56,6 +56,29 @@ else:bar()""", "foo()", ), + BeforeAndAfter( + "if test():pass\nelse:foo()", + "if test():pass\nelse:foo()", + ), + BeforeAndAfter( + "if test():pass\nelse:pass", + "test()", + ), + BeforeAndAfter( + "if str(a) == 'a':pass", + "", + ), + BeforeAndAfter( + "if a < 3:pass", + "", + ), + BeforeAndAfter( + """ +try:foo() +except:raise OSError +if test():pass""", + "try:foo()\nexcept:raise OSError\ntest()", + ), ] diff --git a/tests/parser/test_script.py b/tests/parser/test_script.py index 395826f..33bf82f 100644 --- a/tests/parser/test_script.py +++ b/tests/parser/test_script.py @@ -17,14 +17,14 @@ def test_one_line_if(): """ 'a' if 'True' == b else 'b' 'a' if b == 'True' else 'b' -'a' if 1==1 else 'b' -'a' if 1==2 else 'b' +a='a' if 1==1 else 'b' +b='a' if 1==2 else 'b' """, """ 'a'if'True'==b else'b' 'a'if b=='True'else'b' -'a' -'b' +a='a' +b='b' """.strip(), ) run_minifier_and_assert_correct(before_and_after) diff --git a/version.txt b/version.txt index dfda3e0..f3b5af3 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -6.1.0 +6.1.1