Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions flake8_type_checking/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,9 @@ class InjectorMixin:
def visit(self, node: ast.AST) -> ast.AST: # noqa: D102
...

def lookup_full_name(self, node: ast.AST) -> str | None: # noqa: D102
...

def visit_FunctionDef(self, node: FunctionDef) -> None:
"""Remove and map function arguments and returns."""
super().visit_FunctionDef(node) # type: ignore[misc]
Expand All @@ -487,6 +490,12 @@ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
if self.injector_enabled:
self.handle_injector_declaration(node)

def _has_injected_annotation(self, node: AsyncFunctionDef | FunctionDef) -> bool:
return any(
isinstance(expr, ast.Subscript) and self.lookup_full_name(expr.value) == 'injector.Inject'
for expr in iter_function_annotation_nodes(node)
)

def handle_injector_declaration(self, node: AsyncFunctionDef | FunctionDef) -> None:
"""
Adjust for injector declaration setting.
Expand All @@ -496,17 +505,11 @@ def handle_injector_declaration(self, node: AsyncFunctionDef | FunctionDef) -> N

To achieve this, we just visit the annotations to register them as "uses".
"""
for path in [node.args.args, node.args.kwonlyargs]:
for argument in path:
if hasattr(argument, 'annotation') and argument.annotation:
annotation = argument.annotation
if not hasattr(annotation, 'value'):
continue
value = annotation.value
if hasattr(value, 'id') and value.id == 'Inject':
self.visit(argument.annotation)
if hasattr(value, 'attr') and value.attr == 'Inject':
self.visit(argument.annotation)
if not self._has_injected_annotation(node):
return

for expr in iter_function_annotation_nodes(node):
self.visit(expr)


class FastAPIMixin:
Expand Down Expand Up @@ -592,6 +595,8 @@ def visit_FunctionDef(self, node: FunctionDef) -> None:
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
"""Remove and map function arguments and returns."""
super().visit_AsyncFunctionDef(node) # type: ignore[misc]
if self.in_type_checking_block(node.lineno, node.col_offset):
return
if self.has_singledispatch_decorator(node):
for expr in iter_function_annotation_nodes(node):
self.visit(expr)
Expand Down
38 changes: 10 additions & 28 deletions tests/test_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, service: Inject[Service]) -> None:
@pytest.mark.parametrize(
('enabled', 'expected'),
[
(True, {'4:0 ' + TC002.format(module='other_dependency.OtherDependency')}),
(True, set()),
(
False,
{
Expand All @@ -65,8 +65,8 @@ def __init__(self, service: Inject[Service]) -> None:
),
],
)
def test_injector_option_only_allows_injected_dependencies(enabled, expected):
"""Whenever an injector option is enabled, only injected dependencies should be ignored."""
def test_injector_option_all_annotations_in_function_are_runtime_dependencies(enabled, expected):
"""Whenever an argument is injected, all the other annotations are runtime required too."""
example = textwrap.dedent(
'''
from injector import Inject
Expand All @@ -82,38 +82,20 @@ def __init__(self, service: Inject[Service], other: OtherDependency) -> None:
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected


@pytest.mark.parametrize(
('enabled', 'expected'),
[
(True, {'4:0 ' + TC002.format(module='other_dependency.OtherDependency')}),
(
False,
{
'2:0 ' + TC002.format(module='injector.Inject'),
'3:0 ' + TC002.format(module='services.Service'),
'4:0 ' + TC002.format(module='other_dependency.OtherDependency'),
},
),
],
)
def test_injector_option_only_allows_injector_slices(enabled, expected):
"""
Whenever an injector option is enabled, only injected dependencies should be ignored,
not any dependencies with slices.
"""
def test_injector_option_require_injections_under_unpack():
"""Whenever an injector option is enabled, injected dependencies should be ignored, even if unpacked."""
example = textwrap.dedent(
"""
from typing import Unpack
from injector import Inject
from services import Service
from other_dependency import OtherDependency

from services import ServiceKwargs
class X:
def __init__(self, service: Inject[Service], other_deps: list[OtherDependency]) -> None:
def __init__(self, service: Inject[Service], **kwargs: Unpack[ServiceKwargs]) -> None:
self.service = service
self.other_deps = other_deps
self.args = args
"""
)
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=True) == set()


@pytest.mark.parametrize(
Expand Down
Loading