From 2ccbb5eaef4c47e03b423f833b8715bc4b16f26c Mon Sep 17 00:00:00 2001 From: rozyczko Date: Mon, 1 Dec 2025 13:43:06 +0100 Subject: [PATCH 01/14] update __init__ to expose new classes --- .../fitting/calculators/__init__.py | 10 +- .../fitting/calculators/calculator_base.py | 261 +++++++ .../fitting/calculators/calculator_factory.py | 336 +++++++++ .../calculators/test_calculator_base.py | 283 ++++++++ .../calculators/test_calculator_factory.py | 644 ++++++++++++++++++ 5 files changed, 1533 insertions(+), 1 deletion(-) create mode 100644 src/easyscience/fitting/calculators/calculator_base.py create mode 100644 src/easyscience/fitting/calculators/calculator_factory.py create mode 100644 tests/unit_tests/fitting/calculators/test_calculator_base.py create mode 100644 tests/unit_tests/fitting/calculators/test_calculator_factory.py diff --git a/src/easyscience/fitting/calculators/__init__.py b/src/easyscience/fitting/calculators/__init__.py index a3ca5d43..e0c5d480 100644 --- a/src/easyscience/fitting/calculators/__init__.py +++ b/src/easyscience/fitting/calculators/__init__.py @@ -2,6 +2,14 @@ # SPDX-License-Identifier: BSD-3-Clause # © 2021-2025 Contributors to the EasyScience project +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project None: + """ + Initialize the calculator with a model and instrumental parameters. + + Parameters + ---------- + model : NewBase + The physical model to calculate from. This is typically a sample + or structure definition containing fittable parameters. + instrumental_parameters : NewBase, optional + Instrumental parameters that affect the calculation, such as + resolution, wavelength, or detector settings. + unique_name : str, optional + Unique identifier for this calculator instance. + display_name : str, optional + Human-readable name for display purposes. + **kwargs : Any + Additional calculator-specific options. + """ + if model is None: + raise ValueError("Model cannot be None") + + # Initialize NewBase with naming + super().__init__(unique_name=unique_name, display_name=display_name) + + self._model = model + self._instrumental_parameters = instrumental_parameters + self._additional_kwargs = kwargs + + @property + def model(self) -> NewBase: + """ + Get the current physical model. + + Returns + ------- + NewBase + The physical model used for calculations. + """ + return self._model + + @model.setter + def model(self, new_model: NewBase) -> None: + """ + Set a new physical model. + + Parameters + ---------- + new_model : NewBase + The new physical model to use for calculations. + + Raises + ------ + ValueError + If the new model is None. + """ + if new_model is None: + raise ValueError("Model cannot be None") + self._model = new_model + + @property + def instrumental_parameters(self) -> Optional[NewBase]: + """ + Get the current instrumental parameters. + + Returns + ------- + NewBase or None + The instrumental parameters, or None if not set. + """ + return self._instrumental_parameters + + @instrumental_parameters.setter + def instrumental_parameters(self, new_parameters: Optional[NewBase]) -> None: + """ + Set new instrumental parameters. + + Parameters + ---------- + new_parameters : NewBase or None + The new instrumental parameters to use for calculations. + Truly optional, since instrumental parameters may not always be needed. + """ + self._instrumental_parameters = new_parameters + + def update_model(self, new_model: NewBase) -> None: + """ + Update the physical model used for calculations. + + This is an alternative to the `model` property setter that can be + overridden by subclasses to perform additional setup when the model changes. + + Parameters + ---------- + new_model : NewBase + The new physical model to use. + + Raises + ------ + ValueError + If the new model is None. + """ + self.model = new_model + + def update_instrumental_parameters(self, new_parameters: Optional[NewBase]) -> None: + """ + Update the instrumental parameters used for calculations. + + This is an alternative to the `instrumental_parameters` property setter + that can be overridden by subclasses to perform additional setup when + instrumental parameters change. + + Parameters + ---------- + new_parameters : NewBase or None + The new instrumental parameters to use. + """ + self.instrumental_parameters = new_parameters + + @property + def additional_kwargs(self) -> dict: + """ + Get additional keyword arguments passed during initialization. + + Returns + ------- + dict + Dictionary of additional kwargs passed to __init__. + """ + return self._additional_kwargs + + @abstractmethod + def calculate(self, x: np.ndarray) -> np.ndarray: + """ + Calculate theoretical values at the given points. + + This is the main calculation method that must be implemented by all + concrete calculator classes. It uses the current model and instrumental + parameters to compute theoretical predictions. + + Parameters + ---------- + x : np.ndarray + The independent variable values (e.g., Q values, angles, energies) + at which to calculate the theoretical response. + + Returns + ------- + np.ndarray + The calculated theoretical values corresponding to the input x values. + + Notes + ----- + This method is called during fitting and should be thread-safe if + parallel fitting is to be supported. + """ + ... + + def __repr__(self) -> str: + """Return a string representation of the calculator.""" + model_name = getattr(self._model, 'name', type(self._model).__name__) + instr_info = "" + if self._instrumental_parameters is not None: + instr_name = getattr( + self._instrumental_parameters, + 'name', + type(self._instrumental_parameters).__name__ # default to class name if no 'name' attribute + ) + instr_info = f", instrumental_parameters={instr_name}" + return f"{self.__class__.__name__}(model={model_name}{instr_info})" diff --git a/src/easyscience/fitting/calculators/calculator_factory.py b/src/easyscience/fitting/calculators/calculator_factory.py new file mode 100644 index 00000000..50a55e10 --- /dev/null +++ b/src/easyscience/fitting/calculators/calculator_factory.py @@ -0,0 +1,336 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project List[str]: + """ + Return a list of available calculator names. + + Returns + ------- + List[str] + Names of all calculators that can be created by this factory. + """ + ... + + @abstractmethod + def create( + self, + calculator_name: str, + model: NewBase, + instrumental_parameters: Optional[NewBase] = None, + **kwargs: Any, + ) -> CalculatorBase: + """ + Create a calculator instance. + + Parameters + ---------- + calculator_name : str + The name of the calculator to create. Must be one of the names + returned by `available_calculators`. + model : NewBase + The physical model (e.g., sample) to pass to the calculator. + instrumental_parameters : NewBase, optional + Instrumental parameters to pass to the calculator. + **kwargs : Any + Additional arguments to pass to the calculator constructor. + + Returns + ------- + CalculatorBase + A new calculator instance configured with the given model and + instrumental parameters. + + Raises + ------ + ValueError + If the requested calculator_name is not available. + """ + ... + + def __repr__(self) -> str: + """Return a string representation of the factory.""" + return f"{self.__class__.__name__}(available={self.available_calculators})" + + +class SimpleCalculatorFactory(CalculatorFactoryBase): + """ + A simple implementation of a calculator factory using a dictionary registry. + + This class provides a convenient base for creating calculator factories + where calculators are registered in a dictionary. Subclasses only need + to populate the `_calculators` class attribute. + + Parameters + ---------- + calculators : Dict[str, Type[CalculatorBase]], optional + A dictionary mapping calculator names to calculator classes. + If not provided, uses the class-level `_calculators` attribute. + + Attributes + ---------- + _calculators : Dict[str, Type[CalculatorBase]] + Class-level dictionary of registered calculators. Subclasses should + override this with their available calculators. + + Examples + -------- + Using class-level registration:: + + class MyFactory(SimpleCalculatorFactory): + _calculators = { + 'fast': FastCalculator, + 'accurate': AccurateCalculator, + } + + factory = MyFactory() + calc = factory.create('fast', model, instrument) + + Using instance-level registration:: + + factory = SimpleCalculatorFactory({ + 'custom': CustomCalculator, + }) + calc = factory.create('custom', model, instrument) + """ + + _calculators: Dict[str, Type[CalculatorBase]] = {} + + def __init__( + self, + calculators: Optional[Dict[str, Type[CalculatorBase]]] = None, + ) -> None: + """ + Initialize the factory with optional calculator registry. + + Parameters + ---------- + calculators : Dict[str, Type[CalculatorBase]], optional + A dictionary mapping calculator names to calculator classes. + If provided, overrides the class-level `_calculators` attribute. + """ + # Create instance-level copy to prevent bleeding between instances + if calculators is not None: + self._calculators = dict(calculators) + else: + # Create a copy of the class-level registry for this instance + self._calculators = dict(self.__class__._calculators) + + @property + def available_calculators(self) -> List[str]: + """ + Return a list of available calculator names. + + Returns + ------- + List[str] + Names of all registered calculators. + """ + return list(self._calculators.keys()) + + def create( + self, + calculator_name: str, + model: NewBase, + instrumental_parameters: Optional[NewBase] = None, + **kwargs: Any, + ) -> CalculatorBase: + """ + Create a calculator instance from the registered calculators. + + Parameters + ---------- + calculator_name : str + The name of the calculator to create. + model : NewBase + The physical model to pass to the calculator. + instrumental_parameters : NewBase, optional + Instrumental parameters to pass to the calculator. + **kwargs : Any + Additional arguments to pass to the calculator constructor. + + Returns + ------- + CalculatorBase + A new calculator instance. + + Raises + ------ + ValueError + If the calculator_name is not in the registry or is not a string. + TypeError + If model is None or instrumental_parameters has wrong type. + """ + if not isinstance(calculator_name, str): + raise ValueError(f"calculator_name must be a string, got {type(calculator_name).__name__}") + + if calculator_name not in self._calculators: + available = ", ".join(self.available_calculators) if self.available_calculators else "none" + raise ValueError( + f"Unknown calculator '{calculator_name}'. " + f"Available calculators: {available}" + ) + + if model is None: + raise TypeError("model cannot be None") + + calculator_class = self._calculators[calculator_name] + try: + return calculator_class(model, instrumental_parameters, **kwargs) + except Exception as e: + raise type(e)( + f"Failed to create calculator '{calculator_name}': {e}" + ) from e + + def register(self, name: str, calculator_class: Type[CalculatorBase]) -> None: + """ + Register a new calculator class with the factory. + + Parameters + ---------- + name : str + The name to register the calculator under. + calculator_class : Type[CalculatorBase] + The calculator class to register. + + Raises + ------ + TypeError + If calculator_class is not a subclass of CalculatorBase. + ValueError + If name is empty or not a string. + + Warnings + -------- + If overwriting an existing calculator, a warning is issued. + """ + # Import here to avoid circular imports at module level + import warnings + + from .calculator_base import CalculatorBase + + if not isinstance(name, str) or not name: + raise ValueError("Calculator name must be a non-empty string") + + if not (isinstance(calculator_class, type) and issubclass(calculator_class, CalculatorBase)): + raise TypeError( + f"calculator_class must be a subclass of CalculatorBase, " + f"got {type(calculator_class).__name__}" + ) + + if name in self._calculators: + warnings.warn( + f"Overwriting existing calculator '{name}' in {self.__class__.__name__}", + UserWarning, + stacklevel=2 + ) + + self._calculators[name] = calculator_class + + def unregister(self, name: str) -> None: + """ + Remove a calculator from the registry. + + Parameters + ---------- + name : str + The name of the calculator to remove. + + Raises + ------ + KeyError + If the calculator name is not in the registry. + """ + if name not in self._calculators: + raise KeyError(f"Calculator '{name}' is not registered") + del self._calculators[name] diff --git a/tests/unit_tests/fitting/calculators/test_calculator_base.py b/tests/unit_tests/fitting/calculators/test_calculator_base.py new file mode 100644 index 00000000..38b3c601 --- /dev/null +++ b/tests/unit_tests/fitting/calculators/test_calculator_base.py @@ -0,0 +1,283 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project np.ndarray: + # Simple identity function for testing + return x * 2.0 + + return ConcreteCalculator + + @pytest.fixture + def calculator(self, clear, concrete_calculator_class, mock_model, mock_instrumental_parameters): + """Create a calculator instance for testing.""" + return concrete_calculator_class( + mock_model, mock_instrumental_parameters, unique_name="test_calc", display_name="TestCalc" + ) + + # Initialization tests + def test_init_with_model_only(self, clear, concrete_calculator_class, mock_model): + """Test initialization with only a model.""" + calc = concrete_calculator_class(mock_model, unique_name="test_1", display_name="Test1") + assert calc.model is mock_model + assert calc.instrumental_parameters is None + + def test_init_with_model_and_instrumental_parameters( + self, clear, concrete_calculator_class, mock_model, mock_instrumental_parameters + ): + """Test initialization with model and instrumental parameters.""" + calc = concrete_calculator_class( + mock_model, mock_instrumental_parameters, unique_name="test_2", display_name="Test2" + ) + assert calc.model is mock_model + assert calc.instrumental_parameters is mock_instrumental_parameters + + def test_init_with_kwargs(self, clear, concrete_calculator_class, mock_model): + """Test initialization with additional kwargs.""" + calc = concrete_calculator_class( + mock_model, unique_name="test_3", display_name="Test3", custom_option="value" + ) + assert calc.additional_kwargs == {"custom_option": "value"} + + def test_init_with_none_model_raises_error(self, clear, concrete_calculator_class): + """Test that initialization with None model raises ValueError.""" + with pytest.raises(ValueError, match="Model cannot be None"): + concrete_calculator_class(None, unique_name="test_4", display_name="Test4") + + # Model property tests + def test_model_getter(self, calculator, mock_model): + """Test model getter property.""" + assert calculator.model is mock_model + + def test_model_setter(self, calculator): + """Test model setter property.""" + new_model = MagicMock() + new_model.name = "NewModel" + calculator.model = new_model + assert calculator.model is new_model + + def test_model_setter_with_none_raises_error(self, calculator): + """Test that setting model to None raises ValueError.""" + with pytest.raises(ValueError, match="Model cannot be None"): + calculator.model = None + + # Instrumental parameters property tests + def test_instrumental_parameters_getter(self, calculator, mock_instrumental_parameters): + """Test instrumental_parameters getter property.""" + assert calculator.instrumental_parameters is mock_instrumental_parameters + + def test_instrumental_parameters_setter(self, calculator): + """Test instrumental_parameters setter property.""" + new_params = MagicMock() + new_params.name = "NewInstrument" + calculator.instrumental_parameters = new_params + assert calculator.instrumental_parameters is new_params + + def test_instrumental_parameters_setter_with_none(self, calculator): + """Test that instrumental_parameters can be set to None.""" + calculator.instrumental_parameters = None + assert calculator.instrumental_parameters is None + + # Update methods tests + def test_update_model(self, calculator): + """Test update_model method.""" + new_model = MagicMock() + new_model.name = "UpdatedModel" + calculator.update_model(new_model) + assert calculator.model is new_model + + def test_update_model_with_none_raises_error(self, calculator): + """Test that update_model with None raises ValueError.""" + with pytest.raises(ValueError, match="Model cannot be None"): + calculator.update_model(None) + + def test_update_instrumental_parameters(self, calculator): + """Test update_instrumental_parameters method.""" + new_params = MagicMock() + new_params.name = "UpdatedInstrument" + calculator.update_instrumental_parameters(new_params) + assert calculator.instrumental_parameters is new_params + + def test_update_instrumental_parameters_with_none(self, calculator): + """Test that update_instrumental_parameters accepts None.""" + calculator.update_instrumental_parameters(None) + assert calculator.instrumental_parameters is None + + # Calculate method tests + def test_calculate_returns_array(self, calculator): + """Test that calculate returns an array.""" + x = np.array([1.0, 2.0, 3.0]) + result = calculator.calculate(x) + assert isinstance(result, np.ndarray) + np.testing.assert_array_equal(result, np.array([2.0, 4.0, 6.0])) + + def test_calculate_with_empty_array(self, calculator): + """Test calculate with empty array.""" + x = np.array([]) + result = calculator.calculate(x) + assert len(result) == 0 + + # Abstract method enforcement tests + def test_cannot_instantiate_abstract_class(self, mock_model): + """Test that CalculatorBase cannot be instantiated directly.""" + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + CalculatorBase(mock_model) + + def test_subclass_must_implement_calculate(self, mock_model): + """Test that subclasses must implement calculate method.""" + + class IncompleteCalculator(CalculatorBase): + pass # Does not implement calculate + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + IncompleteCalculator(mock_model) + + # Representation tests + def test_repr_with_model_only(self, clear, concrete_calculator_class, mock_model): + """Test __repr__ with only model.""" + calc = concrete_calculator_class(mock_model, unique_name="test_5", display_name="Test5") + repr_str = repr(calc) + assert "ConcreteCalculator" in repr_str + assert "model=MockModel" in repr_str + assert "instrumental_parameters" not in repr_str + + def test_repr_with_model_and_instrumental_parameters( + self, clear, concrete_calculator_class, mock_model, mock_instrumental_parameters + ): + """Test __repr__ with model and instrumental parameters.""" + calc = concrete_calculator_class(mock_model, mock_instrumental_parameters, unique_name="test_6", display_name="Test6") + repr_str = repr(calc) + assert "ConcreteCalculator" in repr_str + assert "model=MockModel" in repr_str + assert "instrumental_parameters=MockInstrument" in repr_str + + def test_repr_with_model_without_name_attribute(self, clear, concrete_calculator_class): + """Test __repr__ when model has no name attribute.""" + model = MagicMock(spec=[]) # No name attribute + calc = concrete_calculator_class(model, unique_name="test_7", display_name="Test7") + repr_str = repr(calc) + assert "ConcreteCalculator" in repr_str + assert "MagicMock" in repr_str + + # Name attribute tests + def test_calculator_name_attribute(self, calculator): + """Test that calculator has name attribute.""" + assert calculator.name == "test_calculator" + + def test_default_name_is_base(self): + """Test that default name is 'base'.""" + assert CalculatorBase.name == "base" + + # Additional kwargs property tests + def test_additional_kwargs_with_init(self, clear, concrete_calculator_class, mock_model): + """Test additional_kwargs property with kwargs in init.""" + calc = concrete_calculator_class( + mock_model, + unique_name="test_8", + display_name="Test8", + custom_option="value", + numeric_param=42 + ) + assert calc.additional_kwargs == {"custom_option": "value", "numeric_param": 42} + + def test_additional_kwargs_empty_by_default(self, clear, concrete_calculator_class, mock_model): + """Test that additional_kwargs is empty dict when no kwargs provided.""" + calc = concrete_calculator_class(mock_model, unique_name="test_9", display_name="Test9") + assert calc.additional_kwargs == {} + + +class TestCalculatorBaseWithRealModel: + """Integration-style tests using actual EasyScience objects.""" + + @pytest.fixture + def clear(self): + """Clear global map to avoid test contamination.""" + global_object.map._clear() + yield + global_object.map._clear() + + @pytest.fixture + def real_parameter(self, clear): + """Create a real Parameter object.""" + from easyscience.variable import Parameter + return Parameter("test_param", value=5.0, unit="m") + + @pytest.fixture + def concrete_calculator_class(self): + """Create a concrete implementation that uses model parameters.""" + + class ParameterAwareCalculator(CalculatorBase): + name = "param_aware" + + def calculate(self, x: np.ndarray) -> np.ndarray: + # Access parameter from model if available + if hasattr(self._model, 'get_parameters'): + params = self._model.get_parameters() + if params: + scale = params[0].value + return x * scale + return x + + return ParameterAwareCalculator + + def test_calculator_can_access_model_parameters( + self, clear, concrete_calculator_class, real_parameter + ): + """Test that calculator can access parameters from model.""" + # Create a mock model that returns our real parameter + model = MagicMock() + model.name = "TestModel" + model.get_parameters.return_value = [real_parameter] + + calc = concrete_calculator_class(model, unique_name="test_10", display_name="Test10") + x = np.array([1.0, 2.0, 3.0]) + result = calc.calculate(x) + + # Should multiply by parameter value (5.0) + np.testing.assert_array_equal(result, np.array([5.0, 10.0, 15.0])) diff --git a/tests/unit_tests/fitting/calculators/test_calculator_factory.py b/tests/unit_tests/fitting/calculators/test_calculator_factory.py new file mode 100644 index 00000000..f2c7743f --- /dev/null +++ b/tests/unit_tests/fitting/calculators/test_calculator_factory.py @@ -0,0 +1,644 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project np.ndarray: + return x * 2.0 + + return TestCalculator + + @pytest.fixture + def concrete_factory_class(self, concrete_calculator_class): + """Create a concrete factory implementation.""" + + class TestFactory(CalculatorFactoryBase): + def __init__(self): + self._calc_class = concrete_calculator_class + + @property + def available_calculators(self) -> List[str]: + return ["test"] + + def create(self, calculator_name, model, instrumental_parameters=None, **kwargs): + if calculator_name != "test": + raise ValueError(f"Unknown calculator: {calculator_name}") + return self._calc_class(model, instrumental_parameters, **kwargs) + + return TestFactory + + # Abstract class enforcement tests + def test_cannot_instantiate_abstract_factory(self): + """Test that CalculatorFactoryBase cannot be instantiated directly.""" + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + CalculatorFactoryBase() + + def test_subclass_must_implement_available_calculators(self, concrete_calculator_class, mock_model): + """Test that subclasses must implement available_calculators property.""" + + class IncompleteFactory(CalculatorFactoryBase): + def create(self, calculator_name, model, instrumental_parameters=None, **kwargs): + return concrete_calculator_class(model, instrumental_parameters, **kwargs) + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + IncompleteFactory() + + def test_subclass_must_implement_create(self): + """Test that subclasses must implement create method.""" + + class IncompleteFactory(CalculatorFactoryBase): + @property + def available_calculators(self): + return ["test"] + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + IncompleteFactory() + + # Concrete factory tests + def test_factory_available_calculators(self, concrete_factory_class): + """Test available_calculators property.""" + factory = concrete_factory_class() + assert factory.available_calculators == ["test"] + + def test_factory_create_calculator(self, concrete_factory_class, mock_model, mock_instrumental_parameters): + """Test creating a calculator via factory.""" + factory = concrete_factory_class() + calculator = factory.create("test", mock_model, mock_instrumental_parameters) + assert isinstance(calculator, CalculatorBase) + assert calculator.model is mock_model + assert calculator.instrumental_parameters is mock_instrumental_parameters + + def test_factory_create_with_model_only(self, concrete_factory_class, mock_model): + """Test creating calculator with only model.""" + factory = concrete_factory_class() + calculator = factory.create("test", mock_model) + assert calculator.model is mock_model + assert calculator.instrumental_parameters is None + + def test_factory_create_unknown_calculator_raises_error(self, concrete_factory_class, mock_model): + """Test that creating unknown calculator raises ValueError.""" + factory = concrete_factory_class() + with pytest.raises(ValueError, match="Unknown calculator"): + factory.create("unknown", mock_model) + + # Repr tests + def test_factory_repr(self, concrete_factory_class): + """Test factory __repr__.""" + factory = concrete_factory_class() + repr_str = repr(factory) + assert "TestFactory" in repr_str + assert "test" in repr_str + + +class TestSimpleCalculatorFactory: + """Tests for SimpleCalculatorFactory class.""" + + @pytest.fixture + def clear(self): + """Clear global map to avoid test contamination.""" + global_object.map._clear() + yield + global_object.map._clear() + + @pytest.fixture + def mock_model(self): + """Create a mock model object.""" + model = MagicMock() + model.name = "MockModel" + model.unique_name = "MockModel" + model.display_name = "MockModel" + return model + + @pytest.fixture + def mock_instrumental_parameters(self): + """Create mock instrumental parameters.""" + params = MagicMock() + params.name = "MockInstrument" + params.unique_name = "MockInstrument" + params.display_name = "MockInstrument" + return params + + @pytest.fixture + def calculator_class_a(self): + """Create first concrete calculator class.""" + + class CalculatorA(CalculatorBase): + name = "calc_a" + + def calculate(self, x: np.ndarray) -> np.ndarray: + return x * 2.0 + + return CalculatorA + + @pytest.fixture + def calculator_class_b(self): + """Create second concrete calculator class.""" + + class CalculatorB(CalculatorBase): + name = "calc_b" + + def calculate(self, x: np.ndarray) -> np.ndarray: + return x * 3.0 + + return CalculatorB + + # Initialization tests + def test_init_empty(self): + """Test initialization with no calculators.""" + factory = SimpleCalculatorFactory() + assert factory.available_calculators == [] + + def test_init_with_calculators_dict(self, calculator_class_a, calculator_class_b): + """Test initialization with calculators dictionary.""" + factory = SimpleCalculatorFactory({ + "a": calculator_class_a, + "b": calculator_class_b, + }) + assert set(factory.available_calculators) == {"a", "b"} + + def test_class_level_calculators(self, calculator_class_a): + """Test using class-level _calculators attribute.""" + + class MyFactory(SimpleCalculatorFactory): + pass + + # Set class-level calculators + MyFactory._calculators = {"my_calc": calculator_class_a} + factory = MyFactory() + assert "my_calc" in factory.available_calculators + + # Available calculators tests + def test_available_calculators_returns_list(self, calculator_class_a): + """Test that available_calculators returns a list.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + result = factory.available_calculators + assert isinstance(result, list) + assert "a" in result + + # Create tests + def test_create_calculator(self, calculator_class_a, mock_model, mock_instrumental_parameters): + """Test creating a calculator.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + calculator = factory.create("a", mock_model, mock_instrumental_parameters) + assert isinstance(calculator, CalculatorBase) + assert calculator.model is mock_model + assert calculator.instrumental_parameters is mock_instrumental_parameters + + def test_create_with_kwargs(self, calculator_class_a, mock_model): + """Test creating calculator with additional kwargs.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + calculator = factory.create("a", mock_model, custom_option="value") + assert calculator.additional_kwargs == {"custom_option": "value"} + + def test_create_unknown_calculator_raises_error(self, calculator_class_a, mock_model): + """Test that creating unknown calculator raises ValueError.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + with pytest.raises(ValueError, match="Unknown calculator 'unknown'"): + factory.create("unknown", mock_model) + + def test_create_error_message_includes_available(self, calculator_class_a, calculator_class_b, mock_model): + """Test that error message includes available calculators.""" + factory = SimpleCalculatorFactory({ + "a": calculator_class_a, + "b": calculator_class_b, + }) + with pytest.raises(ValueError) as exc_info: + factory.create("unknown", mock_model) + assert "a" in str(exc_info.value) or "b" in str(exc_info.value) + + # Register tests + def test_register_calculator(self, calculator_class_a, calculator_class_b): + """Test registering a new calculator.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + factory.register("b", calculator_class_b) + assert "b" in factory.available_calculators + + def test_register_overwrites_existing(self, calculator_class_a, calculator_class_b): + """Test that registering with existing name overwrites.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + factory.register("a", calculator_class_b) + # Now "a" should create CalculatorB + calc = factory.create("a", MagicMock()) + assert calc.name == "calc_b" + + def test_register_invalid_class_raises_error(self, calculator_class_a): + """Test that registering non-CalculatorBase raises TypeError.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + + class NotACalculator: + pass + + with pytest.raises(TypeError, match="must be a subclass of CalculatorBase"): + factory.register("bad", NotACalculator) + + def test_register_non_class_raises_error(self, calculator_class_a): + """Test that registering a non-class raises TypeError.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + with pytest.raises(TypeError, match="must be a subclass of CalculatorBase"): + factory.register("bad", "not a class") + + # Unregister tests + def test_unregister_calculator(self, calculator_class_a, calculator_class_b): + """Test unregistering a calculator.""" + factory = SimpleCalculatorFactory({ + "a": calculator_class_a, + "b": calculator_class_b, + }) + factory.unregister("a") + assert "a" not in factory.available_calculators + assert "b" in factory.available_calculators + + def test_unregister_unknown_raises_error(self, calculator_class_a): + """Test that unregistering unknown calculator raises KeyError.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + with pytest.raises(KeyError, match="Calculator 'unknown' is not registered"): + factory.unregister("unknown") + + # Repr tests + def test_repr_empty_factory(self): + """Test __repr__ with empty factory.""" + factory = SimpleCalculatorFactory() + repr_str = repr(factory) + assert "SimpleCalculatorFactory" in repr_str + assert "available=[]" in repr_str + + def test_repr_with_calculators(self, calculator_class_a, calculator_class_b): + """Test __repr__ with calculators.""" + factory = SimpleCalculatorFactory({ + "a": calculator_class_a, + "b": calculator_class_b, + }) + repr_str = repr(factory) + assert "SimpleCalculatorFactory" in repr_str + assert "a" in repr_str or "b" in repr_str + + # Integration tests + def test_created_calculator_works(self, calculator_class_a, mock_model): + """Test that created calculator actually works.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + calculator = factory.create("a", mock_model) + x = np.array([1.0, 2.0, 3.0]) + result = calculator.calculate(x) + np.testing.assert_array_equal(result, np.array([2.0, 4.0, 6.0])) + + def test_create_multiple_calculators_independently( + self, calculator_class_a, calculator_class_b, mock_model + ): + """Test creating multiple independent calculators.""" + factory = SimpleCalculatorFactory({ + "a": calculator_class_a, + "b": calculator_class_b, + }) + + model_a = MagicMock(name="ModelA") + model_b = MagicMock(name="ModelB") + + calc_a = factory.create("a", model_a) + calc_b = factory.create("b", model_b) + + # They should be independent + assert calc_a.model is model_a + assert calc_b.model is model_b + assert calc_a is not calc_b + + # And calculate differently + x = np.array([1.0, 2.0]) + np.testing.assert_array_equal(calc_a.calculate(x), np.array([2.0, 4.0])) + np.testing.assert_array_equal(calc_b.calculate(x), np.array([3.0, 6.0])) + + +class TestFactoryStatelessness: + """Tests to verify that the factory is truly stateless.""" + + @pytest.fixture + def calculator_class(self): + """Create a calculator class with counter for instances.""" + + class CountingCalculator(CalculatorBase): + name = "counting" + instance_count = 0 + + def __init__(self, model, instrumental_parameters=None, **kwargs): + super().__init__(model, instrumental_parameters, **kwargs) + CountingCalculator.instance_count += 1 + self.instance_id = CountingCalculator.instance_count + + def calculate(self, x: np.ndarray) -> np.ndarray: + return x + + # Reset counter before each test + CountingCalculator.instance_count = 0 + return CountingCalculator + + def test_factory_does_not_store_calculator_instances(self, calculator_class): + """Test that factory doesn't store references to created calculators.""" + factory = SimpleCalculatorFactory({"calc": calculator_class}) + mock_model = MagicMock() + + calc1 = factory.create("calc", mock_model) + calc2 = factory.create("calc", mock_model) + + # Each create should produce a new instance + assert calc1 is not calc2 + assert calc1.instance_id == 1 + assert calc2.instance_id == 2 + + def test_factory_has_no_current_calculator_attribute(self, calculator_class): + """Test that factory has no 'current' calculator state.""" + factory = SimpleCalculatorFactory({"calc": calculator_class}) + + # Should not have any attributes tracking current state + assert not hasattr(factory, "_current_calculator") + assert not hasattr(factory, "current_calculator") + assert not hasattr(factory, "_current") + + def test_multiple_factories_are_independent(self, calculator_class): + """Test that multiple factory instances are independent.""" + factory1 = SimpleCalculatorFactory({"calc": calculator_class}) + factory2 = SimpleCalculatorFactory({"calc": calculator_class}) + + mock_model = MagicMock() + + calc1 = factory1.create("calc", mock_model) + calc2 = factory2.create("calc", mock_model) + + # Each factory creates independent calculators + assert calc1 is not calc2 + + +class TestFactoryIsolation: + """Tests to ensure calculator registries don't bleed between factory instances or subclasses.""" + + @pytest.fixture + def calculator_class_x(self): + """First test calculator class.""" + class CalculatorX(CalculatorBase): + name = "x" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x + return CalculatorX + + @pytest.fixture + def calculator_class_y(self): + """Second test calculator class.""" + class CalculatorY(CalculatorBase): + name = "y" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x * 2 + return CalculatorY + + @pytest.fixture + def calculator_class_z(self): + """Third test calculator class.""" + class CalculatorZ(CalculatorBase): + name = "z" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x * 3 + return CalculatorZ + + def test_instance_registration_does_not_affect_other_instances( + self, calculator_class_x, calculator_class_y, calculator_class_z + ): + """Test that registering to one instance doesn't affect others.""" + factory1 = SimpleCalculatorFactory({"x": calculator_class_x}) + factory2 = SimpleCalculatorFactory({"y": calculator_class_y}) + + # Register z to factory1 only + factory1.register("z", calculator_class_z) + + # factory1 should have both x and z + assert "x" in factory1.available_calculators + assert "z" in factory1.available_calculators + assert "y" not in factory1.available_calculators + + # factory2 should only have y + assert "y" in factory2.available_calculators + assert "x" not in factory2.available_calculators + assert "z" not in factory2.available_calculators + + def test_subclass_registration_does_not_affect_parent_or_siblings( + self, calculator_class_x, calculator_class_y + ): + """Test that subclass registries are independent.""" + + class FactoryA(SimpleCalculatorFactory): + _calculators = {"x": calculator_class_x} + + class FactoryB(SimpleCalculatorFactory): + _calculators = {"y": calculator_class_y} + + factory_a = FactoryA() + factory_b = FactoryB() + + # Each should have their own calculators + assert "x" in factory_a.available_calculators + assert "y" not in factory_a.available_calculators + + assert "y" in factory_b.available_calculators + assert "x" not in factory_b.available_calculators + + def test_class_level_registry_not_modified_by_instance_register( + self, calculator_class_x, calculator_class_y + ): + """Test that instance.register() doesn't modify class-level registry.""" + + class MyFactory(SimpleCalculatorFactory): + _calculators = {"x": calculator_class_x} + + # Create instance and register to it + factory = MyFactory() + factory.register("y", calculator_class_y) + + # Instance should have both + assert "x" in factory.available_calculators + assert "y" in factory.available_calculators + + # Create new instance - should NOT have y + factory2 = MyFactory() + assert "x" in factory2.available_calculators + assert "y" not in factory2.available_calculators + + def test_unregister_from_one_instance_does_not_affect_others( + self, calculator_class_x + ): + """Test that unregistering from one instance doesn't affect others.""" + factory1 = SimpleCalculatorFactory({"x": calculator_class_x}) + factory2 = SimpleCalculatorFactory({"x": calculator_class_x}) + + # Unregister from factory1 + factory1.unregister("x") + + # factory1 should not have x + assert "x" not in factory1.available_calculators + + # factory2 should still have x + assert "x" in factory2.available_calculators + + +class TestFactoryErrorHandling: + """Tests for improved error handling and validation.""" + + @pytest.fixture + def calculator_class(self): + """Simple test calculator.""" + class TestCalc(CalculatorBase): + name = "test" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x + return TestCalc + + def test_register_with_empty_name_raises_error(self, calculator_class): + """Test that empty calculator name raises ValueError.""" + factory = SimpleCalculatorFactory() + with pytest.raises(ValueError, match="non-empty string"): + factory.register("", calculator_class) + + def test_register_with_non_string_name_raises_error(self, calculator_class): + """Test that non-string calculator name raises ValueError.""" + factory = SimpleCalculatorFactory() + with pytest.raises(ValueError, match="non-empty string"): + factory.register(123, calculator_class) + + def test_register_overwrites_with_warning(self, calculator_class): + """Test that overwriting existing calculator issues warning.""" + factory = SimpleCalculatorFactory({"test": calculator_class}) + + class NewCalc(CalculatorBase): + name = "new" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x * 2 + + with pytest.warns(UserWarning, match="Overwriting existing calculator 'test'"): + factory.register("test", NewCalc) + + def test_create_with_non_string_name_raises_error(self, calculator_class): + """Test that create with non-string name raises ValueError.""" + factory = SimpleCalculatorFactory({"test": calculator_class}) + with pytest.raises(ValueError, match="must be a string"): + factory.create(123, MagicMock()) + + def test_create_with_none_model_raises_error(self, calculator_class): + """Test that create with None model raises TypeError.""" + factory = SimpleCalculatorFactory({"test": calculator_class}) + with pytest.raises(TypeError, match="model cannot be None"): + factory.create("test", None) + + def test_create_unknown_calculator_shows_available_in_error(self, calculator_class): + """Test that error message includes available calculators.""" + factory = SimpleCalculatorFactory({"calc1": calculator_class}) + with pytest.raises(ValueError, match="calc1") as exc_info: + factory.create("unknown", MagicMock()) + assert "Available calculators" in str(exc_info.value) + + def test_create_empty_factory_error_shows_none_available(self): + """Test error message when factory has no calculators.""" + factory = SimpleCalculatorFactory() + with pytest.raises(ValueError, match="none") as exc_info: + factory.create("anything", MagicMock()) + assert "Available calculators: none" in str(exc_info.value) + + def test_create_wraps_calculator_init_errors(self, calculator_class): + """Test that calculator initialization errors are wrapped.""" + + class BrokenCalc(CalculatorBase): + name = "broken" + def __init__(self, model, instrumental_parameters=None, **kwargs): + raise RuntimeError("Something went wrong") + def calculate(self, x: np.ndarray) -> np.ndarray: + return x + + factory = SimpleCalculatorFactory({"broken": BrokenCalc}) + with pytest.raises(RuntimeError, match="Failed to create calculator 'broken'"): + factory.create("broken", MagicMock()) + + +class TestCalculatorKwargsProperty: + """Tests for the additional_kwargs property on CalculatorBase.""" + + @pytest.fixture + def calculator_class(self): + """Simple calculator class for testing.""" + class TestCalc(CalculatorBase): + name = "test" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x + return TestCalc + + def test_additional_kwargs_accessible(self, calculator_class): + """Test that additional_kwargs property is accessible.""" + calc = calculator_class( + MagicMock(), + custom_param="value", + another_option=42 + ) + kwargs = calc.additional_kwargs + assert isinstance(kwargs, dict) + assert kwargs["custom_param"] == "value" + assert kwargs["another_option"] == 42 + + def test_additional_kwargs_empty_when_none_provided(self, calculator_class): + """Test that additional_kwargs is empty dict when no kwargs provided.""" + calc = calculator_class(MagicMock()) + assert calc.additional_kwargs == {} + + def test_additional_kwargs_via_factory(self, calculator_class): + """Test that kwargs passed through factory are accessible.""" + factory = SimpleCalculatorFactory({"test": calculator_class}) + calc = factory.create( + "test", + MagicMock(), + option1="value1", + option2=123 + ) + assert calc.additional_kwargs["option1"] == "value1" + assert calc.additional_kwargs["option2"] == 123 From f093ee127f132f52dcb0a89a6ddb6186f374b4f8 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Mon, 1 Dec 2025 13:59:17 +0100 Subject: [PATCH 02/14] ruff nonsense --- .../fitting/calculators/__init__.py | 8 ++-- .../fitting/calculators/calculator_base.py | 18 ++++---- .../fitting/calculators/calculator_factory.py | 42 +++++++------------ 3 files changed, 28 insertions(+), 40 deletions(-) diff --git a/src/easyscience/fitting/calculators/__init__.py b/src/easyscience/fitting/calculators/__init__.py index e0c5d480..e2c1ea89 100644 --- a/src/easyscience/fitting/calculators/__init__.py +++ b/src/easyscience/fitting/calculators/__init__.py @@ -8,8 +8,8 @@ from .interface_factory import InterfaceFactoryTemplate __all__ = [ - "CalculatorBase", - "CalculatorFactoryBase", - "SimpleCalculatorFactory", - "InterfaceFactoryTemplate", # Deprecated, kept for backwards compatibility + 'CalculatorBase', + 'CalculatorFactoryBase', + 'SimpleCalculatorFactory', + 'InterfaceFactoryTemplate', # Deprecated, kept for backwards compatibility ] diff --git a/src/easyscience/fitting/calculators/calculator_base.py b/src/easyscience/fitting/calculators/calculator_base.py index 523067f9..8688a7a5 100644 --- a/src/easyscience/fitting/calculators/calculator_base.py +++ b/src/easyscience/fitting/calculators/calculator_base.py @@ -80,7 +80,7 @@ def calculate(self, x): return computed_values """ - name: str = "base" + name: str = 'base' def __init__( self, @@ -109,11 +109,11 @@ def __init__( Additional calculator-specific options. """ if model is None: - raise ValueError("Model cannot be None") - + raise ValueError('Model cannot be None') + # Initialize NewBase with naming super().__init__(unique_name=unique_name, display_name=display_name) - + self._model = model self._instrumental_parameters = instrumental_parameters self._additional_kwargs = kwargs @@ -146,7 +146,7 @@ def model(self, new_model: NewBase) -> None: If the new model is None. """ if new_model is None: - raise ValueError("Model cannot be None") + raise ValueError('Model cannot be None') self._model = new_model @property @@ -250,12 +250,12 @@ def calculate(self, x: np.ndarray) -> np.ndarray: def __repr__(self) -> str: """Return a string representation of the calculator.""" model_name = getattr(self._model, 'name', type(self._model).__name__) - instr_info = "" + instr_info = '' if self._instrumental_parameters is not None: instr_name = getattr( self._instrumental_parameters, 'name', - type(self._instrumental_parameters).__name__ # default to class name if no 'name' attribute + type(self._instrumental_parameters).__name__, # default to class name if no 'name' attribute ) - instr_info = f", instrumental_parameters={instr_name}" - return f"{self.__class__.__name__}(model={model_name}{instr_info})" + instr_info = f', instrumental_parameters={instr_name}' + return f'{self.__class__.__name__}(model={model_name}{instr_info})' diff --git a/src/easyscience/fitting/calculators/calculator_factory.py b/src/easyscience/fitting/calculators/calculator_factory.py index 50a55e10..de13c5e4 100644 --- a/src/easyscience/fitting/calculators/calculator_factory.py +++ b/src/easyscience/fitting/calculators/calculator_factory.py @@ -140,7 +140,7 @@ def create( def __repr__(self) -> str: """Return a string representation of the factory.""" - return f"{self.__class__.__name__}(available={self.available_calculators})" + return f'{self.__class__.__name__}(available={self.available_calculators})' class SimpleCalculatorFactory(CalculatorFactoryBase): @@ -252,25 +252,20 @@ def create( If model is None or instrumental_parameters has wrong type. """ if not isinstance(calculator_name, str): - raise ValueError(f"calculator_name must be a string, got {type(calculator_name).__name__}") - + raise ValueError(f'calculator_name must be a string, got {type(calculator_name).__name__}') + if calculator_name not in self._calculators: - available = ", ".join(self.available_calculators) if self.available_calculators else "none" - raise ValueError( - f"Unknown calculator '{calculator_name}'. " - f"Available calculators: {available}" - ) - + available = ', '.join(self.available_calculators) if self.available_calculators else 'none' + raise ValueError(f"Unknown calculator '{calculator_name}'. Available calculators: {available}") + if model is None: - raise TypeError("model cannot be None") - + raise TypeError('model cannot be None') + calculator_class = self._calculators[calculator_name] try: return calculator_class(model, instrumental_parameters, **kwargs) except Exception as e: - raise type(e)( - f"Failed to create calculator '{calculator_name}': {e}" - ) from e + raise type(e)(f"Failed to create calculator '{calculator_name}': {e}") from e def register(self, name: str, calculator_class: Type[CalculatorBase]) -> None: """ @@ -289,7 +284,7 @@ def register(self, name: str, calculator_class: Type[CalculatorBase]) -> None: If calculator_class is not a subclass of CalculatorBase. ValueError If name is empty or not a string. - + Warnings -------- If overwriting an existing calculator, a warning is issued. @@ -300,21 +295,14 @@ def register(self, name: str, calculator_class: Type[CalculatorBase]) -> None: from .calculator_base import CalculatorBase if not isinstance(name, str) or not name: - raise ValueError("Calculator name must be a non-empty string") + raise ValueError('Calculator name must be a non-empty string') if not (isinstance(calculator_class, type) and issubclass(calculator_class, CalculatorBase)): - raise TypeError( - f"calculator_class must be a subclass of CalculatorBase, " - f"got {type(calculator_class).__name__}" - ) - + raise TypeError(f'calculator_class must be a subclass of CalculatorBase, got {type(calculator_class).__name__}') + if name in self._calculators: - warnings.warn( - f"Overwriting existing calculator '{name}' in {self.__class__.__name__}", - UserWarning, - stacklevel=2 - ) - + warnings.warn(f"Overwriting existing calculator '{name}' in {self.__class__.__name__}", UserWarning, stacklevel=2) + self._calculators[name] = calculator_class def unregister(self, name: str) -> None: From b8a307ea9d189d803e6c451022c78eb691329bf4 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Fri, 5 Dec 2025 15:05:27 +0100 Subject: [PATCH 03/14] added updated impl of model collection --- src/easyscience/base_classes/__init__.py | 2 + .../base_classes/model_collection.py | 287 ++++++++ .../base_classes/test_model_collection.py | 657 ++++++++++++++++++ 3 files changed, 946 insertions(+) create mode 100644 src/easyscience/base_classes/model_collection.py create mode 100644 tests/unit_tests/base_classes/test_model_collection.py diff --git a/src/easyscience/base_classes/__init__.py b/src/easyscience/base_classes/__init__.py index 9f3ba080..9a8dc2d3 100644 --- a/src/easyscience/base_classes/__init__.py +++ b/src/easyscience/base_classes/__init__.py @@ -1,6 +1,7 @@ from .based_base import BasedBase from .collection_base import CollectionBase from .model_base import ModelBase +from .model_collection import ModelCollection from .new_base import NewBase from .obj_base import ObjBase @@ -9,5 +10,6 @@ CollectionBase, ObjBase, ModelBase, + ModelCollection, NewBase, ] diff --git a/src/easyscience/base_classes/model_collection.py b/src/easyscience/base_classes/model_collection.py new file mode 100644 index 00000000..12314c77 --- /dev/null +++ b/src/easyscience/base_classes/model_collection.py @@ -0,0 +1,287 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project None: + """Add an item to the collection and set up graph edges.""" + if not isinstance(item, NewBase): + raise TypeError(f'Items must be NewBase objects, got {type(item)}') + self._data.append(item) + self._global_object.map.add_edge(self, item) + self._global_object.map.reset_type(item, 'created_internal') + if self._interface is not None and hasattr(item, 'interface'): + setattr(item, 'interface', self._interface) + + def _remove_item(self, item: NewBase) -> None: + """Remove an item from the collection and clean up graph edges.""" + self._global_object.map.prune_vertex_from_edge(self, item) + + @property + def name(self) -> str: + """Get the name of the collection.""" + return self._name + + @name.setter + def name(self, new_name: str) -> None: + """Set the name of the collection.""" + self._name = new_name + + @property + def interface(self) -> Optional[InterfaceFactoryTemplate]: + """Get the current interface of the collection.""" + return self._interface + + @interface.setter + def interface(self, new_interface: Optional[InterfaceFactoryTemplate]) -> None: + """Set the interface and propagate to all items.""" + self._interface = new_interface + for item in self._data: + if hasattr(item, 'interface'): + setattr(item, 'interface', new_interface) + + # MutableSequence abstract methods + + # Use @overload to provide precise type hints for different __getitem__ argument types + @overload + def __getitem__(self, idx: int) -> T: ... + @overload + def __getitem__(self, idx: slice) -> 'ModelCollection[T]': ... + @overload + def __getitem__(self, idx: str) -> T: ... + + def __getitem__(self, idx: Union[int, slice, str]) -> Union[T, 'ModelCollection[T]']: + """ + Get an item by index, slice, or name. + + :param idx: Index, slice, or name of the item + :return: The item or a new collection for slices + """ + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + return self.__class__(self._name, *[self._data[i] for i in range(start, stop, step)]) + if isinstance(idx, str): + # Search by name + for item in self._data: + if hasattr(item, 'name') and getattr(item, 'name') == idx: + return item # type: ignore[return-value] + if hasattr(item, 'unique_name') and item.unique_name == idx: + return item # type: ignore[return-value] + raise KeyError(f'No item with name "{idx}" found') + return self._data[idx] # type: ignore[return-value] + + @overload + def __setitem__(self, idx: int, value: T) -> None: ... + @overload + def __setitem__(self, idx: slice, value: Iterable[T]) -> None: ... + + def __setitem__(self, idx: Union[int, slice], value: Union[T, Iterable[T]]) -> None: + """ + Set an item at an index. + + :param idx: Index to set + :param value: New value + """ + if isinstance(idx, slice): + # Handle slice assignment + values = list(value) # type: ignore[arg-type] + # Remove old items + start, stop, step = idx.indices(len(self)) + for i in range(start, stop, step): + self._remove_item(self._data[i]) + # Set new items + self._data[idx] = values # type: ignore[assignment] + for v in values: + self._global_object.map.add_edge(self, v) + self._global_object.map.reset_type(v, 'created_internal') + if self._interface is not None and hasattr(v, 'interface'): + setattr(v, 'interface', self._interface) + else: + if not isinstance(value, NewBase): + raise TypeError(f'Items must be NewBase objects, got {type(value)}') + + old_item = self._data[idx] + self._remove_item(old_item) + + self._data[idx] = value # type: ignore[assignment] + self._global_object.map.add_edge(self, value) + self._global_object.map.reset_type(value, 'created_internal') + if self._interface is not None and hasattr(value, 'interface'): + setattr(value, 'interface', self._interface) + + @overload + def __delitem__(self, idx: int) -> None: ... + @overload + def __delitem__(self, idx: slice) -> None: ... + @overload + def __delitem__(self, idx: str) -> None: ... + + def __delitem__(self, idx: Union[int, slice, str]) -> None: + """ + Delete an item by index, slice, or name. + + :param idx: Index, slice, or name of item to delete + """ + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + indices = list(range(start, stop, step)) + # Remove in reverse order to maintain indices + for i in reversed(indices): + item = self._data[i] + self._remove_item(item) + del self._data[i] + elif isinstance(idx, str): + for i, item in enumerate(self._data): + if hasattr(item, 'name') and getattr(item, 'name') == idx: + idx = i + break + if hasattr(item, 'unique_name') and item.unique_name == idx: + idx = i + break + else: + raise KeyError(f'No item with name "{idx}" found') + + item = self._data[idx] + self._remove_item(item) + del self._data[idx] + else: + item = self._data[idx] + self._remove_item(item) + del self._data[idx] + + def __len__(self) -> int: + """Return the number of items in the collection.""" + return len(self._data) + + def insert(self, index: int, value: T) -> None: + """ + Insert an item at an index. + + :param index: Index to insert at + :param value: Item to insert + """ + if not isinstance(value, NewBase): + raise TypeError(f'Items must be NewBase objects, got {type(value)}') + + self._data.insert(index, value) # type: ignore[arg-type] + self._global_object.map.add_edge(self, value) + self._global_object.map.reset_type(value, 'created_internal') + if self._interface is not None and hasattr(value, 'interface'): + setattr(value, 'interface', self._interface) + + # Additional utility methods + + @property + def data(self) -> tuple: + """Return the data as a tuple.""" + return tuple(self._data) + + def sort(self, mapping: Callable[[T], Any], reverse: bool = False) -> None: + """ + Sort the collection according to the given mapping. + + :param mapping: Mapping function to sort by + :param reverse: Whether to reverse the sort + """ + self._data.sort(key=mapping, reverse=reverse) # type: ignore[arg-type] + + def __repr__(self) -> str: + return f'{self.__class__.__name__} `{self._name}` of length {len(self)}' + + def __iter__(self) -> Any: + return iter(self._data) + + # Serialization support + + def _convert_to_dict(self, in_dict: dict, encoder: Any, skip: Optional[List[str]] = None, **kwargs: Any) -> dict: + """Convert the collection to a dictionary for serialization.""" + if skip is None: + skip = [] + d: dict = {} + if hasattr(self, '_modify_dict'): + d = self._modify_dict(skip=skip, **kwargs) # type: ignore[attr-defined] + in_dict['data'] = [encoder._convert_to_dict(item, skip=skip, **kwargs) for item in self._data] + return {**in_dict, **d} + + def get_all_variables(self) -> List[Any]: + """Get all variables from all items in the collection.""" + variables: List[Any] = [] + for item in self._data: + if hasattr(item, 'get_all_variables'): + variables.extend(item.get_all_variables()) # type: ignore[attr-defined] + return variables + + def get_all_parameters(self) -> List[Any]: + """Get all parameters from all items in the collection.""" + return [var for var in self.get_all_variables() if isinstance(var, Parameter)] + + def get_fit_parameters(self) -> List[Any]: + """Get all fittable parameters from all items in the collection.""" + parameters: List[Any] = [] + for item in self._data: + if hasattr(item, 'get_fit_parameters'): + parameters.extend(item.get_fit_parameters()) # type: ignore[attr-defined] + return parameters diff --git a/tests/unit_tests/base_classes/test_model_collection.py b/tests/unit_tests/base_classes/test_model_collection.py new file mode 100644 index 00000000..7ea72c17 --- /dev/null +++ b/tests/unit_tests/base_classes/test_model_collection.py @@ -0,0 +1,657 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project str: + return self._name + + @property + def value(self) -> Parameter: + return self._value + + @value.setter + def value(self, new_value: float) -> None: + self._value.value = new_value + + +class DerivedModelCollection(ModelCollection): + """A derived class for testing inheritance.""" + pass + + +class_constructors = [ModelCollection, DerivedModelCollection] + + +@pytest.fixture +def clear_global(): + """Clear the global object map before each test.""" + global_object.map._clear() + yield + global_object.map._clear() + + +@pytest.fixture +def sample_items(): + """Create sample items for testing.""" + return [ + MockModelItem(name='item1', value=1.0), + MockModelItem(name='item2', value=2.0), + MockModelItem(name='item3', value=3.0), + ] + + +# ============================================================================= +# Constructor Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_init_empty(cls, clear_global): + """Test creating an empty collection.""" + coll = cls('test_collection') + assert coll.name == 'test_collection' + assert len(coll) == 0 + assert coll.interface is None + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_init_with_items(cls, clear_global, sample_items): + """Test creating a collection with initial items.""" + coll = cls('test_collection', *sample_items) + assert coll.name == 'test_collection' + assert len(coll) == 3 + for i, item in enumerate(coll): + assert item.name == sample_items[i].name + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_init_with_unique_name(cls, clear_global): + """Test creating a collection with a custom unique_name.""" + coll = cls('test_collection', unique_name='custom_unique') + assert coll.unique_name == 'custom_unique' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_init_with_display_name(cls, clear_global): + """Test creating a collection with a custom display_name.""" + coll = cls('test_collection', display_name='My Display Name') + assert coll.display_name == 'My Display Name' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_init_with_list_arg(cls, clear_global, sample_items): + """Test creating a collection with a list of items (should flatten).""" + coll = cls('test_collection', sample_items) + assert len(coll) == 3 + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_init_type_error(cls, clear_global): + """Test that adding non-NewBase items raises TypeError.""" + with pytest.raises(TypeError): + cls('test_collection', 'not_a_newbase_object') + + +# ============================================================================= +# Name Property Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_name_getter(cls, clear_global): + """Test getting the collection name.""" + coll = cls('my_collection') + assert coll.name == 'my_collection' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_name_setter(cls, clear_global): + """Test setting the collection name.""" + coll = cls('old_name') + coll.name = 'new_name' + assert coll.name == 'new_name' + + +# ============================================================================= +# Interface Property Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_interface_default(cls, clear_global): + """Test that interface defaults to None.""" + coll = cls('test_collection') + assert coll.interface is None + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_interface_propagation(cls, clear_global, sample_items): + """Test that setting interface propagates to items.""" + # Add interface attribute to items for this test + for item in sample_items: + item.interface = None + + coll = cls('test_collection', *sample_items) + + class MockInterface: + pass + + mock_interface = MockInterface() + coll.interface = mock_interface + + assert coll.interface is mock_interface + for item in coll: + assert item.interface is mock_interface + + +# ============================================================================= +# __getitem__ Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_getitem_int(cls, clear_global, sample_items): + """Test getting items by integer index.""" + coll = cls('test_collection', *sample_items) + assert coll[0].name == 'item1' + assert coll[1].name == 'item2' + assert coll[2].name == 'item3' + assert coll[-1].name == 'item3' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_getitem_int_out_of_range(cls, clear_global, sample_items): + """Test that out of range index raises IndexError.""" + coll = cls('test_collection', *sample_items) + with pytest.raises(IndexError): + _ = coll[100] + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_getitem_slice(cls, clear_global, sample_items): + """Test getting items by slice.""" + coll = cls('test_collection', *sample_items) + sliced = coll[0:2] + assert isinstance(sliced, cls) + assert len(sliced) == 2 + assert sliced[0].name == 'item1' + assert sliced[1].name == 'item2' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_getitem_str_by_name(cls, clear_global, sample_items): + """Test getting items by name string.""" + coll = cls('test_collection', *sample_items) + item = coll['item2'] + assert item.name == 'item2' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_getitem_str_by_unique_name(cls, clear_global, sample_items): + """Test getting items by unique_name string.""" + coll = cls('test_collection', *sample_items) + unique_name = sample_items[1].unique_name + item = coll[unique_name] + assert item.unique_name == unique_name + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_getitem_str_not_found(cls, clear_global, sample_items): + """Test that getting non-existent name raises KeyError.""" + coll = cls('test_collection', *sample_items) + with pytest.raises(KeyError): + _ = coll['nonexistent'] + + +# ============================================================================= +# __setitem__ Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_setitem_int(cls, clear_global, sample_items): + """Test setting items by integer index.""" + coll = cls('test_collection', *sample_items) + new_item = MockModelItem(name='new_item', value=99.0) + old_item = coll[1] + + coll[1] = new_item + + assert len(coll) == 3 + assert coll[1].name == 'new_item' + assert coll[1].value.value == 99.0 + + # Check graph edges + edges = global_object.map.get_edges(coll) + assert new_item.unique_name in edges + assert old_item.unique_name not in edges + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_setitem_type_error(cls, clear_global, sample_items): + """Test that setting non-NewBase item raises TypeError.""" + coll = cls('test_collection', *sample_items) + with pytest.raises(TypeError): + coll[0] = 'not_a_newbase_object' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_setitem_slice(cls, clear_global, sample_items): + """Test setting items by slice.""" + coll = cls('test_collection', *sample_items) + new_items = [ + MockModelItem(name='new1', value=10.0), + MockModelItem(name='new2', value=20.0), + ] + + coll[0:2] = new_items + + assert len(coll) == 3 + assert coll[0].name == 'new1' + assert coll[1].name == 'new2' + assert coll[2].name == 'item3' + + +# ============================================================================= +# __delitem__ Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_delitem_int(cls, clear_global, sample_items): + """Test deleting items by integer index.""" + coll = cls('test_collection', *sample_items) + deleted_item = coll[1] + + del coll[1] + + assert len(coll) == 2 + assert coll[0].name == 'item1' + assert coll[1].name == 'item3' + + # Check graph edges + edges = global_object.map.get_edges(coll) + assert deleted_item.unique_name not in edges + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_delitem_slice(cls, clear_global, sample_items): + """Test deleting items by slice.""" + coll = cls('test_collection', *sample_items) + + del coll[0:2] + + assert len(coll) == 1 + assert coll[0].name == 'item3' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_delitem_str_by_name(cls, clear_global, sample_items): + """Test deleting items by name string.""" + coll = cls('test_collection', *sample_items) + + del coll['item2'] + + assert len(coll) == 2 + assert 'item2' not in [item.name for item in coll] + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_delitem_str_not_found(cls, clear_global, sample_items): + """Test that deleting non-existent name raises KeyError.""" + coll = cls('test_collection', *sample_items) + with pytest.raises(KeyError): + del coll['nonexistent'] + + +# ============================================================================= +# __len__ Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +@pytest.mark.parametrize('count', [0, 1, 3, 5]) +def test_ModelCollection_len(cls, clear_global, count): + """Test __len__ returns correct count.""" + items = [MockModelItem(name=f'item{i}', value=float(i)) for i in range(count)] + coll = cls('test_collection', *items) + assert len(coll) == count + + +# ============================================================================= +# insert Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_insert(cls, clear_global, sample_items): + """Test inserting items at an index.""" + coll = cls('test_collection', *sample_items) + new_item = MockModelItem(name='inserted', value=99.0) + + coll.insert(1, new_item) + + assert len(coll) == 4 + assert coll[0].name == 'item1' + assert coll[1].name == 'inserted' + assert coll[2].name == 'item2' + assert coll[3].name == 'item3' + + # Check graph edges + edges = global_object.map.get_edges(coll) + assert new_item.unique_name in edges + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_insert_type_error(cls, clear_global, sample_items): + """Test that inserting non-NewBase item raises TypeError.""" + coll = cls('test_collection', *sample_items) + with pytest.raises(TypeError): + coll.insert(0, 'not_a_newbase_object') + + +# ============================================================================= +# append Tests (inherited from MutableSequence) +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_append(cls, clear_global, sample_items): + """Test appending items.""" + coll = cls('test_collection', *sample_items) + new_item = MockModelItem(name='appended', value=99.0) + + coll.append(new_item) + + assert len(coll) == 4 + assert coll[-1].name == 'appended' + + # Check graph edges + edges = global_object.map.get_edges(coll) + assert new_item.unique_name in edges + + +# ============================================================================= +# data Property Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_data_property(cls, clear_global, sample_items): + """Test that data property returns tuple of items.""" + coll = cls('test_collection', *sample_items) + data = coll.data + assert isinstance(data, tuple) + assert len(data) == 3 + for i, item in enumerate(data): + assert item.name == sample_items[i].name + + +# ============================================================================= +# sort Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_sort(cls, clear_global): + """Test sorting the collection.""" + items = [ + MockModelItem(name='c', value=3.0), + MockModelItem(name='a', value=1.0), + MockModelItem(name='b', value=2.0), + ] + coll = cls('test_collection', *items) + + coll.sort(lambda x: x.value.value) + + assert coll[0].name == 'a' + assert coll[1].name == 'b' + assert coll[2].name == 'c' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_sort_reverse(cls, clear_global): + """Test sorting the collection in reverse.""" + items = [ + MockModelItem(name='a', value=1.0), + MockModelItem(name='c', value=3.0), + MockModelItem(name='b', value=2.0), + ] + coll = cls('test_collection', *items) + + coll.sort(lambda x: x.value.value, reverse=True) + + assert coll[0].name == 'c' + assert coll[1].name == 'b' + assert coll[2].name == 'a' + + +# ============================================================================= +# __repr__ Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_repr(cls, clear_global, sample_items): + """Test string representation.""" + coll = cls('my_collection', *sample_items) + repr_str = repr(coll) + assert cls.__name__ in repr_str + assert 'my_collection' in repr_str + assert '3' in repr_str + + +# ============================================================================= +# __iter__ Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_iter(cls, clear_global, sample_items): + """Test iteration over collection.""" + coll = cls('test_collection', *sample_items) + + names = [item.name for item in coll] + assert names == ['item1', 'item2', 'item3'] + + +# ============================================================================= +# get_all_variables Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_get_all_variables(cls, clear_global, sample_items): + """Test getting all variables from items.""" + coll = cls('test_collection', *sample_items) + variables = coll.get_all_variables() + + # Each MockModelItem has one Parameter (value) + assert len(variables) == 3 + for var in variables: + assert isinstance(var, Parameter) + + +# ============================================================================= +# get_all_parameters Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_get_all_parameters(cls, clear_global, sample_items): + """Test getting all parameters from items.""" + coll = cls('test_collection', *sample_items) + parameters = coll.get_all_parameters() + + assert len(parameters) == 3 + for param in parameters: + assert isinstance(param, Parameter) + + +# ============================================================================= +# get_fit_parameters Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_get_fit_parameters(cls, clear_global, sample_items): + """Test getting fit parameters from items.""" + # Fix one parameter so we can test filtering + sample_items[0].value.fixed = True + + coll = cls('test_collection', *sample_items) + fit_params = coll.get_fit_parameters() + + # All 3 parameters should be returned (get_fit_parameters on items) + # since MockModelItem.get_fit_parameters returns free params + assert len(fit_params) == 2 + + +# ============================================================================= +# Graph Edge Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_graph_edges(cls, clear_global, sample_items): + """Test that graph edges are correctly maintained.""" + coll = cls('test_collection', *sample_items) + + edges = global_object.map.get_edges(coll) + assert len(edges) == 3 + + for item in sample_items: + assert item.unique_name in edges + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_graph_edges_after_append(cls, clear_global, sample_items): + """Test graph edges are updated after append.""" + coll = cls('test_collection', *sample_items) + new_item = MockModelItem(name='new', value=99.0) + + coll.append(new_item) + + edges = global_object.map.get_edges(coll) + assert len(edges) == 4 + assert new_item.unique_name in edges + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_graph_edges_after_delete(cls, clear_global, sample_items): + """Test graph edges are updated after delete.""" + coll = cls('test_collection', *sample_items) + deleted_item = sample_items[1] + + del coll[1] + + edges = global_object.map.get_edges(coll) + assert len(edges) == 2 + assert deleted_item.unique_name not in edges + + +# ============================================================================= +# MutableSequence Interface Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_extend(cls, clear_global, sample_items): + """Test extend method (inherited from MutableSequence).""" + coll = cls('test_collection', sample_items[0]) + coll.extend([sample_items[1], sample_items[2]]) + + assert len(coll) == 3 + assert coll[1].name == 'item2' + assert coll[2].name == 'item3' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_pop(cls, clear_global, sample_items): + """Test pop method (inherited from MutableSequence).""" + coll = cls('test_collection', *sample_items) + + popped = coll.pop() + assert popped.name == 'item3' + assert len(coll) == 2 + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_pop_index(cls, clear_global, sample_items): + """Test pop method with index (inherited from MutableSequence).""" + coll = cls('test_collection', *sample_items) + + popped = coll.pop(0) + assert popped.name == 'item1' + assert len(coll) == 2 + assert coll[0].name == 'item2' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_remove(cls, clear_global, sample_items): + """Test remove method (inherited from MutableSequence).""" + coll = cls('test_collection', *sample_items) + item_to_remove = sample_items[1] + + coll.remove(item_to_remove) + + assert len(coll) == 2 + assert item_to_remove not in coll + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_clear(cls, clear_global, sample_items): + """Test clear method (inherited from MutableSequence).""" + coll = cls('test_collection', *sample_items) + + coll.clear() + + assert len(coll) == 0 + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_reverse(cls, clear_global, sample_items): + """Test reverse method (inherited from MutableSequence).""" + coll = cls('test_collection', *sample_items) + + coll.reverse() + + assert coll[0].name == 'item3' + assert coll[1].name == 'item2' + assert coll[2].name == 'item1' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_count(cls, clear_global, sample_items): + """Test count method (inherited from MutableSequence).""" + coll = cls('test_collection', *sample_items) + + count = coll.count(sample_items[0]) + assert count == 1 + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_index(cls, clear_global, sample_items): + """Test index method (inherited from MutableSequence).""" + coll = cls('test_collection', *sample_items) + + idx = coll.index(sample_items[1]) + assert idx == 1 + + +# ============================================================================= +# Contains Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_contains(cls, clear_global, sample_items): + """Test __contains__ (in operator).""" + coll = cls('test_collection', *sample_items) + + assert sample_items[0] in coll + assert sample_items[1] in coll + + new_item = MockModelItem(name='not_in_collection', value=999.0) + assert new_item not in coll From 8704002024134ec762e5e2636cd4c4e4f02820a2 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Fri, 5 Dec 2025 15:06:34 +0100 Subject: [PATCH 04/14] ruff --- src/easyscience/base_classes/model_collection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/easyscience/base_classes/model_collection.py b/src/easyscience/base_classes/model_collection.py index 12314c77..c9a12c28 100644 --- a/src/easyscience/base_classes/model_collection.py +++ b/src/easyscience/base_classes/model_collection.py @@ -15,9 +15,8 @@ from typing import Union from typing import overload -from .new_base import NewBase - from ..variable import Parameter +from .new_base import NewBase if TYPE_CHECKING: from ..fitting.calculators import InterfaceFactoryTemplate From f158891774fd451024fba905b8a496cf84d8347c Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Thu, 11 Dec 2025 21:50:52 +0100 Subject: [PATCH 05/14] temporary inheritance added --- src/easyscience/base_classes/collection_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/easyscience/base_classes/collection_base.py b/src/easyscience/base_classes/collection_base.py index 3cc0586a..e08ae26c 100644 --- a/src/easyscience/base_classes/collection_base.py +++ b/src/easyscience/base_classes/collection_base.py @@ -18,6 +18,7 @@ from ..variable.descriptor_base import DescriptorBase from .based_base import BasedBase +from .new_base import NewBase if TYPE_CHECKING: from ..fitting.calculators import InterfaceFactoryTemplate @@ -64,7 +65,7 @@ def __init__( _kwargs[key] = item kwargs = _kwargs for item in list(kwargs.values()) + _args: - if not issubclass(type(item), (DescriptorBase, BasedBase)): + if not issubclass(type(item), (DescriptorBase, BasedBase, NewBase)): raise AttributeError('A collection can only be formed from easyscience objects.') args = _args _kwargs = {} From b4b4375d37be5f479e0ad65fc7a6f4342be0cf3a Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Sat, 13 Dec 2025 11:55:48 +0100 Subject: [PATCH 06/14] runn format --- src/easyscience/base_classes/model_collection.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/easyscience/base_classes/model_collection.py b/src/easyscience/base_classes/model_collection.py index c9a12c28..d26225ae 100644 --- a/src/easyscience/base_classes/model_collection.py +++ b/src/easyscience/base_classes/model_collection.py @@ -161,10 +161,10 @@ def __setitem__(self, idx: Union[int, slice], value: Union[T, Iterable[T]]) -> N else: if not isinstance(value, NewBase): raise TypeError(f'Items must be NewBase objects, got {type(value)}') - + old_item = self._data[idx] self._remove_item(old_item) - + self._data[idx] = value # type: ignore[assignment] self._global_object.map.add_edge(self, value) self._global_object.map.reset_type(value, 'created_internal') @@ -202,7 +202,7 @@ def __delitem__(self, idx: Union[int, slice, str]) -> None: break else: raise KeyError(f'No item with name "{idx}" found') - + item = self._data[idx] self._remove_item(item) del self._data[idx] @@ -224,7 +224,7 @@ def insert(self, index: int, value: T) -> None: """ if not isinstance(value, NewBase): raise TypeError(f'Items must be NewBase objects, got {type(value)}') - + self._data.insert(index, value) # type: ignore[arg-type] self._global_object.map.add_edge(self, value) self._global_object.map.reset_type(value, 'created_internal') From ec9f07d6e849b1d6c6f3ad0f5f4ef6102da463f6 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Mon, 15 Dec 2025 14:06:19 +0100 Subject: [PATCH 07/14] removed direct access to map._store --- src/easyscience/variable/parameter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/easyscience/variable/parameter.py b/src/easyscience/variable/parameter.py index da5dfe59..05712222 100644 --- a/src/easyscience/variable/parameter.py +++ b/src/easyscience/variable/parameter.py @@ -1025,7 +1025,8 @@ def resolve_pending_dependencies(self) -> None: def _find_parameter_by_serializer_id(self, serializer_id: str) -> Optional['DescriptorNumber']: """Find a parameter by its serializer_id from all parameters in the global map.""" - for obj in self._global_object.map._store.values(): + for key in self._global_object.map.vertices(): + obj = self._global_object.map.get_item_by_key(key) if isinstance(obj, DescriptorNumber) and hasattr(obj, '_DescriptorNumber__serializer_id'): if obj._DescriptorNumber__serializer_id == serializer_id: return obj From 4cf8f2c14a3fc2a2d4f43c15c33248eadddabdf8 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Mon, 15 Dec 2025 14:12:53 +0100 Subject: [PATCH 08/14] move map._store to map.__store to discourage direct access --- src/easyscience/global_object/map.py | 20 ++++++++++---------- tests/unit_tests/global_object/test_map.py | 10 +++++----- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/easyscience/global_object/map.py b/src/easyscience/global_object/map.py index c64bfaff..34459476 100644 --- a/src/easyscience/global_object/map.py +++ b/src/easyscience/global_object/map.py @@ -70,13 +70,13 @@ def is_returned(self) -> bool: class Map: def __init__(self): # A dictionary of object names and their corresponding objects - self._store = weakref.WeakValueDictionary() + self.__store = weakref.WeakValueDictionary() # A dict with object names as keys and a list of their object types as values, with weak references self.__type_dict = {} def vertices(self) -> List[str]: """returns the vertices of a map""" - return list(self._store.keys()) + return list(self.__store.keys()) def edges(self): """returns the edges of a map""" @@ -103,13 +103,13 @@ def _nested_get(self, obj_type: str) -> List[str]: return [key for key, item in self.__type_dict.items() if obj_type in item.type] def get_item_by_key(self, item_id: str) -> object: - if item_id in self._store.keys(): - return self._store[item_id] + if item_id in self.__store.keys(): + return self.__store[item_id] raise ValueError('Item not in map.') def is_known(self, vertex: object) -> bool: # All objects should have a 'unique_name' attribute - return vertex.unique_name in self._store.keys() + return vertex.unique_name in self.__store.keys() def find_type(self, vertex: object) -> List[str]: if self.is_known(vertex): @@ -125,11 +125,11 @@ def change_type(self, obj, new_type: str): def add_vertex(self, obj: object, obj_type: str = None): name = obj.unique_name - if name in self._store.keys(): + if name in self.__store.keys(): raise ValueError(f'Object name {name} already exists in the graph.') - self._store[name] = obj + self.__store[name] = obj self.__type_dict[name] = _EntryList() # Add objects type to the list of types - self.__type_dict[name].finalizer = weakref.finalize(self._store[name], self.prune, name) + self.__type_dict[name].finalizer = weakref.finalize(self.__store[name], self.prune, name) self.__type_dict[name].type = obj_type def add_edge(self, start_obj: object, end_obj: object): @@ -169,7 +169,7 @@ def prune_vertex_from_edge(self, parent_obj, child_obj): def prune(self, key: str): if key in self.__type_dict.keys(): del self.__type_dict[key] - del self._store[key] + del self.__store[key] def find_isolated_vertices(self) -> list: """returns a list of isolated vertices.""" @@ -268,4 +268,4 @@ def _clear(self): self.__type_dict = {} def __repr__(self) -> str: - return f'Map object of {len(self._store)} vertices.' + return f'Map object of {len(self.__store)} vertices.' diff --git a/tests/unit_tests/global_object/test_map.py b/tests/unit_tests/global_object/test_map.py index 05eec1a8..af2d0f6f 100644 --- a/tests/unit_tests/global_object/test_map.py +++ b/tests/unit_tests/global_object/test_map.py @@ -143,29 +143,29 @@ def parameter_object(self): def test_add_vertex(self, clear, base_object, parameter_object): # When Then Expect - assert len(global_object.map._store) == 2 + assert len(global_object.map._Map__store) == 2 assert len(global_object.map._Map__type_dict) == 2 def test_clear(self, clear, base_object): # When - assert len(global_object.map._store) == 1 + assert len(global_object.map._Map__store) == 1 assert len(global_object.map._Map__type_dict) == 1 # Then global_object.map._clear() # Expect - assert len(global_object.map._store) == 0 + assert len(global_object.map._Map__store) == 0 assert global_object.map._Map__type_dict == {} def test_weakref(self, clear): # When test_obj = ObjBase(name="test") - assert len(global_object.map._store) == 1 + assert len(global_object.map._Map__store) == 1 assert len(global_object.map._Map__type_dict) == 1 # Then del test_obj gc.collect() # Expect - assert len(global_object.map._store) == 0 + assert len(global_object.map._Map__store) == 0 assert len(global_object.map._Map__type_dict) == 0 def test_vertices(self, clear, base_object, parameter_object): From 7876d99e95e6be91a940a56ab631d9c21f894094 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Tue, 16 Dec 2025 14:18:40 +0100 Subject: [PATCH 09/14] minor fixes, including deprecation warning --- .../base_classes/model_collection.py | 19 +++++-- .../fitting/calculators/calculator_base.py | 9 +++- .../fitting/calculators/calculator_factory.py | 2 +- .../fitting/calculators/interface_factory.py | 11 ++++ .../calculators/test_calculator_factory.py | 2 +- .../calculators/test_interface_factory.py | 51 +++++++++++++++++++ 6 files changed, 85 insertions(+), 9 deletions(-) diff --git a/src/easyscience/base_classes/model_collection.py b/src/easyscience/base_classes/model_collection.py index d26225ae..71e9fda1 100644 --- a/src/easyscience/base_classes/model_collection.py +++ b/src/easyscience/base_classes/model_collection.py @@ -19,8 +19,12 @@ from .new_base import NewBase if TYPE_CHECKING: + from ..fitting.calculators import CalculatorFactoryBase from ..fitting.calculators import InterfaceFactoryTemplate +# Type alias for interface - supports both legacy and new factory types +InterfaceType = Union['InterfaceFactoryTemplate', 'CalculatorFactoryBase', None] + T = TypeVar('T', bound=NewBase) @@ -35,7 +39,7 @@ def __init__( self, name: str, *args: NewBase, - interface: Optional[InterfaceFactoryTemplate] = None, + interface: InterfaceType = None, unique_name: Optional[str] = None, display_name: Optional[str] = None, ): @@ -51,7 +55,7 @@ def __init__( super().__init__(unique_name=unique_name, display_name=display_name) self._name = name self._data: List[NewBase] = [] - self._interface: Optional[InterfaceFactoryTemplate] = None + self._interface: InterfaceType = None # Add initial items for item in args: @@ -66,9 +70,14 @@ def __init__( self.interface = interface def _add_item(self, item: Any) -> None: - """Add an item to the collection and set up graph edges.""" + """Add an item to the collection and set up graph edges. + + Note: Duplicate items (same object reference) are silently ignored. + """ if not isinstance(item, NewBase): raise TypeError(f'Items must be NewBase objects, got {type(item)}') + if item in self._data: + return # Skip duplicates to avoid multiple graph edges self._data.append(item) self._global_object.map.add_edge(self, item) self._global_object.map.reset_type(item, 'created_internal') @@ -90,12 +99,12 @@ def name(self, new_name: str) -> None: self._name = new_name @property - def interface(self) -> Optional[InterfaceFactoryTemplate]: + def interface(self) -> InterfaceType: """Get the current interface of the collection.""" return self._interface @interface.setter - def interface(self, new_interface: Optional[InterfaceFactoryTemplate]) -> None: + def interface(self, new_interface: InterfaceType) -> None: """Set the interface and propagate to all items.""" self._interface = new_interface for item in self._data: diff --git a/src/easyscience/fitting/calculators/calculator_base.py b/src/easyscience/fitting/calculators/calculator_base.py index 8688a7a5..bd7d12ff 100644 --- a/src/easyscience/fitting/calculators/calculator_base.py +++ b/src/easyscience/fitting/calculators/calculator_base.py @@ -117,6 +117,9 @@ def __init__( self._model = model self._instrumental_parameters = instrumental_parameters self._additional_kwargs = kwargs + # Register this calculator and model in the global object map + if hasattr(model, 'unique_name'): + self._global_object.map.add_edge(self, model) @property def model(self) -> NewBase: @@ -213,12 +216,14 @@ def additional_kwargs(self) -> dict: """ Get additional keyword arguments passed during initialization. + Returns a copy to prevent external modification of internal state. + Returns ------- dict - Dictionary of additional kwargs passed to __init__. + Copy of the dictionary of additional kwargs passed to __init__. """ - return self._additional_kwargs + return dict(self._additional_kwargs) @abstractmethod def calculate(self, x: np.ndarray) -> np.ndarray: diff --git a/src/easyscience/fitting/calculators/calculator_factory.py b/src/easyscience/fitting/calculators/calculator_factory.py index de13c5e4..69397d4d 100644 --- a/src/easyscience/fitting/calculators/calculator_factory.py +++ b/src/easyscience/fitting/calculators/calculator_factory.py @@ -259,7 +259,7 @@ def create( raise ValueError(f"Unknown calculator '{calculator_name}'. Available calculators: {available}") if model is None: - raise TypeError('model cannot be None') + raise TypeError('Model cannot be None') calculator_class = self._calculators[calculator_name] try: diff --git a/src/easyscience/fitting/calculators/interface_factory.py b/src/easyscience/fitting/calculators/interface_factory.py index ca4713fd..35956ad7 100644 --- a/src/easyscience/fitting/calculators/interface_factory.py +++ b/src/easyscience/fitting/calculators/interface_factory.py @@ -3,6 +3,7 @@ # SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause # © 2021-2025 Contributors to the EasyScience project Date: Wed, 17 Dec 2025 15:02:57 +0100 Subject: [PATCH 10/14] ruff format --- src/easyscience/base_classes/model_collection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/easyscience/base_classes/model_collection.py b/src/easyscience/base_classes/model_collection.py index 71e9fda1..62e22817 100644 --- a/src/easyscience/base_classes/model_collection.py +++ b/src/easyscience/base_classes/model_collection.py @@ -71,7 +71,7 @@ def __init__( def _add_item(self, item: Any) -> None: """Add an item to the collection and set up graph edges. - + Note: Duplicate items (same object reference) are silently ignored. """ if not isinstance(item, NewBase): From 0ec3f941f413cc7552f07722f63cb9480d587104 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Wed, 17 Dec 2025 15:36:37 +0100 Subject: [PATCH 11/14] standardize on underscores on develop merge --- src/easyscience/global_object/map.py | 18 +++++++++--------- tests/unit_tests/global_object/test_map.py | 12 ++++++------ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/easyscience/global_object/map.py b/src/easyscience/global_object/map.py index c0173919..b0dde4e4 100644 --- a/src/easyscience/global_object/map.py +++ b/src/easyscience/global_object/map.py @@ -82,7 +82,7 @@ def vertices(self) -> List[str]: """ while True: try: - return list(self._store) + return list(self.__store) except RuntimeError: # Dictionary changed size during iteration, retry continue @@ -112,8 +112,8 @@ def _nested_get(self, obj_type: str) -> List[str]: return [key for key, item in self.__type_dict.items() if obj_type in item.type] def get_item_by_key(self, item_id: str) -> object: - if item_id in self._store: - return self._store[item_id] + if item_id in self.__store: + return self.__store[item_id] raise ValueError('Item not in map.') def is_known(self, vertex: object) -> bool: @@ -121,7 +121,7 @@ def is_known(self, vertex: object) -> bool: All objects should have a 'unique_name' attribute. """ - return vertex.unique_name in self._store + return vertex.unique_name in self.__store def find_type(self, vertex: object) -> List[str]: if self.is_known(vertex): @@ -137,13 +137,13 @@ def change_type(self, obj, new_type: str): def add_vertex(self, obj: object, obj_type: str = None): name = obj.unique_name - if name in self._store: + if name in self.__store: raise ValueError(f'Object name {name} already exists in the graph.') # Clean up stale entry in __type_dict if the weak reference was collected # but the finalizer hasn't run yet if name in self.__type_dict: del self.__type_dict[name] - self._store[name] = obj + self.__store[name] = obj self.__type_dict[name] = _EntryList() # Add objects type to the list of types self.__type_dict[name].finalizer = weakref.finalize(self.__store[name], self.prune, name) self.__type_dict[name].type = obj_type @@ -185,8 +185,8 @@ def prune_vertex_from_edge(self, parent_obj, child_obj): def prune(self, key: str): if key in self.__type_dict: del self.__type_dict[key] - if key in self._store: - del self._store[key] + if key in self.__store: + del self.__store[key] def find_isolated_vertices(self) -> list: """returns a list of isolated vertices.""" @@ -279,7 +279,7 @@ def is_connected(self, vertices_encountered=None, start_vertex=None) -> bool: def _clear(self): """Reset the map to an empty state. Only to be used for testing""" - self._store.clear() + self.__store.clear() self.__type_dict.clear() gc.collect() diff --git a/tests/unit_tests/global_object/test_map.py b/tests/unit_tests/global_object/test_map.py index 44f49a71..00a55e63 100644 --- a/tests/unit_tests/global_object/test_map.py +++ b/tests/unit_tests/global_object/test_map.py @@ -495,7 +495,7 @@ def test_vertices_retry_on_runtime_error(self, clear): # Create a mock _store that raises RuntimeError on first iteration attempt call_count = 0 - original_store = test_map._store + original_store = test_map._Map__store class MockWeakValueDict: def __init__(self): @@ -514,7 +514,7 @@ def __len__(self): return len(self.data) mock_store = MockWeakValueDict() - test_map._store = mock_store + test_map._Map__store = mock_store # When vertices = test_map.vertices() @@ -544,7 +544,7 @@ def test_add_vertex_cleans_stale_type_dict_entry(self, clear): test_map.add_vertex(mock_obj, 'created') # Then - Object should be added successfully - assert stale_name in test_map._store + assert stale_name in test_map._Map__store assert stale_name in test_map._Map__type_dict assert test_map._Map__type_dict[stale_name].type == ['created'] @@ -571,7 +571,7 @@ def test_prune_key_in_both_dicts(self, clear, base_object): """Test that prune removes key from both _store and __type_dict.""" # Given unique_name = base_object.unique_name - assert unique_name in global_object.map._store + assert unique_name in global_object.map._Map__store assert unique_name in global_object.map._Map__type_dict # When @@ -676,14 +676,14 @@ def test_created_internal_property(self, clear): def test_clear_empties_both_dicts(self, clear, base_object, parameter_object): """Test that _clear() properly empties both _store and __type_dict.""" # Given - assert len(global_object.map._store) == 2 + assert len(global_object.map._Map__store) == 2 assert len(global_object.map._Map__type_dict) == 2 # When global_object.map._clear() # Then - assert len(global_object.map._store) == 0 + assert len(global_object.map._Map__store) == 0 assert len(global_object.map._Map__type_dict) == 0 def test_entry_list_delitem(self): From 4446dc1f45fdebb9235883020c2d9d0a709b857f Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Thu, 18 Dec 2025 14:33:55 +0100 Subject: [PATCH 12/14] PR fixes --- .../base_classes/model_collection.py | 57 +++---- .../base_classes/test_model_collection.py | 152 ++++++++++-------- 2 files changed, 106 insertions(+), 103 deletions(-) diff --git a/src/easyscience/base_classes/model_collection.py b/src/easyscience/base_classes/model_collection.py index 62e22817..b382f928 100644 --- a/src/easyscience/base_classes/model_collection.py +++ b/src/easyscience/base_classes/model_collection.py @@ -12,10 +12,10 @@ from typing import List from typing import Optional from typing import TypeVar -from typing import Union from typing import overload from ..variable import Parameter +from .model_base import ModelBase from .new_base import NewBase if TYPE_CHECKING: @@ -23,12 +23,12 @@ from ..fitting.calculators import InterfaceFactoryTemplate # Type alias for interface - supports both legacy and new factory types -InterfaceType = Union['InterfaceFactoryTemplate', 'CalculatorFactoryBase', None] +InterfaceType = 'InterfaceFactoryTemplate | CalculatorFactoryBase | None' T = TypeVar('T', bound=NewBase) -class ModelCollection(NewBase, MutableSequence[T]): +class ModelCollection(ModelBase, MutableSequence[T]): """ A collection class for NewBase/ModelBase objects. This provides list-like functionality while maintaining EasyScience features @@ -37,7 +37,6 @@ class ModelCollection(NewBase, MutableSequence[T]): def __init__( self, - name: str, *args: NewBase, interface: InterfaceType = None, unique_name: Optional[str] = None, @@ -46,14 +45,12 @@ def __init__( """ Initialize the ModelCollection. - :param name: Name of this collection :param args: Initial items to add to the collection :param interface: Optional interface for bindings :param unique_name: Optional unique name for the collection :param display_name: Optional display name for the collection """ super().__init__(unique_name=unique_name, display_name=display_name) - self._name = name self._data: List[NewBase] = [] self._interface: InterfaceType = None @@ -88,16 +85,6 @@ def _remove_item(self, item: NewBase) -> None: """Remove an item from the collection and clean up graph edges.""" self._global_object.map.prune_vertex_from_edge(self, item) - @property - def name(self) -> str: - """Get the name of the collection.""" - return self._name - - @name.setter - def name(self, new_name: str) -> None: - """Set the name of the collection.""" - self._name = new_name - @property def interface(self) -> InterfaceType: """Get the current interface of the collection.""" @@ -105,7 +92,21 @@ def interface(self) -> InterfaceType: @interface.setter def interface(self, new_interface: InterfaceType) -> None: - """Set the interface and propagate to all items.""" + """Set the interface and propagate to all items. + + :param new_interface: The interface to set (must be InterfaceFactoryTemplate, CalculatorFactoryBase, or None) + :raises TypeError: If the interface is not a valid type + """ + # Import here to avoid circular imports + from ..fitting.calculators import CalculatorFactoryBase + from ..fitting.calculators import InterfaceFactoryTemplate + + if new_interface is not None and not isinstance(new_interface, (InterfaceFactoryTemplate, CalculatorFactoryBase)): + raise TypeError( + f'interface must be InterfaceFactoryTemplate, CalculatorFactoryBase, or None, ' + f'got {type(new_interface).__name__}' + ) + self._interface = new_interface for item in self._data: if hasattr(item, 'interface'): @@ -121,7 +122,7 @@ def __getitem__(self, idx: slice) -> 'ModelCollection[T]': ... @overload def __getitem__(self, idx: str) -> T: ... - def __getitem__(self, idx: Union[int, slice, str]) -> Union[T, 'ModelCollection[T]']: + def __getitem__(self, idx: int | slice | str) -> T | 'ModelCollection[T]': """ Get an item by index, slice, or name. @@ -130,7 +131,7 @@ def __getitem__(self, idx: Union[int, slice, str]) -> Union[T, 'ModelCollection[ """ if isinstance(idx, slice): start, stop, step = idx.indices(len(self)) - return self.__class__(self._name, *[self._data[i] for i in range(start, stop, step)]) + return self.__class__(*[self._data[i] for i in range(start, stop, step)]) if isinstance(idx, str): # Search by name for item in self._data: @@ -146,7 +147,7 @@ def __setitem__(self, idx: int, value: T) -> None: ... @overload def __setitem__(self, idx: slice, value: Iterable[T]) -> None: ... - def __setitem__(self, idx: Union[int, slice], value: Union[T, Iterable[T]]) -> None: + def __setitem__(self, idx: int | slice, value: T | Iterable[T]) -> None: """ Set an item at an index. @@ -187,7 +188,7 @@ def __delitem__(self, idx: slice) -> None: ... @overload def __delitem__(self, idx: str) -> None: ... - def __delitem__(self, idx: Union[int, slice, str]) -> None: + def __delitem__(self, idx: int | slice | str) -> None: """ Delete an item by index, slice, or name. @@ -257,7 +258,7 @@ def sort(self, mapping: Callable[[T], Any], reverse: bool = False) -> None: self._data.sort(key=mapping, reverse=reverse) # type: ignore[arg-type] def __repr__(self) -> str: - return f'{self.__class__.__name__} `{self._name}` of length {len(self)}' + return f'{self.__class__.__name__} of length {len(self)}' def __iter__(self) -> Any: return iter(self._data) @@ -281,15 +282,3 @@ def get_all_variables(self) -> List[Any]: if hasattr(item, 'get_all_variables'): variables.extend(item.get_all_variables()) # type: ignore[attr-defined] return variables - - def get_all_parameters(self) -> List[Any]: - """Get all parameters from all items in the collection.""" - return [var for var in self.get_all_variables() if isinstance(var, Parameter)] - - def get_fit_parameters(self) -> List[Any]: - """Get all fittable parameters from all items in the collection.""" - parameters: List[Any] = [] - for item in self._data: - if hasattr(item, 'get_fit_parameters'): - parameters.extend(item.get_fit_parameters()) # type: ignore[attr-defined] - return parameters diff --git a/tests/unit_tests/base_classes/test_model_collection.py b/tests/unit_tests/base_classes/test_model_collection.py index 7ea72c17..b62ddb63 100644 --- a/tests/unit_tests/base_classes/test_model_collection.py +++ b/tests/unit_tests/base_classes/test_model_collection.py @@ -11,6 +11,7 @@ from easyscience.base_classes import ModelCollection from easyscience.base_classes import ModelBase from easyscience.base_classes import NewBase +from easyscience.fitting.calculators import CalculatorFactoryBase class MockModelItem(ModelBase): @@ -67,8 +68,7 @@ def sample_items(): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_init_empty(cls, clear_global): """Test creating an empty collection.""" - coll = cls('test_collection') - assert coll.name == 'test_collection' + coll = cls() assert len(coll) == 0 assert coll.interface is None @@ -76,8 +76,7 @@ def test_ModelCollection_init_empty(cls, clear_global): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_init_with_items(cls, clear_global, sample_items): """Test creating a collection with initial items.""" - coll = cls('test_collection', *sample_items) - assert coll.name == 'test_collection' + coll = cls(*sample_items) assert len(coll) == 3 for i, item in enumerate(coll): assert item.name == sample_items[i].name @@ -86,21 +85,21 @@ def test_ModelCollection_init_with_items(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_init_with_unique_name(cls, clear_global): """Test creating a collection with a custom unique_name.""" - coll = cls('test_collection', unique_name='custom_unique') + coll = cls(unique_name='custom_unique') assert coll.unique_name == 'custom_unique' @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_init_with_display_name(cls, clear_global): """Test creating a collection with a custom display_name.""" - coll = cls('test_collection', display_name='My Display Name') + coll = cls(display_name='My Display Name') assert coll.display_name == 'My Display Name' @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_init_with_list_arg(cls, clear_global, sample_items): """Test creating a collection with a list of items (should flatten).""" - coll = cls('test_collection', sample_items) + coll = cls(sample_items) assert len(coll) == 3 @@ -108,26 +107,7 @@ def test_ModelCollection_init_with_list_arg(cls, clear_global, sample_items): def test_ModelCollection_init_type_error(cls, clear_global): """Test that adding non-NewBase items raises TypeError.""" with pytest.raises(TypeError): - cls('test_collection', 'not_a_newbase_object') - - -# ============================================================================= -# Name Property Tests -# ============================================================================= - -@pytest.mark.parametrize('cls', class_constructors) -def test_ModelCollection_name_getter(cls, clear_global): - """Test getting the collection name.""" - coll = cls('my_collection') - assert coll.name == 'my_collection' - - -@pytest.mark.parametrize('cls', class_constructors) -def test_ModelCollection_name_setter(cls, clear_global): - """Test setting the collection name.""" - coll = cls('old_name') - coll.name = 'new_name' - assert coll.name == 'new_name' + cls('not_a_newbase_object') # ============================================================================= @@ -137,7 +117,7 @@ def test_ModelCollection_name_setter(cls, clear_global): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_interface_default(cls, clear_global): """Test that interface defaults to None.""" - coll = cls('test_collection') + coll = cls() assert coll.interface is None @@ -148,10 +128,16 @@ def test_ModelCollection_interface_propagation(cls, clear_global, sample_items): for item in sample_items: item.interface = None - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) - class MockInterface: - pass + class MockInterface(CalculatorFactoryBase): + """Mock interface for testing.""" + @property + def available_calculators(self): + return [] + + def create(self, calculator_name, *args, **kwargs): + pass mock_interface = MockInterface() coll.interface = mock_interface @@ -161,6 +147,35 @@ class MockInterface: assert item.interface is mock_interface +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_interface_type_error(cls, clear_global): + """Test that setting an invalid interface type raises TypeError.""" + coll = cls() + + with pytest.raises(TypeError, match='interface must be'): + coll.interface = 'not_an_interface' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_interface_type_error_with_object(cls, clear_global): + """Test that setting a plain object as interface raises TypeError.""" + coll = cls() + + class NotAnInterface: + pass + + with pytest.raises(TypeError, match='interface must be'): + coll.interface = NotAnInterface() + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_interface_accepts_none(cls, clear_global, sample_items): + """Test that setting interface to None is allowed.""" + coll = cls(*sample_items) + coll.interface = None + assert coll.interface is None + + # ============================================================================= # __getitem__ Tests # ============================================================================= @@ -168,7 +183,7 @@ class MockInterface: @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_getitem_int(cls, clear_global, sample_items): """Test getting items by integer index.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) assert coll[0].name == 'item1' assert coll[1].name == 'item2' assert coll[2].name == 'item3' @@ -178,7 +193,7 @@ def test_ModelCollection_getitem_int(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_getitem_int_out_of_range(cls, clear_global, sample_items): """Test that out of range index raises IndexError.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) with pytest.raises(IndexError): _ = coll[100] @@ -186,7 +201,7 @@ def test_ModelCollection_getitem_int_out_of_range(cls, clear_global, sample_item @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_getitem_slice(cls, clear_global, sample_items): """Test getting items by slice.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) sliced = coll[0:2] assert isinstance(sliced, cls) assert len(sliced) == 2 @@ -197,7 +212,7 @@ def test_ModelCollection_getitem_slice(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_getitem_str_by_name(cls, clear_global, sample_items): """Test getting items by name string.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) item = coll['item2'] assert item.name == 'item2' @@ -205,7 +220,7 @@ def test_ModelCollection_getitem_str_by_name(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_getitem_str_by_unique_name(cls, clear_global, sample_items): """Test getting items by unique_name string.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) unique_name = sample_items[1].unique_name item = coll[unique_name] assert item.unique_name == unique_name @@ -214,7 +229,7 @@ def test_ModelCollection_getitem_str_by_unique_name(cls, clear_global, sample_it @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_getitem_str_not_found(cls, clear_global, sample_items): """Test that getting non-existent name raises KeyError.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) with pytest.raises(KeyError): _ = coll['nonexistent'] @@ -226,7 +241,7 @@ def test_ModelCollection_getitem_str_not_found(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_setitem_int(cls, clear_global, sample_items): """Test setting items by integer index.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) new_item = MockModelItem(name='new_item', value=99.0) old_item = coll[1] @@ -245,7 +260,7 @@ def test_ModelCollection_setitem_int(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_setitem_type_error(cls, clear_global, sample_items): """Test that setting non-NewBase item raises TypeError.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) with pytest.raises(TypeError): coll[0] = 'not_a_newbase_object' @@ -253,7 +268,7 @@ def test_ModelCollection_setitem_type_error(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_setitem_slice(cls, clear_global, sample_items): """Test setting items by slice.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) new_items = [ MockModelItem(name='new1', value=10.0), MockModelItem(name='new2', value=20.0), @@ -274,7 +289,7 @@ def test_ModelCollection_setitem_slice(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_delitem_int(cls, clear_global, sample_items): """Test deleting items by integer index.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) deleted_item = coll[1] del coll[1] @@ -291,7 +306,7 @@ def test_ModelCollection_delitem_int(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_delitem_slice(cls, clear_global, sample_items): """Test deleting items by slice.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) del coll[0:2] @@ -302,7 +317,7 @@ def test_ModelCollection_delitem_slice(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_delitem_str_by_name(cls, clear_global, sample_items): """Test deleting items by name string.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) del coll['item2'] @@ -313,7 +328,7 @@ def test_ModelCollection_delitem_str_by_name(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_delitem_str_not_found(cls, clear_global, sample_items): """Test that deleting non-existent name raises KeyError.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) with pytest.raises(KeyError): del coll['nonexistent'] @@ -327,7 +342,7 @@ def test_ModelCollection_delitem_str_not_found(cls, clear_global, sample_items): def test_ModelCollection_len(cls, clear_global, count): """Test __len__ returns correct count.""" items = [MockModelItem(name=f'item{i}', value=float(i)) for i in range(count)] - coll = cls('test_collection', *items) + coll = cls(*items) assert len(coll) == count @@ -338,7 +353,7 @@ def test_ModelCollection_len(cls, clear_global, count): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_insert(cls, clear_global, sample_items): """Test inserting items at an index.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) new_item = MockModelItem(name='inserted', value=99.0) coll.insert(1, new_item) @@ -357,7 +372,7 @@ def test_ModelCollection_insert(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_insert_type_error(cls, clear_global, sample_items): """Test that inserting non-NewBase item raises TypeError.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) with pytest.raises(TypeError): coll.insert(0, 'not_a_newbase_object') @@ -369,7 +384,7 @@ def test_ModelCollection_insert_type_error(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_append(cls, clear_global, sample_items): """Test appending items.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) new_item = MockModelItem(name='appended', value=99.0) coll.append(new_item) @@ -389,7 +404,7 @@ def test_ModelCollection_append(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_data_property(cls, clear_global, sample_items): """Test that data property returns tuple of items.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) data = coll.data assert isinstance(data, tuple) assert len(data) == 3 @@ -409,7 +424,7 @@ def test_ModelCollection_sort(cls, clear_global): MockModelItem(name='a', value=1.0), MockModelItem(name='b', value=2.0), ] - coll = cls('test_collection', *items) + coll = cls(*items) coll.sort(lambda x: x.value.value) @@ -426,7 +441,7 @@ def test_ModelCollection_sort_reverse(cls, clear_global): MockModelItem(name='c', value=3.0), MockModelItem(name='b', value=2.0), ] - coll = cls('test_collection', *items) + coll = cls(*items) coll.sort(lambda x: x.value.value, reverse=True) @@ -442,10 +457,9 @@ def test_ModelCollection_sort_reverse(cls, clear_global): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_repr(cls, clear_global, sample_items): """Test string representation.""" - coll = cls('my_collection', *sample_items) + coll = cls(*sample_items) repr_str = repr(coll) assert cls.__name__ in repr_str - assert 'my_collection' in repr_str assert '3' in repr_str @@ -456,7 +470,7 @@ def test_ModelCollection_repr(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_iter(cls, clear_global, sample_items): """Test iteration over collection.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) names = [item.name for item in coll] assert names == ['item1', 'item2', 'item3'] @@ -469,7 +483,7 @@ def test_ModelCollection_iter(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_get_all_variables(cls, clear_global, sample_items): """Test getting all variables from items.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) variables = coll.get_all_variables() # Each MockModelItem has one Parameter (value) @@ -485,7 +499,7 @@ def test_ModelCollection_get_all_variables(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_get_all_parameters(cls, clear_global, sample_items): """Test getting all parameters from items.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) parameters = coll.get_all_parameters() assert len(parameters) == 3 @@ -503,7 +517,7 @@ def test_ModelCollection_get_fit_parameters(cls, clear_global, sample_items): # Fix one parameter so we can test filtering sample_items[0].value.fixed = True - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) fit_params = coll.get_fit_parameters() # All 3 parameters should be returned (get_fit_parameters on items) @@ -518,7 +532,7 @@ def test_ModelCollection_get_fit_parameters(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_graph_edges(cls, clear_global, sample_items): """Test that graph edges are correctly maintained.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) edges = global_object.map.get_edges(coll) assert len(edges) == 3 @@ -530,7 +544,7 @@ def test_ModelCollection_graph_edges(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_graph_edges_after_append(cls, clear_global, sample_items): """Test graph edges are updated after append.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) new_item = MockModelItem(name='new', value=99.0) coll.append(new_item) @@ -543,7 +557,7 @@ def test_ModelCollection_graph_edges_after_append(cls, clear_global, sample_item @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_graph_edges_after_delete(cls, clear_global, sample_items): """Test graph edges are updated after delete.""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) deleted_item = sample_items[1] del coll[1] @@ -560,7 +574,7 @@ def test_ModelCollection_graph_edges_after_delete(cls, clear_global, sample_item @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_extend(cls, clear_global, sample_items): """Test extend method (inherited from MutableSequence).""" - coll = cls('test_collection', sample_items[0]) + coll = cls(sample_items[0]) coll.extend([sample_items[1], sample_items[2]]) assert len(coll) == 3 @@ -571,7 +585,7 @@ def test_ModelCollection_extend(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_pop(cls, clear_global, sample_items): """Test pop method (inherited from MutableSequence).""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) popped = coll.pop() assert popped.name == 'item3' @@ -581,7 +595,7 @@ def test_ModelCollection_pop(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_pop_index(cls, clear_global, sample_items): """Test pop method with index (inherited from MutableSequence).""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) popped = coll.pop(0) assert popped.name == 'item1' @@ -592,7 +606,7 @@ def test_ModelCollection_pop_index(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_remove(cls, clear_global, sample_items): """Test remove method (inherited from MutableSequence).""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) item_to_remove = sample_items[1] coll.remove(item_to_remove) @@ -604,7 +618,7 @@ def test_ModelCollection_remove(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_clear(cls, clear_global, sample_items): """Test clear method (inherited from MutableSequence).""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) coll.clear() @@ -614,7 +628,7 @@ def test_ModelCollection_clear(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_reverse(cls, clear_global, sample_items): """Test reverse method (inherited from MutableSequence).""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) coll.reverse() @@ -626,7 +640,7 @@ def test_ModelCollection_reverse(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_count(cls, clear_global, sample_items): """Test count method (inherited from MutableSequence).""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) count = coll.count(sample_items[0]) assert count == 1 @@ -635,7 +649,7 @@ def test_ModelCollection_count(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_index(cls, clear_global, sample_items): """Test index method (inherited from MutableSequence).""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) idx = coll.index(sample_items[1]) assert idx == 1 @@ -648,7 +662,7 @@ def test_ModelCollection_index(cls, clear_global, sample_items): @pytest.mark.parametrize('cls', class_constructors) def test_ModelCollection_contains(cls, clear_global, sample_items): """Test __contains__ (in operator).""" - coll = cls('test_collection', *sample_items) + coll = cls(*sample_items) assert sample_items[0] in coll assert sample_items[1] in coll From c5dfc687a6f726c90cc8e896ceb2798dc0dfe569 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Thu, 18 Dec 2025 14:35:59 +0100 Subject: [PATCH 13/14] ruff --- src/easyscience/base_classes/model_collection.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/easyscience/base_classes/model_collection.py b/src/easyscience/base_classes/model_collection.py index b382f928..04561059 100644 --- a/src/easyscience/base_classes/model_collection.py +++ b/src/easyscience/base_classes/model_collection.py @@ -14,13 +14,11 @@ from typing import TypeVar from typing import overload -from ..variable import Parameter from .model_base import ModelBase from .new_base import NewBase if TYPE_CHECKING: - from ..fitting.calculators import CalculatorFactoryBase - from ..fitting.calculators import InterfaceFactoryTemplate + pass # Type alias for interface - supports both legacy and new factory types InterfaceType = 'InterfaceFactoryTemplate | CalculatorFactoryBase | None' From d0833d060e65644fbf647bf6dcbcf13647b978f8 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Thu, 18 Dec 2025 14:46:37 +0100 Subject: [PATCH 14/14] ruff format... --- src/easyscience/base_classes/model_collection.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/easyscience/base_classes/model_collection.py b/src/easyscience/base_classes/model_collection.py index 04561059..6f26085f 100644 --- a/src/easyscience/base_classes/model_collection.py +++ b/src/easyscience/base_classes/model_collection.py @@ -91,20 +91,20 @@ def interface(self) -> InterfaceType: @interface.setter def interface(self, new_interface: InterfaceType) -> None: """Set the interface and propagate to all items. - + :param new_interface: The interface to set (must be InterfaceFactoryTemplate, CalculatorFactoryBase, or None) :raises TypeError: If the interface is not a valid type """ # Import here to avoid circular imports from ..fitting.calculators import CalculatorFactoryBase from ..fitting.calculators import InterfaceFactoryTemplate - + if new_interface is not None and not isinstance(new_interface, (InterfaceFactoryTemplate, CalculatorFactoryBase)): raise TypeError( f'interface must be InterfaceFactoryTemplate, CalculatorFactoryBase, or None, ' f'got {type(new_interface).__name__}' ) - + self._interface = new_interface for item in self._data: if hasattr(item, 'interface'):