diff --git a/personal_python_ast_optimizer/parser/skipper.py b/personal_python_ast_optimizer/parser/skipper.py index 9394b36..86136d0 100644 --- a/personal_python_ast_optimizer/parser/skipper.py +++ b/personal_python_ast_optimizer/parser/skipper.py @@ -195,25 +195,16 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.AST | None: skip_base_classes(node, self.tokens_config.classes_to_skip) skip_decorators(node, self.tokens_config.decorators_to_skip) - parsed_node = self.generic_visit(node) - if ( self.token_types_config.simplify_named_tuples - and isinstance(parsed_node, ast.ClassDef) - and self._is_simple_named_tuple(parsed_node) + and isinstance(node, ast.ClassDef) + and self._is_simple_named_tuple(node) ): self._simplified_named_tuple = True - named_tuple = ast.Call( - ast.Name("namedtuple"), - [ - ast.Constant(parsed_node.name), - ast.List([ast.Constant(n.target.id) for n in parsed_node.body]), # type: ignore - ], - [], - ) - return ast.Assign([ast.Name(parsed_node.name)], named_tuple) + named_tuple = self._build_named_tuple(node) + return ast.Assign([ast.Name(node.name)], named_tuple) - return parsed_node + return self.generic_visit(node) @staticmethod def _is_simple_named_tuple(node: ast.ClassDef) -> bool: @@ -224,13 +215,47 @@ def _is_simple_named_tuple(node: ast.ClassDef) -> bool: and not node.keywords and not node.decorator_list and all( - isinstance(n, ast.AnnAssign) - and isinstance(n.target, ast.Name) - and n.value is None + isinstance(n, ast.AnnAssign) and isinstance(n.target, ast.Name) for n in node.body ) ) + @staticmethod + def _build_named_tuple(node: ast.ClassDef) -> ast.Call: + """Build what a namedtuple node would be for a given + class def inheriting from NamedTuple with only AnnAssigns in the body.""" + + defaults: list[ast.expr] + + if node.body: + defaults = [node.body[0].value] if node.body[0].value is not None else [] # type: ignore + + for i in range(1, len(node.body)): + assign: ast.AnnAssign = node.body[i] # type: ignore + if assign.value is not None: + defaults.append(assign.value) + elif node.body[i - 1].value is not None: # type: ignore + raise ValueError( + f"Non-default namedtuple {node.name} field " + "cannot follow default field" + ) + + else: + defaults = [] + + keywords: list[ast.keyword] = ( + [ast.keyword("defaults", ast.List(defaults))] if defaults else [] + ) + + return ast.Call( + ast.Name("namedtuple"), + [ + ast.Constant(node.name), + ast.List([ast.Constant(n.target.id) for n in node.body]), # type: ignore + ], + keywords, + ) + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST | None: return self._handle_function_node(node) diff --git a/requirements_dev.txt b/requirements_dev.txt index 365a146..1141147 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -2,4 +2,4 @@ codespell==2.4.1 mypy==1.19.1 pytest==9.0.2 pytest-cov==7.0.0 -ruff==0.14.11 +ruff==0.14.12 diff --git a/tests/parser/test_tuple.py b/tests/parser/test_tuple.py index 0a3171f..3eaa39e 100644 --- a/tests/parser/test_tuple.py +++ b/tests/parser/test_tuple.py @@ -50,6 +50,17 @@ class A(NamedTuple): """, "from collections import OrderedDict,namedtuple\nA=namedtuple('A',['foo','bar'])\nb=OrderedDict()", # noqa: E501 ), + ( + """ +from typing import NamedTuple + +class A(NamedTuple): + foo: int + bar: int = 2 + spam: str = 'a' +""", + "from collections import namedtuple\nA=namedtuple('A',['foo','bar','spam'],defaults=[2,'a'])", # noqa: E501 + ), ] @@ -61,3 +72,23 @@ def test_simplify_named_tuple(before: str, after: str): before_and_after, token_types_config=TokenTypesConfig(simplify_named_tuples=True), ) + + +def test_simplify_named_tuple_error(): + before_and_after = BeforeAndAfter( + """ +from typing import NamedTuple + +class A(NamedTuple): + foo: int = 2 + bar: int + spam: str = 'a' +""", + "", + ) + + with pytest.raises(ValueError): + run_minifier_and_assert_correct( + before_and_after, + token_types_config=TokenTypesConfig(simplify_named_tuples=True), + ) diff --git a/version.txt b/version.txt index 5fe6072..9b9a244 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -6.0.1 +6.0.2