diff --git a/ax/api/utils/generation_strategy_dispatch.py b/ax/api/utils/generation_strategy_dispatch.py index f3f60d2fa35..0acc03f46fc 100644 --- a/ax/api/utils/generation_strategy_dispatch.py +++ b/ax/api/utils/generation_strategy_dispatch.py @@ -12,6 +12,7 @@ import torch from ax.adapter.registry import Generators from ax.api.utils.structs import GenerationStrategyDispatchStruct +from ax.core.experiment_status import ExperimentStatus from ax.core.trial_status import TrialStatus from ax.exceptions.core import UnsupportedError, UserInputError from ax.generation_strategy.center_generation_node import CenterGenerationNode @@ -95,6 +96,7 @@ def _get_sobol_node( ], transition_criteria=transition_criteria, should_deduplicate=True, + suggested_experiment_status=ExperimentStatus.INITIALIZATION, ) @@ -175,6 +177,7 @@ def _get_mbm_node( ) ], should_deduplicate=True, + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, ), mbm_name @@ -225,6 +228,7 @@ def choose_generation_strategy( generator_kwargs={"seed": struct.initialization_random_seed}, ) ], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, ) ] gs_name = "QuasiRandomSearch" diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 9cdae91e135..3662eb51cd2 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -29,6 +29,7 @@ from ax.core.base_trial import BaseTrial from ax.core.batch_trial import BatchTrial from ax.core.data import combine_data_rows_favoring_recent, Data +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric, MetricFetchE, MetricFetchResult from ax.core.objective import MultiObjective @@ -146,6 +147,7 @@ def __init__( self._optimization_config: OptimizationConfig | None = None self._tracking_metrics: dict[str, Metric] = {} self._time_created: datetime = datetime.now() + self._status: ExperimentStatus | None = None self._trials: dict[int, BaseTrial] = {} self._properties: dict[str, Any] = properties or {} @@ -231,6 +233,27 @@ def experiment_type(self, experiment_type: str | None) -> None: """Set the type of the experiment.""" self._experiment_type = experiment_type + @property + def status(self) -> ExperimentStatus | None: + """The current status of the experiment. + + Status tracks the high-level lifecycle phase of the experiment: + DRAFT, INITIALIZATION, OPTIMIZATION, COMPLETED. + + For new experiments, status defaults to DRAFT. For legacy experiments + that were created before the status field was added, status may be None. + """ + return self._status + + @status.setter + def status(self, status: ExperimentStatus | None) -> None: + """Set the status of the experiment. + + Args: + status: The new status for the experiment. + """ + self._status = status + @property def search_space(self) -> SearchSpace: """The search space for this experiment. diff --git a/ax/core/experiment_status.py b/ax/core/experiment_status.py new file mode 100644 index 00000000000..25a11c81bcd --- /dev/null +++ b/ax/core/experiment_status.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from __future__ import annotations + +from enum import Enum + + +class ExperimentStatus(int, Enum): + """Enum of experiment status. + + General lifecycle of an experiment is::: + + DRAFT --> INITIALIZATION --> OPTIMIZATION --> COMPLETED + + Experiment is marked as ``DRAFT`` immediately upon its creation when + the experiment is still being configured (search space, optimization config, etc.). + + Once the experiment is fully configured and begins initial exploration, + it transitions to ``INITIALIZATION``. This is typically when the first trials + are being generated to explore the search space. + + After initial exploration completes (typically after some data has been collected), + the experiment transitions to ``OPTIMIZATION``, where Bayesian optimization or + other adaptive methods are used to find optimal configurations. + + ``COMPLETED`` indicates the experiment has successfully finished its objectives. + + Note: This status tracks the high-level experiment lifecycle and is independent + of individual trial statuses. An experiment in OPTIMIZATION status may have + trials in various states (RUNNING, COMPLETED, FAILED, etc.). + """ + + DRAFT = 0 + INITIALIZATION = 1 + OPTIMIZATION = 2 + COMPLETED = 4 + + @property + def is_active(self) -> bool: + """True if experiment is actively running trials.""" + return ( + self == ExperimentStatus.INITIALIZATION + or self == ExperimentStatus.OPTIMIZATION + ) + + @property + def is_draft(self) -> bool: + """True if experiment is in draft phase.""" + return self == ExperimentStatus.DRAFT + + @property + def is_initialization(self) -> bool: + """True if experiment is in initialization phase.""" + return self == ExperimentStatus.INITIALIZATION + + @property + def is_optimization(self) -> bool: + """True if experiment is in optimization phase.""" + return self == ExperimentStatus.OPTIMIZATION + + @property + def is_completed(self) -> bool: + """True if experiment has successfully completed.""" + return self == ExperimentStatus.COMPLETED + + def __format__(self, fmt: str) -> str: + """Define `__format__` to avoid pulling the `__format__` from the `int` + mixin (since its better for statuses to show up as `DRAFT` than as + just an int that is difficult to interpret). + + E.g. experiment representation with the overridden method is: + "Experiment(name='test', status=ExperimentStatus.DRAFT)". + + Docs on enum formatting: https://docs.python.org/3/library/enum.html#others. + """ + return f"{self!s}" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index 9903fbeb578..5fd25d61694 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -17,6 +17,7 @@ from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.data import Data, sort_by_trial_index_and_arm_name from ax.core.evaluations_to_data import raw_evaluations_to_data +from ax.core.experiment_status import ExperimentStatus from ax.core.map_metric import MapMetric from ax.core.metric import Metric from ax.core.objective import MultiObjective, Objective @@ -1853,6 +1854,15 @@ def test_to_df_with_relativize(self) -> None: "relativized value", ) + def test_experiment_status_default(self) -> None: + """Test that new experiments have None status for backward compatibility.""" + self.assertIsNone(self.experiment.status) + + def test_experiment_status_property(self) -> None: + """Test the experiment status property getter and setter.""" + self.experiment.status = ExperimentStatus.DRAFT + self.assertEqual(self.experiment.status, ExperimentStatus.DRAFT) + class ExperimentWithMapDataTest(TestCase): def setUp(self) -> None: diff --git a/ax/core/tests/test_experiment_status.py b/ax/core/tests/test_experiment_status.py new file mode 100644 index 00000000000..06ae415d900 --- /dev/null +++ b/ax/core/tests/test_experiment_status.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.core.experiment_status import ExperimentStatus +from ax.utils.common.testutils import TestCase + + +class TestExperimentStatus(TestCase): + """Tests for the ExperimentStatus enum.""" + + def test_status_values(self) -> None: + """Test that status enum values are correctly defined.""" + self.assertEqual(ExperimentStatus.DRAFT.value, 0) + self.assertEqual(ExperimentStatus.INITIALIZATION.value, 1) + self.assertEqual(ExperimentStatus.OPTIMIZATION.value, 2) + self.assertEqual(ExperimentStatus.COMPLETED.value, 4) + + def test_is_active(self) -> None: + """Test the is_active property.""" + # Active statuses + self.assertTrue(ExperimentStatus.INITIALIZATION.is_active) + self.assertTrue(ExperimentStatus.OPTIMIZATION.is_active) + + # Inactive statuses + self.assertFalse(ExperimentStatus.DRAFT.is_active) + self.assertFalse(ExperimentStatus.COMPLETED.is_active) + + def test_individual_status_checks(self) -> None: + """Test individual status check properties.""" + self.assertTrue(ExperimentStatus.DRAFT.is_draft) + self.assertFalse(ExperimentStatus.INITIALIZATION.is_draft) + + self.assertTrue(ExperimentStatus.INITIALIZATION.is_initialization) + self.assertFalse(ExperimentStatus.OPTIMIZATION.is_initialization) + + self.assertTrue(ExperimentStatus.OPTIMIZATION.is_optimization) + self.assertFalse(ExperimentStatus.COMPLETED.is_optimization) + + self.assertTrue(ExperimentStatus.COMPLETED.is_completed) + self.assertFalse(ExperimentStatus.DRAFT.is_completed) + + def test_format_and_repr(self) -> None: + """Test __format__ and __repr__ methods.""" + status = ExperimentStatus.DRAFT + self.assertEqual(f"{status}", "ExperimentStatus.DRAFT") + self.assertEqual(repr(status), "ExperimentStatus.DRAFT") + + status = ExperimentStatus.OPTIMIZATION + self.assertEqual(f"{status}", "ExperimentStatus.OPTIMIZATION") + self.assertEqual(repr(status), "ExperimentStatus.OPTIMIZATION") diff --git a/ax/generation_strategy/center_generation_node.py b/ax/generation_strategy/center_generation_node.py index d8f1e8ee90e..201cf4ed646 100644 --- a/ax/generation_strategy/center_generation_node.py +++ b/ax/generation_strategy/center_generation_node.py @@ -13,6 +13,7 @@ from ax.core.arm import Arm from ax.core.data import Data from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures from ax.core.parameter import DerivedParameter @@ -29,7 +30,12 @@ class CenterGenerationNode(ExternalGenerationNode): next_node_name: str - def __init__(self, next_node_name: str) -> None: + def __init__( + self, + next_node_name: str, + suggested_experiment_status: ExperimentStatus + | None = ExperimentStatus.INITIALIZATION, + ) -> None: """A generation node that samples the center of the search space. This generation node is only used to generate the first point of the experiment. After one point is generated, it will transition to `next_node_name`. @@ -37,9 +43,16 @@ def __init__(self, next_node_name: str) -> None: If the generated point is a duplicate of an arm already attached to the experiment, this will fallback to Sobol through the use of ``GenerationNode`` deduplication logic. + + Args: + next_node_name: The name of the node to transition to after generating + the center point. + suggested_experiment_status: Optional suggested experiment status for this + node. """ super().__init__( name="CenterOfSearchSpace", + suggested_experiment_status=suggested_experiment_status, transition_criteria=[ AutoTransitionAfterGen( transition_to=next_node_name, diff --git a/ax/generation_strategy/external_generation_node.py b/ax/generation_strategy/external_generation_node.py index b5192a97b9d..6e153848895 100644 --- a/ax/generation_strategy/external_generation_node.py +++ b/ax/generation_strategy/external_generation_node.py @@ -14,6 +14,7 @@ from ax.core.arm import Arm from ax.core.data import Data from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures from ax.core.types import TParameterization @@ -48,6 +49,7 @@ class ExternalGenerationNode(GenerationNode, ABC): def __init__( self, name: str, + suggested_experiment_status: ExperimentStatus | None = None, should_deduplicate: bool = True, transition_criteria: Sequence[TransitionCriterion] | None = None, ) -> None: @@ -59,6 +61,8 @@ def __init__( Args: name: Name of the generation node. + suggested_experiment_status: Optional suggested experiment status for this + node. Defaults to None if not specified. should_deduplicate: Whether to deduplicate the generated points against the existing trials on the experiment. If True, the duplicate points will be discarded and re-generated up to 5 times, after which a @@ -73,6 +77,7 @@ def __init__( super().__init__( name=name, generator_specs=[], + suggested_experiment_status=suggested_experiment_status, best_model_selector=None, should_deduplicate=should_deduplicate, transition_criteria=transition_criteria, diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index a3efebc8cd7..1af6c060a02 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -15,6 +15,7 @@ from ax.core.data import Data from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures from ax.core.trial_status import TrialStatus @@ -113,6 +114,10 @@ class GenerationNode(SerializationMixin, SortableBase): store the most recent previous ``GenerationNode`` name. should_skip: Whether to skip this node during generation time. Defaults to False, and can only currently be set to True via ``NodeInputConstructors`` + suggested_experiment_status: Optional ``ExperimentStatus`` that indicates + what the experiment's status should be once the experiment adds trials + using ``GeneratorRun``-s produced from this node. This is advisory only + and does not automatically update the experiment's status. fallback_specs: Optional dict mapping expected exception types to `ModelSpec` fallbacks used when gen fails. @@ -135,6 +140,7 @@ class GenerationNode(SerializationMixin, SortableBase): _previous_node_name: str | None = None _trial_type: str | None = None _should_skip: bool = False + _suggested_experiment_status: ExperimentStatus | None = None fallback_specs: dict[type[Exception], GeneratorSpec] # [TODO] Handle experiment passing more eloquently by enforcing experiment @@ -156,6 +162,7 @@ def __init__( previous_node_name: str | None = None, trial_type: str | None = None, should_skip: bool = False, + suggested_experiment_status: ExperimentStatus | None = None, fallback_specs: dict[type[Exception], GeneratorSpec] | None = None, ) -> None: self._name = name @@ -188,6 +195,7 @@ def __init__( self._previous_node_name = previous_node_name self._trial_type = trial_type self._should_skip = should_skip + self._suggested_experiment_status = suggested_experiment_status self.fallback_specs = ( fallback_specs if fallback_specs is not None else DEFAULT_FALLBACK ) @@ -366,6 +374,10 @@ def __repr__(self) -> str: str_rep += ( f", transition_criteria={str(self._brief_transition_criteria_repr())}" ) + if self._suggested_experiment_status is not None: + str_rep += ( + f", suggested_experiment_status={self._suggested_experiment_status!r}" + ) return f"{str_rep})" def _fit( @@ -999,6 +1011,7 @@ class GenerationStep: whether to transition to the next step. If False, `num_trials` and `min_trials_observed` will only count trials generatd by this step. If True, they will count all trials in the experiment (of corresponding statuses). + suggested_experiment_status: The suggested experiment status for this step. Note for developers: by "generator" here we really mean an ``Adapter`` object, which contains a ``Generator`` under the hood. We call it "generator" here to simplify and @@ -1019,6 +1032,7 @@ def __new__( use_all_trials_in_exp: bool = False, use_update: bool = False, # DEPRECATED. index: int = -1, # Index of this step, set internally. + suggested_experiment_status: ExperimentStatus | None = None, # Deprecated arguments for backwards compatibility. model_kwargs: dict[str, Any] | None = None, model_gen_kwargs: dict[str, Any] | None = None, @@ -1135,6 +1149,7 @@ def __new__( step_index=index, generator_name=resolved_generator_name ), generator_specs=[generator_spec], + suggested_experiment_status=suggested_experiment_status, should_deduplicate=should_deduplicate, transition_criteria=transition_criteria, ) diff --git a/ax/generation_strategy/tests/test_generation_node.py b/ax/generation_strategy/tests/test_generation_node.py index 3cae54f0b27..deaec393afc 100644 --- a/ax/generation_strategy/tests/test_generation_node.py +++ b/ax/generation_strategy/tests/test_generation_node.py @@ -11,6 +11,7 @@ import torch from ax.adapter.factory import get_sobol from ax.adapter.registry import Generators +from ax.core.experiment_status import ExperimentStatus from ax.core.observation import ObservationFeatures from ax.core.trial_status import TrialStatus from ax.exceptions.core import UserInputError @@ -53,13 +54,16 @@ def setUp(self) -> None: generator_gen_kwargs={}, ) self.sobol_generation_node = GenerationNode( - name="test", generator_specs=[self.sobol_generator_spec] + name="test", + generator_specs=[self.sobol_generator_spec], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, ) self.branin_experiment = get_branin_experiment(with_completed_trial=True) self.branin_data = self.branin_experiment.lookup_data() self.node_short = GenerationNode( name="test", generator_specs=[self.sobol_generator_spec], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, trial_type=Keys.SHORT_RUN, ) @@ -97,6 +101,30 @@ def test_init(self) -> None: self.assertEqual(node.generator_specs, mbm_specs) self.assertIs(node.best_model_selector, model_selector) + def test_suggested_experiment_status(self) -> None: + """Test that suggested_experiment_status is properly set and accessible.""" + with self.subTest("initialization set"): + self.assertEqual( + self.sobol_generation_node._suggested_experiment_status, + ExperimentStatus.INITIALIZATION, + ) + + with self.subTest("default None when not provided"): + node_without_state = GenerationNode( + name="test", + generator_specs=[self.sobol_generator_spec], + ) + self.assertIsNone(node_without_state._suggested_experiment_status) + + with self.subTest("__repr__ includes status when set"): + repr_str = repr(self.sobol_generation_node) + self.assertIn("suggested_experiment_status", repr_str) + self.assertIn("INITIALIZATION", repr_str) + + with self.subTest("__repr__ excludes status when None"): + repr_str_without = repr(node_without_state) + self.assertNotIn("suggested_experiment_status", repr_str_without) + def test_input_constructor_none(self) -> None: self.assertEqual(self.sobol_generation_node._input_constructors, {}) self.assertEqual(self.sobol_generation_node.input_constructors, {}) @@ -320,6 +348,7 @@ def test_node_string_representation(self) -> None: generator_specs=[ self.mbm_generator_spec, ], + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, transition_criteria=[ MinTrials( threshold=5, @@ -335,7 +364,8 @@ def test_node_string_representation(self) -> None: "GenerationNode(name='test', " "generator_specs=[GeneratorSpec(generator_enum=BoTorch, " "generator_key_override=None)], " - "transition_criteria=[MinTrials(transition_to='next_node')])", + "transition_criteria=[MinTrials(transition_to='next_node')], " + "suggested_experiment_status=ExperimentStatus.OPTIMIZATION)", ) def test_single_fixed_features(self) -> None: diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 0beccd47c09..7c617316f15 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -889,6 +889,15 @@ def generation_node_from_json( if "trial_type" in generation_node_json.keys() else None ), + suggested_experiment_status=( + object_from_json( + object_json=generation_node_json.pop("suggested_experiment_status"), + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + if "suggested_experiment_status" in generation_node_json.keys() + else None # Default for old records without the field + ), ) diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index bfd6157f129..0b2ef485563 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -411,6 +411,7 @@ def generation_node_to_dict(generation_node: GenerationNode) -> dict[str, Any]: "generator_spec_to_gen_from": generation_node._generator_spec_to_gen_from, "previous_node_name": generation_node._previous_node_name, "trial_type": generation_node._trial_type, + "suggested_experiment_status": generation_node._suggested_experiment_status, # need to manually encode input constructors because the key is an enum. # Our encoding and decoding logic in object_to_json and object_from_json # doesn't recursively encode/decode the keys of dictionaries. diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 842a90968be..987c008d022 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -29,6 +29,7 @@ from ax.core.batch_trial import AbandonedArm, BatchTrial from ax.core.data import Data from ax.core.evaluations_to_data import DataType +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.map_metric import MapMetric from ax.core.metric import Metric @@ -315,6 +316,7 @@ "DerivedParameter": DerivedParameter, "DomainType": DomainType, "Experiment": Experiment, + "ExperimentStatus": ExperimentStatus, "FactorialMetric": FactorialMetric, "FilterFeatures": FilterFeatures, "FixedParameter": fixed_parameter_from_json, diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 3ba9ef21be3..dd47b85dea3 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -408,6 +408,7 @@ def experiment_from_sqa( _cast_arm_parameters(sq, experiment.search_space) experiment._register_arm(sq) experiment._time_created = experiment_sqa.time_created + experiment._status = experiment_sqa.status experiment._experiment_type = self.get_enum_name( value=experiment_sqa.experiment_type, enum=self.config.experiment_type_enum ) diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 16b7a58b043..36756867491 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -246,6 +246,7 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: status_quo_name=status_quo_name, status_quo_parameters=status_quo_parameters, time_created=experiment.time_created, + status=experiment.status, experiment_type=experiment_type, metrics=optimization_metrics + tracking_metrics, parameters=parameters, diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index 7b8fe5dc657..dff59cb9e7a 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -13,6 +13,7 @@ from typing import Any from ax.core.evaluations_to_data import DataType +from ax.core.experiment_status import ExperimentStatus from ax.core.parameter import ParameterType from ax.core.trial_status import TrialStatus from ax.core.types import ( @@ -375,6 +376,9 @@ class SQAExperiment(Base): JSONEncodedTextDict ) time_created: Column[datetime] = Column(IntTimestamp, nullable=False) + status: Column[ExperimentStatus | None] = Column( + IntEnum(ExperimentStatus), nullable=True + ) default_trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) default_data_type: Column[DataType] = Column(IntEnum(DataType), nullable=True) # pyre-fixme[8]: Incompatible attribute type [8]: Attribute diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index d0a9df77ce4..9d72ea425dc 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -32,6 +32,7 @@ TransferLearningMetadata, ) from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric from ax.core.objective import MultiObjective, Objective, ScalarizedObjective @@ -614,6 +615,27 @@ def test_saving_an_experiment_with_type_errors_with_missing_enum_value( config=SQAConfig(experiment_type_enum=MockExperimentTypeEnum), ) + def test_experiment_status_save_load(self) -> None: + """Test that experiment status is correctly saved and loaded.""" + # Test None status (backward compatibility) + with self.subTest(status=None): + exp = get_experiment() + exp._name = "test_exp_status_none" + exp.status = None + save_experiment(exp) + loaded_exp = load_experiment(exp.name) + self.assertEqual(loaded_exp.status, None) + + # Test all ExperimentStatus enum values + for status in ExperimentStatus: + with self.subTest(status=status): + exp = get_experiment() + exp._name = f"test_exp_status_{status.name.lower()}" + exp.status = status + save_experiment(exp) + loaded_exp = load_experiment(exp.name) + self.assertEqual(loaded_exp.status, status) + def test_load_experiment_trials_in_batches(self) -> None: for _ in range(4): self.experiment.new_trial()