Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ax.core.base_trial import BaseTrial
from ax.core.batch_trial import BatchTrial
from ax.core.data import combine_data_rows_favoring_recent, Data
from ax.core.experiment_status import ExperimentStatus
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.core.objective import MultiObjective
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
self._optimization_config: OptimizationConfig | None = None
self._tracking_metrics: dict[str, Metric] = {}
self._time_created: datetime = datetime.now()
self._status: ExperimentStatus | None = None
self._trials: dict[int, BaseTrial] = {}
self._properties: dict[str, Any] = properties or {}

Expand Down Expand Up @@ -231,6 +233,27 @@ def experiment_type(self, experiment_type: str | None) -> None:
"""Set the type of the experiment."""
self._experiment_type = experiment_type

@property
def status(self) -> ExperimentStatus | None:
"""The current status of the experiment.

Status tracks the high-level lifecycle phase of the experiment:
DRAFT, INITIALIZATION, OPTIMIZATION, COMPLETED.

For new experiments, status defaults to DRAFT. For legacy experiments
that were created before the status field was added, status may be None.
"""
return self._status

@status.setter
def status(self, status: ExperimentStatus | None) -> None:
"""Set the status of the experiment.

Args:
status: The new status for the experiment.
"""
self._status = status

@property
def search_space(self) -> SearchSpace:
"""The search space for this experiment.
Expand Down
85 changes: 85 additions & 0 deletions ax/core/experiment_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

from enum import Enum


class ExperimentStatus(int, Enum):
"""Enum of experiment status.

General lifecycle of an experiment is:::

DRAFT --> INITIALIZATION --> OPTIMIZATION --> COMPLETED

Experiment is marked as ``DRAFT`` immediately upon its creation when
the experiment is still being configured (search space, optimization config, etc.).

Once the experiment is fully configured and begins initial exploration,
it transitions to ``INITIALIZATION``. This is typically when the first trials
are being generated to explore the search space.

After initial exploration completes (typically after some data has been collected),
the experiment transitions to ``OPTIMIZATION``, where Bayesian optimization or
other adaptive methods are used to find optimal configurations.

``COMPLETED`` indicates the experiment has successfully finished its objectives.

Note: This status tracks the high-level experiment lifecycle and is independent
of individual trial statuses. An experiment in OPTIMIZATION status may have
trials in various states (RUNNING, COMPLETED, FAILED, etc.).
"""

DRAFT = 0
INITIALIZATION = 1
OPTIMIZATION = 2
COMPLETED = 4

@property
def is_active(self) -> bool:
"""True if experiment is actively running trials."""
return (
self == ExperimentStatus.INITIALIZATION
or self == ExperimentStatus.OPTIMIZATION
)

@property
def is_draft(self) -> bool:
"""True if experiment is in draft phase."""
return self == ExperimentStatus.DRAFT

@property
def is_initialization(self) -> bool:
"""True if experiment is in initialization phase."""
return self == ExperimentStatus.INITIALIZATION

@property
def is_optimization(self) -> bool:
"""True if experiment is in optimization phase."""
return self == ExperimentStatus.OPTIMIZATION

@property
def is_completed(self) -> bool:
"""True if experiment has successfully completed."""
return self == ExperimentStatus.COMPLETED

def __format__(self, fmt: str) -> str:
"""Define `__format__` to avoid pulling the `__format__` from the `int`
mixin (since its better for statuses to show up as `DRAFT` than as
just an int that is difficult to interpret).

E.g. experiment representation with the overridden method is:
"Experiment(name='test', status=ExperimentStatus.DRAFT)".

