Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 105 additions & 48 deletions src/fastcs/controllers/base_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from fastcs.attributes import AnyAttributeIO, Attribute, AttrR, AttrW, HintedAttribute
from fastcs.logging import bind_logger
from fastcs.methods import Command, Scan, UnboundCommand, UnboundScan
from fastcs.methods import Command, Method, Scan, UnboundCommand, UnboundScan
from fastcs.tracer import Tracer

logger = bind_logger(logger_name=__name__)
Expand Down Expand Up @@ -52,6 +52,7 @@ def __init__(

self.__hinted_attributes: dict[str, HintedAttribute] = {}
self.__hinted_sub_controllers: dict[str, type[BaseController]] = {}
self.__hinted_methods: dict[str, type[Method]] = {}
self._find_type_hints()

self._bind_attrs()
Expand Down Expand Up @@ -87,6 +88,9 @@ def _find_type_hints(self):
elif isinstance(hint, type) and issubclass(hint, BaseController):
self.__hinted_sub_controllers[name] = hint

elif isinstance(hint, type) and issubclass(hint, Method):
self.__hinted_methods[name] = hint

def _bind_attrs(self) -> None:
"""Search for Attributes and Methods to bind them to this instance.

Expand Down Expand Up @@ -168,47 +172,77 @@ def post_initialise(self):
self._connect_attribute_ios()

def _validate_type_hints(self):
"""Validate all `Attribute` and `Controller` type-hints were introspected"""
"""Validate all `Attribute`, `Controller`, and `Method`
type-hints were introspected"""
for name in self.__hinted_attributes:
self._validate_hinted_attribute(name)

for name in self.__hinted_sub_controllers:
self._validate_hinted_controller(name)

for name in self.__hinted_methods:
self._validate_hinted_method(name)

for subcontroller in self.sub_controllers.values():
subcontroller._validate_type_hints() # noqa: SLF001

def _validate_hinted_member(self, name: str, expected_type: type):
"""Validate that a hinted member exists on the controller"""
member = getattr(self, name, None)
if member is None or not isinstance(member, expected_type):
raise RuntimeError()
return member

def _validate_hinted_method(self, name: str):
"""Check that a `Method` with the given name exists on the controller"""
try:
method = self._validate_hinted_member(name, Method)
except RuntimeError as exc:
raise RuntimeError(
f"Controller `{self.__class__.__name__}` failed to introspect "
f"hinted method `{name}` during initialisation"
) from exc

logger.debug(
"Validated hinted method",
name=name,
controller=self,
method=method,
)

def _validate_hinted_attribute(self, name: str):
"""Check that an `Attribute` with the given name exists on the controller"""
attr = getattr(self, name, None)
if attr is None or not isinstance(attr, Attribute):
try:
attr = self._validate_hinted_member(name, Attribute)
except RuntimeError as exc:
raise RuntimeError(
f"Controller `{self.__class__.__name__}` failed to introspect "
f"hinted attribute `{name}` during initialisation"
)
else:
logger.debug(
"Validated hinted attribute",
name=name,
controller=self,
attribute=attr,
)
) from exc

logger.debug(
"Validated hinted attribute",
name=name,
controller=self,
attribute=attr,
)

def _validate_hinted_controller(self, name: str):
"""Check that a sub controller with the given name exists on the controller"""
controller = getattr(self, name, None)
if controller is None or not isinstance(controller, BaseController):
try:
controller = self._validate_hinted_member(name, BaseController)
except RuntimeError as exc:
raise RuntimeError(
f"Controller `{self.__class__.__name__}` failed to introspect "
f"hinted controller `{name}` during initialisation"
)
else:
logger.debug(
"Validated hinted sub controller",
name=name,
controller=self,
sub_controller=controller,
)
) from exc

logger.debug(
"Validated hinted sub controller",
name=name,
controller=self,
sub_controller=controller,
)

def _connect_attribute_ios(self) -> None:
"""Connect ``Attribute`` callbacks to ``AttributeIO``s"""
Expand Down Expand Up @@ -245,14 +279,27 @@ def set_path(self, path: list[str]):
for attribute in self.__attributes.values():
attribute.set_path(path)

def _check_for_name_clash(self, name: str):
namespaces = {
"attribute": self.__attributes,
"sub controller": self.__sub_controllers,
"scan method": self.__scan_methods,
"command method": self.__command_methods,
}

for kind, namespace in namespaces.items():
if name in namespace:
raise ValueError(
f"Controller {self} has existing {kind} {name}: {namespace[name]}"
)

def add_attribute(self, name, attr: Attribute):
if name in self.__attributes:
raise ValueError(
f"Cannot add attribute {attr}. "
f"Controller {self} has has existing attribute {name}: "
f"{self.__attributes[name]}"
)
elif name in self.__hinted_attributes:
try:
self._check_for_name_clash(name)
except ValueError as exc:
raise ValueError(f"Cannot add attribute {attr}.") from exc

if name in self.__hinted_attributes:
hint = self.__hinted_attributes[name]
if not isinstance(attr, hint.attr_type):
raise RuntimeError(
Expand All @@ -267,12 +314,6 @@ def add_attribute(self, name, attr: Attribute):
f"Expected '{hint.dtype.__name__}', "
f"got '{attr.datatype.dtype.__name__}'."
)
elif name in self.__sub_controllers.keys():
raise ValueError(
f"Cannot add attribute {attr}. "
f"Controller {self} has existing sub controller {name}: "
f"{self.__sub_controllers[name]}"
)

attr.set_name(name)
attr.set_path(self.path)
Expand All @@ -284,13 +325,12 @@ def attributes(self) -> dict[str, Attribute]:
return self.__attributes

def add_sub_controller(self, name: str, sub_controller: BaseController):
if name in self.__sub_controllers.keys():
raise ValueError(
f"Cannot add sub controller {sub_controller}. "
f"Controller {self} has existing sub controller {name}: "
f"{self.__sub_controllers[name]}"
)
elif name in self.__hinted_sub_controllers:
try:
self._check_for_name_clash(name)
except ValueError as exc:
raise ValueError(f"Cannot add sub controller {sub_controller}.") from exc

if name in self.__hinted_sub_controllers:
hint = self.__hinted_sub_controllers[name]
if not isinstance(sub_controller, hint):
raise RuntimeError(
Expand All @@ -299,12 +339,6 @@ def add_sub_controller(self, name: str, sub_controller: BaseController):
f"Expected '{hint.__name__}' got "
f"'{sub_controller.__class__.__name__}'."
)
elif name in self.__attributes:
raise ValueError(
f"Cannot add sub controller {sub_controller}. "
f"Controller {self} has existing attribute {name}: "
f"{self.__attributes[name]}"
)

sub_controller.set_path(self.path + [name])
self.__sub_controllers[name] = sub_controller
Expand All @@ -317,7 +351,24 @@ def add_sub_controller(self, name: str, sub_controller: BaseController):
def sub_controllers(self) -> dict[str, BaseController]:
return self.__sub_controllers

def _validated_added_method(self, name: str, method: Method):
if name in self.__hinted_methods:
hint = self.__hinted_methods[name]
if not isinstance(method, hint):
raise RuntimeError(
f"Controller '{self.__class__.__name__}' introspection of "
f"hinted method '{name}' does not match defined type. "
f"Expected '{hint.__name__}' got "
f"'{method.__class__.__name__}'."
)

def add_command(self, name: str, command: Command):
try:
self._check_for_name_clash(name)
self._validated_added_method(name, command)
except (ValueError, RuntimeError) as exc:
raise exc.__class__(f"Cannot add command method {command}.") from exc

self.__command_methods[name] = command
super().__setattr__(name, command)

Expand All @@ -326,6 +377,12 @@ def command_methods(self) -> dict[str, Command]:
return self.__command_methods

def add_scan(self, name: str, scan: Scan):
try:
self._check_for_name_clash(name)
self._validated_added_method(name, scan)
except (ValueError, RuntimeError) as exc:
raise exc.__class__(f"Cannot add scan method {scan}.") from exc

self.__scan_methods[name] = scan
super().__setattr__(name, scan)

Expand Down
1 change: 1 addition & 0 deletions src/fastcs/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .command import CommandCallback as CommandCallback
from .command import UnboundCommand as UnboundCommand
from .command import command as command
from .method import Method as Method
from .scan import Scan as Scan
from .scan import ScanCallback as ScanCallback
from .scan import UnboundScan as UnboundScan
Expand Down
63 changes: 44 additions & 19 deletions tests/test_controllers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastcs.attributes import AttrR, AttrRW
from fastcs.controllers import Controller, ControllerVector
from fastcs.datatypes import Enum, Float, Int
from fastcs.methods import Command, Scan


def test_controller_nesting():
Expand All @@ -20,7 +21,7 @@ def test_controller_nesting():
assert controller.sub_controllers == {"a": sub_controller}
assert sub_controller.sub_controllers == {"b": sub_sub_controller}

with pytest.raises(ValueError, match=r"existing sub controller"):
with pytest.raises(ValueError, match=r"Cannot add sub controller"):
controller.a = Controller()

with pytest.raises(ValueError, match=r"already registered"):
Expand Down Expand Up @@ -76,33 +77,39 @@ def test_attribute_parsing():
}


def test_conflicting_attributes_and_controllers():
async def noop() -> None:
pass


@pytest.mark.parametrize(
"member_name, member_value, expected_error",
[
("attr", AttrR(Float()), r"Cannot add attribute"),
("attr", Controller(), r"Cannot add sub controller"),
("attr", Command(noop), r"Cannot add command"),
("sub_controller", AttrR(Int()), r"Cannot add attribute"),
("sub_controller", Controller(), r"Cannot add sub controller"),
("sub_controller", Command(noop), r"Cannot add command"),
("cmd", AttrR(Int()), r"Cannot add attribute"),
("cmd", Controller(), r"Cannot add sub controller"),
("cmd", Command(noop), r"Cannot add command"),
],
)
def test_conflicting_attributes_and_controllers_and_commands(
member_name, member_value, expected_error
):
class ConflictingController(Controller):
attr = AttrR(Int())
cmd = Command(noop)

def __init__(self):
super().__init__()
self.sub_controller = Controller()

controller = ConflictingController()

with pytest.raises(ValueError, match=r"Cannot add attribute .* existing attribute"):
controller.attr = AttrR(Float()) # pyright: ignore[reportAttributeAccessIssue]

with pytest.raises(
ValueError, match=r"Cannot add sub controller .* existing attribute"
):
controller.attr = Controller() # pyright: ignore[reportAttributeAccessIssue]

with pytest.raises(
ValueError, match=r"Cannot add sub controller .* existing sub controller"
):
controller.sub_controller = Controller()

with pytest.raises(
ValueError, match=r"Cannot add attribute .* existing sub controller"
):
controller.sub_controller = AttrR(Int()) # pyright: ignore[reportAttributeAccessIssue]
with pytest.raises(ValueError, match=expected_error):
setattr(controller, member_name, member_value)


def test_controller_raises_error_if_passed_numeric_sub_controller_name():
Expand Down Expand Up @@ -203,3 +210,21 @@ class HintedController(Controller):

controller.add_sub_controller("child", SomeSubController())
controller._validate_type_hints()


@pytest.mark.asyncio
async def test_method_hint_validation():
class HintedController(Controller):
method: Scan

controller = HintedController()

with pytest.raises(RuntimeError, match="failed to introspect hinted method"):
controller._validate_type_hints()

with pytest.raises(RuntimeError, match="Cannot add command method"):
controller.add_command("method", Command(noop))

controller.add_scan("method", Scan(fn=noop, period=0.1))

controller._validate_type_hints()