Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ax/adapter/tests/test_model_fit_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_model_fit_metrics(self) -> None:
orchestrator = Orchestrator(
experiment=self.branin_experiment,
generation_strategy=self.generation_strategy,
options=OrchestratorOptions(max_pending_trials=NUM_SOBOL),
options=OrchestratorOptions(max_concurrent_trials=NUM_SOBOL),
)
# need to run some trials to initialize the Adapter
orchestrator.run_n_trials(max_trials=NUM_SOBOL + 1)
Expand Down
10 changes: 5 additions & 5 deletions ax/analysis/healthcheck/early_stopping_healthcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ax.early_stopping.dispatch import get_default_ess_or_none
from ax.early_stopping.experiment_replay import (
estimate_hypothetical_early_stopping_savings,
MAX_PENDING_TRIALS,
MAX_CONCURRENT_TRIALS,
MIN_SAVINGS_THRESHOLD,
)
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
self,
early_stopping_strategy: BaseEarlyStoppingStrategy | None = None,
min_savings_threshold: float = MIN_SAVINGS_THRESHOLD,
max_pending_trials: int = MAX_PENDING_TRIALS,
max_concurrent_trials: int = MAX_CONCURRENT_TRIALS,
auto_early_stopping_config: AutoEarlyStoppingConfig | None = None,
nudge_additional_info: str | None = None,
) -> None:
Expand All @@ -95,7 +95,7 @@ def __init__(
single-objective unconstrained experiments.
min_savings_threshold: Minimum savings threshold to suggest early
stopping. Default is 0.1 (10% savings).
max_pending_trials: Maximum number of pending trials for replay
max_concurrent_trials: Maximum number of concurrent trials for replay
orchestrator. Default is 5.
auto_early_stopping_config: A string for configuring automated early
stopping strategy.
Expand All @@ -111,7 +111,7 @@ def __init__(
"""
self.early_stopping_strategy = early_stopping_strategy
self.min_savings_threshold = min_savings_threshold
self.max_pending_trials = max_pending_trials
self.max_concurrent_trials = max_concurrent_trials
self.auto_early_stopping_config = auto_early_stopping_config
self.nudge_additional_info = nudge_additional_info

Expand Down Expand Up @@ -409,7 +409,7 @@ def _report_early_stopping_nudge(
savings = estimate_hypothetical_early_stopping_savings(
experiment=experiment,
metric=metric,
max_pending_trials=self.max_pending_trials,
max_concurrent_trials=self.max_concurrent_trials,
)
except Exception as e:
# Exception is raised when estimate_hypothetical_early_stopping_savings
Expand Down
2 changes: 1 addition & 1 deletion ax/analysis/healthcheck/tests/test_complexity_rating.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_unsupported_configurations(self) -> None:
(
"invalid_failure_rate_check",
OrchestratorOptions(
max_pending_trials=10,
max_concurrent_trials=10,
min_failed_trials_for_failure_rate_check=50,
),
{"user_supplied_max_trials": 100, "uses_standard_api": True},
Expand Down
22 changes: 19 additions & 3 deletions ax/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# pyre-strict

import json
import warnings
from collections.abc import Iterable, Sequence
from logging import Logger
from typing import Any, Literal
Expand Down Expand Up @@ -43,7 +44,7 @@
BaseEarlyStoppingStrategy,
PercentileEarlyStoppingStrategy,
)
from ax.exceptions.core import ObjectNotFoundError, UnsupportedError
from ax.exceptions.core import ObjectNotFoundError, UnsupportedError, UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.orchestration.orchestrator import Orchestrator, OrchestratorOptions
from ax.service.utils.best_point_mixin import BestPointMixin
Expand Down Expand Up @@ -683,9 +684,11 @@ def mark_trial_early_stopped(self, trial_index: int) -> None:
def run_trials(
self,
max_trials: int,
parallelism: int = 1,
concurrency: int = 1,
tolerated_trial_failure_rate: float = 0.5,
initial_seconds_between_polls: int = 1,
# Deprecated argument for backwards compatibility.
parallelism: int | None = None,
) -> None:
"""
Run maximum_trials trials in a loop by creating an ephemeral Orchestrator under
Expand All @@ -694,12 +697,25 @@ def run_trials(

Saves to database on completion if ``storage_config`` is present.
"""
# Handle deprecated `parallelism` argument.
if parallelism is not None:
warnings.warn(
"`parallelism` is deprecated and will be removed in Ax 1.4. "
"Use `concurrency` instead.",
DeprecationWarning,
stacklevel=2,
)
if concurrency != 1:
raise UserInputError(
"Cannot specify both `parallelism` and `concurrency`."
)
concurrency = parallelism

orchestrator = Orchestrator(
experiment=self._experiment,
generation_strategy=self._generation_strategy_or_choose(),
options=OrchestratorOptions(
max_pending_trials=parallelism,
max_concurrent_trials=concurrency,
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
init_seconds_between_polls=initial_seconds_between_polls,
),
Expand Down
12 changes: 6 additions & 6 deletions ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def get_benchmark_runner(
(used to generate data) and ``step_runtime_function`` (used to
determine timing for the simulator).
max_concurrency: The maximum number of trials that can be run concurrently.
Typically, ``max_pending_trials`` from ``OrchestratorOptions``, which are
Typically, ``max_concurrent_trials`` from ``OrchestratorOptions``, which are
stored on the ``BenchmarkMethod``.
force_use_simulated_backend: Whether to use a simulated backend even if
``max_concurrency`` is 1 and ``problem.step_runtime_function`` is
Expand Down Expand Up @@ -226,7 +226,7 @@ def get_oracle_experiment_from_params(
def get_benchmark_orchestrator_options(
batch_size: int | None,
run_trials_in_batches: bool,
max_pending_trials: int,
max_concurrent_trials: int,
early_stopping_strategy: BaseEarlyStoppingStrategy | None,
include_status_quo: bool = False,
logging_level: int = DEFAULT_LOG_LEVEL,
Expand All @@ -240,7 +240,7 @@ def get_benchmark_orchestrator_options(
for high-throughput settings where there are many trials and
generating them in bulk reduces overhead (not to be confused with
`BatchTrial`s, which are different).
max_pending_trials: The maximum number of pending trials allowed.
max_concurrent_trials: The maximum number of pending trials allowed.
early_stopping_strategy: The early stopping strategy to use (if any).
include_status_quo: Whether to include the status quo in each trial.
logging_level: The logging level to use for the Orchestrator.
Expand All @@ -255,7 +255,7 @@ def get_benchmark_orchestrator_options(
return OrchestratorOptions(
# No new candidates can be generated while any are pending.
# If batched, an entire batch must finish before the next can be generated.
max_pending_trials=max_pending_trials,
max_concurrent_trials=max_concurrent_trials,
# Do not throttle, as is often necessary when polling real endpoints
init_seconds_between_polls=0,
min_seconds_before_poll=0,
Expand Down Expand Up @@ -568,14 +568,14 @@ def run_optimization_with_orchestrator(
orchestrator_options = get_benchmark_orchestrator_options(
batch_size=method.batch_size,
run_trials_in_batches=run_trials_in_batches,
max_pending_trials=method.max_pending_trials,
max_concurrent_trials=method.max_concurrent_trials,
early_stopping_strategy=method.early_stopping_strategy,
include_status_quo=sq_arm is not None,
logging_level=orchestrator_logging_level,
)
runner = get_benchmark_runner(
problem=problem,
max_concurrency=orchestrator_options.max_pending_trials,
max_concurrency=orchestrator_options.max_concurrent_trials,
force_use_simulated_backend=method.early_stopping_strategy is not None,
)
experiment = Experiment(
Expand Down
26 changes: 22 additions & 4 deletions ax/benchmark/benchmark_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

# pyre-strict

from dataclasses import dataclass
import warnings
from dataclasses import dataclass, field

from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.utils.common.base import Base

Expand All @@ -16,25 +18,41 @@
class BenchmarkMethod(Base):
"""Benchmark method, represented in terms of Ax generation strategy (which tells us
which models to use when) and Orchestrator options (which tell us extra execution
information like maximum parallelism, early stopping configuration, etc.).
information like maximum concurrency, early stopping configuration, etc.).

Args:
name: String description. Defaults to the name of the generation strategy.
generation_strategy: The `GenerationStrategy` to use.
batch_size: Number of arms per trial. Defaults to 1. If greater than 1,
trials are ``BatchTrial``s; otherwise, they are ``Trial``s.
Passed to ``OrchestratorOptions``.
max_pending_trials: Passed to ``OrchestratorOptions``.
max_concurrent_trials: Passed to ``OrchestratorOptions``.
early_stopping_strategy: Passed to ``OrchestratorOptions``.
"""

name: str = "DEFAULT"
generation_strategy: GenerationStrategy
# Options for the Orchestrator.
batch_size: int | None = 1
max_pending_trials: int = 1
max_concurrent_trials: int = 1
early_stopping_strategy: BaseEarlyStoppingStrategy | None = None
# Deprecated
max_pending_trials: int | None = field(default=None, repr=False)

def __post_init__(self) -> None:
if self.max_pending_trials is not None:
warnings.warn(
"`max_pending_trials` is deprecated and will be removed in Ax 1.4. "
"Use `max_concurrent_trials` instead.",
DeprecationWarning,
stacklevel=2,
)
if self.max_concurrent_trials != 1:
raise UserInputError(
"Cannot specify both `max_pending_trials` and "
"`max_concurrent_trials`."
)
object.__setattr__(self, "max_concurrent_trials", self.max_pending_trials)
object.__setattr__(self, "max_pending_trials", None)
if self.name == "DEFAULT":
self.name = self.generation_strategy.name
2 changes: 1 addition & 1 deletion ax/benchmark/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class BenchmarkRunner(Runner):
step_runtime_function: A function that takes in parameters
(in ``TParameterization`` format) and returns the runtime of a step.
max_concurrency: The maximum number of trials that can be running at a
given time. Typically, this is ``max_pending_trials`` from the
given time. Typically, this is ``max_concurrent_trials`` from the
``orchestrator_options`` on the ``BenchmarkMethod``.
force_use_simulated_backend: If True, use the simulated backend even if
``max_concurrency`` is 1 and ``step_runtime_function`` is None. This
Expand Down
4 changes: 2 additions & 2 deletions ax/benchmark/testing/benchmark_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,14 +293,14 @@ def get_discrete_search_space(n_values: int = 20) -> SearchSpace:

def get_async_benchmark_method(
early_stopping_strategy: BaseEarlyStoppingStrategy | None = None,
max_pending_trials: int = 2,
max_concurrent_trials: int = 2,
) -> BenchmarkMethod:
gs = GenerationStrategy(
nodes=[DeterministicGenerationNode(search_space=get_discrete_search_space())]
)
return BenchmarkMethod(
generation_strategy=gs,
max_pending_trials=max_pending_trials,
max_concurrent_trials=max_concurrent_trials,
batch_size=1,
early_stopping_strategy=early_stopping_strategy,
)
Expand Down
12 changes: 6 additions & 6 deletions ax/benchmark/tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,10 +619,10 @@ def test_early_stopping(self) -> None:
}
self.assertEqual(start_times, expected_start_times)

with self.subTest("max_pending_trials = 1"):
with self.subTest("max_concurrent_trials = 1"):
method = get_async_benchmark_method(
early_stopping_strategy=early_stopping_strategy,
max_pending_trials=1,
max_concurrent_trials=1,
)
experiment = self.run_optimization_with_orchestrator(
problem=problem, method=method, seed=0
Expand Down Expand Up @@ -652,7 +652,7 @@ def test_early_stopping(self) -> None:
self.assertEqual(max_run, {0: 4, 1: 2, 2: 2, 3: 2})

def test_replication_variable_runtime(self) -> None:
method = get_async_benchmark_method(max_pending_trials=1)
method = get_async_benchmark_method(max_concurrent_trials=1)
for map_data in [False, True]:
with self.subTest(map_data=map_data):
problem = get_async_benchmark_problem(
Expand Down Expand Up @@ -1037,17 +1037,17 @@ def test_get_benchmark_orchestrator_options(self) -> None:
generation_strategy=get_sobol_mbm_generation_strategy(
model_cls=SingleTaskGP, acquisition_cls=qLogNoisyExpectedImprovement
),
max_pending_trials=2,
max_concurrent_trials=2,
batch_size=batch_size,
)
orchestrator_options = get_benchmark_orchestrator_options(
batch_size=none_throws(method.batch_size),
run_trials_in_batches=False,
max_pending_trials=method.max_pending_trials,
max_concurrent_trials=method.max_concurrent_trials,
early_stopping_strategy=method.early_stopping_strategy,
include_status_quo=include_sq,
)
self.assertEqual(orchestrator_options.max_pending_trials, 2)
self.assertEqual(orchestrator_options.max_concurrent_trials, 2)
self.assertEqual(orchestrator_options.init_seconds_between_polls, 0)
self.assertEqual(orchestrator_options.min_seconds_before_poll, 0)
self.assertEqual(orchestrator_options.batch_size, batch_size)
Expand Down
4 changes: 2 additions & 2 deletions ax/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def poll_available_capacity(self) -> int:
as is possible without violating Orchestrator settings). There is no need to
artificially force this method to limit capacity; ``Orchestrator`` has other
limitations in place to limit number of trials running at once,
like the ``OrchestratorOptions.max_pending_trials`` setting, or
more granular control in the form of the `max_parallelism`
like the ``OrchestratorOptions.max_concurrent_trials`` setting, or
more granular control in the form of the `max_concurrency`
setting in each of the `GenerationStep`s of a `GenerationStrategy`).

Returns:
Expand Down
12 changes: 6 additions & 6 deletions ax/early_stopping/experiment_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# Constants for experiment replay
MAX_REPLAY_TRIALS: int = 50
REPLAY_NUM_POINTS_PER_CURVE: int = 20
MAX_PENDING_TRIALS: int = 5
MAX_CONCURRENT_TRIALS: int = 5
MIN_SAVINGS_THRESHOLD: float = 0.1 # 10% threshold


Expand All @@ -44,7 +44,7 @@ def replay_experiment(
num_samples_per_curve: int,
max_replay_trials: int,
metric: Metric,
max_pending_trials: int,
max_concurrent_trials: int,
early_stopping_strategy: BaseEarlyStoppingStrategy | None,
logging_level: int = logging.ERROR,
) -> Experiment | None:
Expand Down Expand Up @@ -99,7 +99,7 @@ def replay_experiment(
],
)
options = OrchestratorOptions(
max_pending_trials=max_pending_trials,
max_concurrent_trials=max_concurrent_trials,
total_trials=min(len(historical_experiment.trials), max_replay_trials),
seconds_between_polls_backoff_factor=1.0,
min_seconds_before_poll=0.0,
Expand All @@ -119,7 +119,7 @@ def replay_experiment(
def estimate_hypothetical_early_stopping_savings(
experiment: Experiment,
metric: Metric,
max_pending_trials: int = MAX_PENDING_TRIALS,
max_concurrent_trials: int = MAX_CONCURRENT_TRIALS,
) -> float:
"""Estimate hypothetical early stopping savings using experiment replay.

Expand All @@ -130,7 +130,7 @@ def estimate_hypothetical_early_stopping_savings(
Args:
experiment: The experiment to analyze.
metric: The metric to use for early stopping replay.
max_pending_trials: Maximum number of pending trials for the replay
max_concurrent_trials: Maximum number of concurrent trials for the replay
orchestrator. Defaults to 5.

Returns:
Expand All @@ -156,7 +156,7 @@ def estimate_hypothetical_early_stopping_savings(
num_samples_per_curve=REPLAY_NUM_POINTS_PER_CURVE,
max_replay_trials=MAX_REPLAY_TRIALS,
metric=metric,
max_pending_trials=max_pending_trials,
max_concurrent_trials=max_concurrent_trials,
early_stopping_strategy=default_ess,
)

Expand Down
Loading
Loading