Docs on enum formatting: https://docs.python.org/3/library/enum.html#others.
"""
return f"{self!s}"

def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
10 changes: 10 additions & 0 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.data import Data, sort_by_trial_index_and_arm_name
from ax.core.evaluations_to_data import raw_evaluations_to_data
from ax.core.experiment_status import ExperimentStatus
from ax.core.map_metric import MapMetric
from ax.core.metric import Metric
from ax.core.objective import MultiObjective, Objective
Expand Down Expand Up @@ -1853,6 +1854,15 @@ def test_to_df_with_relativize(self) -> None:
"relativized value",
)

def test_experiment_status_default(self) -> None:
"""Test that new experiments have None status for backward compatibility."""
self.assertIsNone(self.experiment.status)

def test_experiment_status_property(self) -> None:
"""Test the experiment status property getter and setter."""
self.experiment.status = ExperimentStatus.DRAFT
self.assertEqual(self.experiment.status, ExperimentStatus.DRAFT)


class ExperimentWithMapDataTest(TestCase):
def setUp(self) -> None:
Expand Down
55 changes: 55 additions & 0 deletions ax/core/tests/test_experiment_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from ax.core.experiment_status import ExperimentStatus
from ax.utils.common.testutils import TestCase


class TestExperimentStatus(TestCase):
"""Tests for the ExperimentStatus enum."""

def test_status_values(self) -> None:
"""Test that status enum values are correctly defined."""
self.assertEqual(ExperimentStatus.DRAFT.value, 0)
self.assertEqual(ExperimentStatus.INITIALIZATION.value, 1)
self.assertEqual(ExperimentStatus.OPTIMIZATION.value, 2)
self.assertEqual(ExperimentStatus.COMPLETED.value, 4)

def test_is_active(self) -> None:
"""Test the is_active property."""
# Active statuses
self.assertTrue(ExperimentStatus.INITIALIZATION.is_active)
self.assertTrue(ExperimentStatus.OPTIMIZATION.is_active)

# Inactive statuses
self.assertFalse(ExperimentStatus.DRAFT.is_active)
self.assertFalse(ExperimentStatus.COMPLETED.is_active)

def test_individual_status_checks(self) -> None:
"""Test individual status check properties."""
self.assertTrue(ExperimentStatus.DRAFT.is_draft)
self.assertFalse(ExperimentStatus.INITIALIZATION.is_draft)

self.assertTrue(ExperimentStatus.INITIALIZATION.is_initialization)
self.assertFalse(ExperimentStatus.OPTIMIZATION.is_initialization)

self.assertTrue(ExperimentStatus.OPTIMIZATION.is_optimization)
self.assertFalse(ExperimentStatus.COMPLETED.is_optimization)

self.assertTrue(ExperimentStatus.COMPLETED.is_completed)
self.assertFalse(ExperimentStatus.DRAFT.is_completed)

def test_format_and_repr(self) -> None:
"""Test __format__ and __repr__ methods."""
status = ExperimentStatus.DRAFT
self.assertEqual(f"{status}", "ExperimentStatus.DRAFT")
self.assertEqual(repr(status), "ExperimentStatus.DRAFT")

status = ExperimentStatus.OPTIMIZATION
self.assertEqual(f"{status}", "ExperimentStatus.OPTIMIZATION")
self.assertEqual(repr(status), "ExperimentStatus.OPTIMIZATION")
15 changes: 14 additions & 1 deletion ax/generation_strategy/center_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ax.core.arm import Arm
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.experiment_status import ExperimentStatus
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.parameter import DerivedParameter
Expand All @@ -29,17 +30,29 @@
class CenterGenerationNode(ExternalGenerationNode):
next_node_name: str

def __init__(self, next_node_name: str) -> None:
def __init__(
self,
next_node_name: str,
suggested_experiment_status: ExperimentStatus
| None = ExperimentStatus.INITIALIZATION,
) -> None:
"""A generation node that samples the center of the search space.
This generation node is only used to generate the first point of the experiment.
After one point is generated, it will transition to `next_node_name`.

If the generated point is a duplicate of an arm already attached to the
experiment, this will fallback to Sobol through the use of ``GenerationNode``
deduplication logic.

Args:
next_node_name: The name of the node to transition to after generating
the center point.
suggested_experiment_status: Optional suggested experiment status for this
node.
"""
super().__init__(
name="CenterOfSearchSpace",
suggested_experiment_status=suggested_experiment_status,
transition_criteria=[
AutoTransitionAfterGen(
transition_to=next_node_name,
Expand Down
5 changes: 5 additions & 0 deletions ax/generation_strategy/external_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ax.core.arm import Arm
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.experiment_status import ExperimentStatus
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.types import TParameterization
Expand Down Expand Up @@ -48,6 +49,7 @@ class ExternalGenerationNode(GenerationNode, ABC):
def __init__(
self,
name: str,
suggested_experiment_status: ExperimentStatus | None = None,
should_deduplicate: bool = True,
transition_criteria: Sequence[TransitionCriterion] | None = None,
) -> None:
Expand All @@ -59,6 +61,8 @@ def __init__(

Args:
name: Name of the generation node.
suggested_experiment_status: Optional suggested experiment status for this
node. Defaults to None if not specified.
should_deduplicate: Whether to deduplicate the generated points against
the existing trials on the experiment. If True, the duplicate points
will be discarded and re-generated up to 5 times, after which a
Expand All @@ -73,6 +77,7 @@ def __init__(
super().__init__(
name=name,
generator_specs=[],
suggested_experiment_status=suggested_experiment_status,
best_model_selector=None,
should_deduplicate=should_deduplicate,
transition_criteria=transition_criteria,
Expand Down
15 changes: 15 additions & 0 deletions ax/generation_strategy/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.experiment_status import ExperimentStatus
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.trial_status import TrialStatus
Expand Down Expand Up @@ -113,6 +114,10 @@ class GenerationNode(SerializationMixin, SortableBase):
store the most recent previous ``GenerationNode`` name.
should_skip: Whether to skip this node during generation time. Defaults to
False, and can only currently be set to True via ``NodeInputConstructors``
suggested_experiment_status: Optional ``ExperimentStatus`` that indicates
what the experiment's status should be once the experiment adds trials
using ``GeneratorRun``-s produced from this node. This is advisory only
and does not automatically update the experiment's status.
fallback_specs: Optional dict mapping expected exception types to `ModelSpec`
fallbacks used when gen fails.

Expand All @@ -135,6 +140,7 @@ class GenerationNode(SerializationMixin, SortableBase):
_previous_node_name: str | None = None
_trial_type: str | None = None
_should_skip: bool = False
_suggested_experiment_status: ExperimentStatus | None = None
fallback_specs: dict[type[Exception], GeneratorSpec]

# [TODO] Handle experiment passing more eloquently by enforcing experiment
Expand All @@ -156,6 +162,7 @@ def __init__(
previous_node_name: str | None = None,
trial_type: str | None = None,
should_skip: bool = False,
suggested_experiment_status: ExperimentStatus | None = None,
fallback_specs: dict[type[Exception], GeneratorSpec] | None = None,
) -> None:
self._name = name
Expand Down Expand Up @@ -188,6 +195,7 @@ def __init__(
self._previous_node_name = previous_node_name
self._trial_type = trial_type
self._should_skip = should_skip
self._suggested_experiment_status = suggested_experiment_status
self.fallback_specs = (
fallback_specs if fallback_specs is not None else DEFAULT_FALLBACK
)
Expand Down Expand Up @@ -366,6 +374,10 @@ def __repr__(self) -> str:
str_rep += (
f", transition_criteria={str(self._brief_transition_criteria_repr())}"
)
if self._suggested_experiment_status is not None:
str_rep += (
f", suggested_experiment_status={self._suggested_experiment_status!r}"
)
return f"{str_rep})"

def _fit(
Expand Down Expand Up @@ -999,6 +1011,7 @@ class GenerationStep:
whether to transition to the next step. If False, `num_trials` and
`min_trials_observed` will only count trials generatd by this step. If True,
they will count all trials in the experiment (of corresponding statuses).
suggested_experiment_status: The suggested experiment status for this step.

Note for developers: by "generator" here we really mean an ``Adapter`` object, which
contains a ``Generator`` under the hood. We call it "generator" here to simplify and
Expand All @@ -1019,6 +1032,7 @@ def __new__(
use_all_trials_in_exp: bool = False,
use_update: bool = False, # DEPRECATED.
index: int = -1, # Index of this step, set internally.
suggested_experiment_status: ExperimentStatus | None = None,
# Deprecated arguments for backwards compatibility.
model_kwargs: dict[str, Any] | None = None,
model_gen_kwargs: dict[str, Any] | None = None,
Expand Down Expand Up @@ -1135,6 +1149,7 @@ def __new__(
step_index=index, generator_name=resolved_generator_name
),
generator_specs=[generator_spec],
suggested_experiment_status=suggested_experiment_status,
should_deduplicate=should_deduplicate,
transition_criteria=transition_criteria,
)
Expand Down
Loading