From 41bb89f00557d19db570a69325346181056847c2 Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Thu, 5 Feb 2026 16:29:57 -0800 Subject: [PATCH] rename parallelism=>concurrency across all of Ax Differential Revision: D92457714 --- ax/adapter/tests/test_model_fit_metrics.py | 2 +- .../healthcheck/early_stopping_healthcheck.py | 10 +- .../tests/test_complexity_rating.py | 2 +- ax/api/client.py | 22 ++- ax/benchmark/benchmark.py | 12 +- ax/benchmark/benchmark_method.py | 26 +++- ax/benchmark/benchmark_runner.py | 2 +- ax/benchmark/testing/benchmark_stubs.py | 4 +- ax/benchmark/tests/test_benchmark.py | 12 +- ax/core/runner.py | 4 +- ax/early_stopping/experiment_replay.py | 12 +- ax/generation_strategy/dispatch_utils.py | 131 +++++++++++------- ax/generation_strategy/generation_node.py | 32 +++-- .../tests/test_dispatch_utils.py | 68 ++++----- ax/orchestration/orchestrator.py | 23 +-- ax/orchestration/orchestrator_options.py | 31 ++++- ax/orchestration/tests/test_orchestrator.py | 12 +- ax/service/ax_client.py | 42 +++--- ax/service/tests/test_ax_client.py | 20 +-- ax/utils/common/complexity_utils.py | 16 +-- .../common/tests/test_complexity_utils.py | 10 +- 21 files changed, 303 insertions(+), 190 deletions(-) diff --git a/ax/adapter/tests/test_model_fit_metrics.py b/ax/adapter/tests/test_model_fit_metrics.py index 99dd6c3042e..171a1e160d3 100644 --- a/ax/adapter/tests/test_model_fit_metrics.py +++ b/ax/adapter/tests/test_model_fit_metrics.py @@ -58,7 +58,7 @@ def test_model_fit_metrics(self) -> None: orchestrator = Orchestrator( experiment=self.branin_experiment, generation_strategy=self.generation_strategy, - options=OrchestratorOptions(max_pending_trials=NUM_SOBOL), + options=OrchestratorOptions(max_concurrent_trials=NUM_SOBOL), ) # need to run some trials to initialize the Adapter orchestrator.run_n_trials(max_trials=NUM_SOBOL + 1) diff --git a/ax/analysis/healthcheck/early_stopping_healthcheck.py b/ax/analysis/healthcheck/early_stopping_healthcheck.py index 49bfca2e30c..819e201341c 100644 --- a/ax/analysis/healthcheck/early_stopping_healthcheck.py +++ b/ax/analysis/healthcheck/early_stopping_healthcheck.py @@ -23,7 +23,7 @@ from ax.early_stopping.dispatch import get_default_ess_or_none from ax.early_stopping.experiment_replay import ( estimate_hypothetical_early_stopping_savings, - MAX_PENDING_TRIALS, + MAX_CONCURRENT_TRIALS, MIN_SAVINGS_THRESHOLD, ) from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy @@ -81,7 +81,7 @@ def __init__( self, early_stopping_strategy: BaseEarlyStoppingStrategy | None = None, min_savings_threshold: float = MIN_SAVINGS_THRESHOLD, - max_pending_trials: int = MAX_PENDING_TRIALS, + max_concurrent_trials: int = MAX_CONCURRENT_TRIALS, auto_early_stopping_config: AutoEarlyStoppingConfig | None = None, nudge_additional_info: str | None = None, ) -> None: @@ -95,7 +95,7 @@ def __init__( single-objective unconstrained experiments. min_savings_threshold: Minimum savings threshold to suggest early stopping. Default is 0.1 (10% savings). - max_pending_trials: Maximum number of pending trials for replay + max_concurrent_trials: Maximum number of concurrent trials for replay orchestrator. Default is 5. auto_early_stopping_config: A string for configuring automated early stopping strategy. @@ -111,7 +111,7 @@ def __init__( """ self.early_stopping_strategy = early_stopping_strategy self.min_savings_threshold = min_savings_threshold - self.max_pending_trials = max_pending_trials + self.max_concurrent_trials = max_concurrent_trials self.auto_early_stopping_config = auto_early_stopping_config self.nudge_additional_info = nudge_additional_info @@ -409,7 +409,7 @@ def _report_early_stopping_nudge( savings = estimate_hypothetical_early_stopping_savings( experiment=experiment, metric=metric, - max_pending_trials=self.max_pending_trials, + max_concurrent_trials=self.max_concurrent_trials, ) except Exception as e: # Exception is raised when estimate_hypothetical_early_stopping_savings diff --git a/ax/analysis/healthcheck/tests/test_complexity_rating.py b/ax/analysis/healthcheck/tests/test_complexity_rating.py index 61f4ff75cbc..5aad26efde1 100644 --- a/ax/analysis/healthcheck/tests/test_complexity_rating.py +++ b/ax/analysis/healthcheck/tests/test_complexity_rating.py @@ -215,7 +215,7 @@ def test_unsupported_configurations(self) -> None: ( "invalid_failure_rate_check", OrchestratorOptions( - max_pending_trials=10, + max_concurrent_trials=10, min_failed_trials_for_failure_rate_check=50, ), {"user_supplied_max_trials": 100, "uses_standard_api": True}, diff --git a/ax/api/client.py b/ax/api/client.py index dc1abb18c55..0de53494589 100644 --- a/ax/api/client.py +++ b/ax/api/client.py @@ -6,6 +6,7 @@ # pyre-strict import json +import warnings from collections.abc import Iterable, Sequence from logging import Logger from typing import Any, Literal @@ -43,7 +44,7 @@ BaseEarlyStoppingStrategy, PercentileEarlyStoppingStrategy, ) -from ax.exceptions.core import ObjectNotFoundError, UnsupportedError +from ax.exceptions.core import ObjectNotFoundError, UnsupportedError, UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.orchestration.orchestrator import Orchestrator, OrchestratorOptions from ax.service.utils.best_point_mixin import BestPointMixin @@ -683,9 +684,11 @@ def mark_trial_early_stopped(self, trial_index: int) -> None: def run_trials( self, max_trials: int, - parallelism: int = 1, + concurrency: int = 1, tolerated_trial_failure_rate: float = 0.5, initial_seconds_between_polls: int = 1, + # Deprecated argument for backwards compatibility. + parallelism: int | None = None, ) -> None: """ Run maximum_trials trials in a loop by creating an ephemeral Orchestrator under @@ -694,12 +697,25 @@ def run_trials( Saves to database on completion if ``storage_config`` is present. """ + # Handle deprecated `parallelism` argument. + if parallelism is not None: + warnings.warn( + "`parallelism` is deprecated and will be removed in Ax 1.4. " + "Use `concurrency` instead.", + DeprecationWarning, + stacklevel=2, + ) + if concurrency != 1: + raise UserInputError( + "Cannot specify both `parallelism` and `concurrency`." + ) + concurrency = parallelism orchestrator = Orchestrator( experiment=self._experiment, generation_strategy=self._generation_strategy_or_choose(), options=OrchestratorOptions( - max_pending_trials=parallelism, + max_concurrent_trials=concurrency, tolerated_trial_failure_rate=tolerated_trial_failure_rate, init_seconds_between_polls=initial_seconds_between_polls, ), diff --git a/ax/benchmark/benchmark.py b/ax/benchmark/benchmark.py index a7b93ee82a7..e4ed7c58560 100644 --- a/ax/benchmark/benchmark.py +++ b/ax/benchmark/benchmark.py @@ -137,7 +137,7 @@ def get_benchmark_runner( (used to generate data) and ``step_runtime_function`` (used to determine timing for the simulator). max_concurrency: The maximum number of trials that can be run concurrently. - Typically, ``max_pending_trials`` from ``OrchestratorOptions``, which are + Typically, ``max_concurrent_trials`` from ``OrchestratorOptions``, which are stored on the ``BenchmarkMethod``. force_use_simulated_backend: Whether to use a simulated backend even if ``max_concurrency`` is 1 and ``problem.step_runtime_function`` is @@ -226,7 +226,7 @@ def get_oracle_experiment_from_params( def get_benchmark_orchestrator_options( batch_size: int | None, run_trials_in_batches: bool, - max_pending_trials: int, + max_concurrent_trials: int, early_stopping_strategy: BaseEarlyStoppingStrategy | None, include_status_quo: bool = False, logging_level: int = DEFAULT_LOG_LEVEL, @@ -240,7 +240,7 @@ def get_benchmark_orchestrator_options( for high-throughput settings where there are many trials and generating them in bulk reduces overhead (not to be confused with `BatchTrial`s, which are different). - max_pending_trials: The maximum number of pending trials allowed. + max_concurrent_trials: The maximum number of pending trials allowed. early_stopping_strategy: The early stopping strategy to use (if any). include_status_quo: Whether to include the status quo in each trial. logging_level: The logging level to use for the Orchestrator. @@ -255,7 +255,7 @@ def get_benchmark_orchestrator_options( return OrchestratorOptions( # No new candidates can be generated while any are pending. # If batched, an entire batch must finish before the next can be generated. - max_pending_trials=max_pending_trials, + max_concurrent_trials=max_concurrent_trials, # Do not throttle, as is often necessary when polling real endpoints init_seconds_between_polls=0, min_seconds_before_poll=0, @@ -568,14 +568,14 @@ def run_optimization_with_orchestrator( orchestrator_options = get_benchmark_orchestrator_options( batch_size=method.batch_size, run_trials_in_batches=run_trials_in_batches, - max_pending_trials=method.max_pending_trials, + max_concurrent_trials=method.max_concurrent_trials, early_stopping_strategy=method.early_stopping_strategy, include_status_quo=sq_arm is not None, logging_level=orchestrator_logging_level, ) runner = get_benchmark_runner( problem=problem, - max_concurrency=orchestrator_options.max_pending_trials, + max_concurrency=orchestrator_options.max_concurrent_trials, force_use_simulated_backend=method.early_stopping_strategy is not None, ) experiment = Experiment( diff --git a/ax/benchmark/benchmark_method.py b/ax/benchmark/benchmark_method.py index bea5463bd2f..b1e9d9c1682 100644 --- a/ax/benchmark/benchmark_method.py +++ b/ax/benchmark/benchmark_method.py @@ -5,9 +5,11 @@ # pyre-strict -from dataclasses import dataclass +import warnings +from dataclasses import dataclass, field from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy +from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.utils.common.base import Base @@ -16,7 +18,7 @@ class BenchmarkMethod(Base): """Benchmark method, represented in terms of Ax generation strategy (which tells us which models to use when) and Orchestrator options (which tell us extra execution - information like maximum parallelism, early stopping configuration, etc.). + information like maximum concurrency, early stopping configuration, etc.). Args: name: String description. Defaults to the name of the generation strategy. @@ -24,7 +26,7 @@ class BenchmarkMethod(Base): batch_size: Number of arms per trial. Defaults to 1. If greater than 1, trials are ``BatchTrial``s; otherwise, they are ``Trial``s. Passed to ``OrchestratorOptions``. - max_pending_trials: Passed to ``OrchestratorOptions``. + max_concurrent_trials: Passed to ``OrchestratorOptions``. early_stopping_strategy: Passed to ``OrchestratorOptions``. """ @@ -32,9 +34,25 @@ class BenchmarkMethod(Base): generation_strategy: GenerationStrategy # Options for the Orchestrator. batch_size: int | None = 1 - max_pending_trials: int = 1 + max_concurrent_trials: int = 1 early_stopping_strategy: BaseEarlyStoppingStrategy | None = None + # Deprecated + max_pending_trials: int | None = field(default=None, repr=False) def __post_init__(self) -> None: + if self.max_pending_trials is not None: + warnings.warn( + "`max_pending_trials` is deprecated and will be removed in Ax 1.4. " + "Use `max_concurrent_trials` instead.", + DeprecationWarning, + stacklevel=2, + ) + if self.max_concurrent_trials != 1: + raise UserInputError( + "Cannot specify both `max_pending_trials` and " + "`max_concurrent_trials`." + ) + object.__setattr__(self, "max_concurrent_trials", self.max_pending_trials) + object.__setattr__(self, "max_pending_trials", None) if self.name == "DEFAULT": self.name = self.generation_strategy.name diff --git a/ax/benchmark/benchmark_runner.py b/ax/benchmark/benchmark_runner.py index 83b502245b7..c8779cb3794 100644 --- a/ax/benchmark/benchmark_runner.py +++ b/ax/benchmark/benchmark_runner.py @@ -121,7 +121,7 @@ class BenchmarkRunner(Runner): step_runtime_function: A function that takes in parameters (in ``TParameterization`` format) and returns the runtime of a step. max_concurrency: The maximum number of trials that can be running at a - given time. Typically, this is ``max_pending_trials`` from the + given time. Typically, this is ``max_concurrent_trials`` from the ``orchestrator_options`` on the ``BenchmarkMethod``. force_use_simulated_backend: If True, use the simulated backend even if ``max_concurrency`` is 1 and ``step_runtime_function`` is None. This diff --git a/ax/benchmark/testing/benchmark_stubs.py b/ax/benchmark/testing/benchmark_stubs.py index 2a9e2f35006..c99bb282c74 100644 --- a/ax/benchmark/testing/benchmark_stubs.py +++ b/ax/benchmark/testing/benchmark_stubs.py @@ -293,14 +293,14 @@ def get_discrete_search_space(n_values: int = 20) -> SearchSpace: def get_async_benchmark_method( early_stopping_strategy: BaseEarlyStoppingStrategy | None = None, - max_pending_trials: int = 2, + max_concurrent_trials: int = 2, ) -> BenchmarkMethod: gs = GenerationStrategy( nodes=[DeterministicGenerationNode(search_space=get_discrete_search_space())] ) return BenchmarkMethod( generation_strategy=gs, - max_pending_trials=max_pending_trials, + max_concurrent_trials=max_concurrent_trials, batch_size=1, early_stopping_strategy=early_stopping_strategy, ) diff --git a/ax/benchmark/tests/test_benchmark.py b/ax/benchmark/tests/test_benchmark.py index c58b1441062..cd0d383accc 100644 --- a/ax/benchmark/tests/test_benchmark.py +++ b/ax/benchmark/tests/test_benchmark.py @@ -619,10 +619,10 @@ def test_early_stopping(self) -> None: } self.assertEqual(start_times, expected_start_times) - with self.subTest("max_pending_trials = 1"): + with self.subTest("max_concurrent_trials = 1"): method = get_async_benchmark_method( early_stopping_strategy=early_stopping_strategy, - max_pending_trials=1, + max_concurrent_trials=1, ) experiment = self.run_optimization_with_orchestrator( problem=problem, method=method, seed=0 @@ -652,7 +652,7 @@ def test_early_stopping(self) -> None: self.assertEqual(max_run, {0: 4, 1: 2, 2: 2, 3: 2}) def test_replication_variable_runtime(self) -> None: - method = get_async_benchmark_method(max_pending_trials=1) + method = get_async_benchmark_method(max_concurrent_trials=1) for map_data in [False, True]: with self.subTest(map_data=map_data): problem = get_async_benchmark_problem( @@ -1037,17 +1037,17 @@ def test_get_benchmark_orchestrator_options(self) -> None: generation_strategy=get_sobol_mbm_generation_strategy( model_cls=SingleTaskGP, acquisition_cls=qLogNoisyExpectedImprovement ), - max_pending_trials=2, + max_concurrent_trials=2, batch_size=batch_size, ) orchestrator_options = get_benchmark_orchestrator_options( batch_size=none_throws(method.batch_size), run_trials_in_batches=False, - max_pending_trials=method.max_pending_trials, + max_concurrent_trials=method.max_concurrent_trials, early_stopping_strategy=method.early_stopping_strategy, include_status_quo=include_sq, ) - self.assertEqual(orchestrator_options.max_pending_trials, 2) + self.assertEqual(orchestrator_options.max_concurrent_trials, 2) self.assertEqual(orchestrator_options.init_seconds_between_polls, 0) self.assertEqual(orchestrator_options.min_seconds_before_poll, 0) self.assertEqual(orchestrator_options.batch_size, batch_size) diff --git a/ax/core/runner.py b/ax/core/runner.py index 0fcf8abff20..62427ceb860 100644 --- a/ax/core/runner.py +++ b/ax/core/runner.py @@ -80,8 +80,8 @@ def poll_available_capacity(self) -> int: as is possible without violating Orchestrator settings). There is no need to artificially force this method to limit capacity; ``Orchestrator`` has other limitations in place to limit number of trials running at once, - like the ``OrchestratorOptions.max_pending_trials`` setting, or - more granular control in the form of the `max_parallelism` + like the ``OrchestratorOptions.max_concurrent_trials`` setting, or + more granular control in the form of the `max_concurrency` setting in each of the `GenerationStep`s of a `GenerationStrategy`). Returns: diff --git a/ax/early_stopping/experiment_replay.py b/ax/early_stopping/experiment_replay.py index d4a1f5a5fe2..0f34128c9b4 100644 --- a/ax/early_stopping/experiment_replay.py +++ b/ax/early_stopping/experiment_replay.py @@ -35,7 +35,7 @@ # Constants for experiment replay MAX_REPLAY_TRIALS: int = 50 REPLAY_NUM_POINTS_PER_CURVE: int = 20 -MAX_PENDING_TRIALS: int = 5 +MAX_CONCURRENT_TRIALS: int = 5 MIN_SAVINGS_THRESHOLD: float = 0.1 # 10% threshold @@ -44,7 +44,7 @@ def replay_experiment( num_samples_per_curve: int, max_replay_trials: int, metric: Metric, - max_pending_trials: int, + max_concurrent_trials: int, early_stopping_strategy: BaseEarlyStoppingStrategy | None, logging_level: int = logging.ERROR, ) -> Experiment | None: @@ -99,7 +99,7 @@ def replay_experiment( ], ) options = OrchestratorOptions( - max_pending_trials=max_pending_trials, + max_concurrent_trials=max_concurrent_trials, total_trials=min(len(historical_experiment.trials), max_replay_trials), seconds_between_polls_backoff_factor=1.0, min_seconds_before_poll=0.0, @@ -119,7 +119,7 @@ def replay_experiment( def estimate_hypothetical_early_stopping_savings( experiment: Experiment, metric: Metric, - max_pending_trials: int = MAX_PENDING_TRIALS, + max_concurrent_trials: int = MAX_CONCURRENT_TRIALS, ) -> float: """Estimate hypothetical early stopping savings using experiment replay. @@ -130,7 +130,7 @@ def estimate_hypothetical_early_stopping_savings( Args: experiment: The experiment to analyze. metric: The metric to use for early stopping replay. - max_pending_trials: Maximum number of pending trials for the replay + max_concurrent_trials: Maximum number of concurrent trials for the replay orchestrator. Defaults to 5. Returns: @@ -156,7 +156,7 @@ def estimate_hypothetical_early_stopping_savings( num_samples_per_curve=REPLAY_NUM_POINTS_PER_CURVE, max_replay_trials=MAX_REPLAY_TRIALS, metric=metric, - max_pending_trials=max_pending_trials, + max_concurrent_trials=max_concurrent_trials, early_stopping_strategy=default_ess, ) diff --git a/ax/generation_strategy/dispatch_utils.py b/ax/generation_strategy/dispatch_utils.py index bc0957896c5..8bf72d0d2d1 100644 --- a/ax/generation_strategy/dispatch_utils.py +++ b/ax/generation_strategy/dispatch_utils.py @@ -7,6 +7,7 @@ # pyre-strict import logging +import warnings from math import ceil from typing import Any @@ -16,6 +17,7 @@ from ax.core.optimization_config import OptimizationConfig from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter from ax.core.search_space import SearchSpace +from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import ( GenerationStep, GenerationStrategy, @@ -31,7 +33,7 @@ logger: logging.Logger = get_logger(__name__) -DEFAULT_BAYESIAN_PARALLELISM = 3 +DEFAULT_BAYESIAN_CONCURRENCY = 3 # `BO_MIXED` optimizes all range parameters once for each combination of choice # parameters, then takes the optimum of those optima. The cost associated with this # method grows with the number of combinations, and so it is only used when the @@ -49,7 +51,7 @@ def _make_sobol_step( num_trials: int = -1, min_trials_observed: int | None = None, enforce_num_trials: bool = True, - max_parallelism: int | None = None, + max_concurrency: int | None = None, seed: int | None = None, should_deduplicate: bool = False, ) -> GenerationStep: @@ -60,7 +62,7 @@ def _make_sobol_step( # NOTE: ceil(-1 / 2) = 0, so this is safe to do when num trials is -1. min_trials_observed=min_trials_observed or ceil(num_trials / 2), enforce_num_trials=enforce_num_trials, - max_parallelism=max_parallelism, + max_concurrency=max_concurrency, generator_kwargs={"deduplicate": True, "seed": seed}, should_deduplicate=should_deduplicate, use_all_trials_in_exp=True, @@ -71,7 +73,7 @@ def _make_botorch_step( num_trials: int = -1, min_trials_observed: int | None = None, enforce_num_trials: bool = True, - max_parallelism: int | None = None, + max_concurrency: int | None = None, generator: GeneratorRegistryBase = Generators.BOTORCH_MODULAR, generator_kwargs: dict[str, Any] | None = None, winsorization_config: None @@ -126,7 +128,7 @@ def _make_botorch_step( # NOTE: ceil(-1 / 2) = 0, so this is safe to do when num trials is -1. min_trials_observed=min_trials_observed or ceil(num_trials / 2), enforce_num_trials=enforce_num_trials, - max_parallelism=max_parallelism, + max_concurrency=max_concurrency, generator_kwargs=generator_kwargs, should_deduplicate=should_deduplicate, ) @@ -296,8 +298,8 @@ def choose_generation_strategy_legacy( num_completed_initialization_trials: int = 0, max_initialization_trials: int | None = None, min_sobol_trials_observed: int | None = None, - max_parallelism_cap: int | None = None, - max_parallelism_override: int | None = None, + max_concurrency_cap: int | None = None, + max_concurrency_override: int | None = None, optimization_config: OptimizationConfig | None = None, should_deduplicate: bool = False, use_saasbo: bool = False, @@ -307,6 +309,9 @@ def choose_generation_strategy_legacy( suggested_model_override: GeneratorRegistryBase | None = None, use_input_warping: bool = False, simplify_parameter_changes: bool = False, + # Deprecated arguments for backwards compatibility. + max_parallelism_cap: int | None = None, + max_parallelism_override: int | None = None, ) -> GenerationStrategy: """Select an appropriate generation strategy based on the properties of the search space and expected settings of the experiment, such as number of @@ -321,11 +326,11 @@ def choose_generation_strategy_legacy( enforce_sequential_optimization: Whether to enforce that 1) the generation strategy needs to be updated with ``min_trials_observed`` observations for a given generation step before proceeding to the next one and 2) maximum - number of trials running at once (max_parallelism) if enforced for the - BayesOpt step. NOTE: ``max_parallelism_override`` and - ``max_parallelism_cap`` settings will still take their effect on max - parallelism even if ``enforce_sequential_optimization=False``, so if those - settings are specified, max parallelism will be enforced. + number of trials running at once (max_concurrency) if enforced for the + BayesOpt step. NOTE: ``max_concurrency_override`` and + ``max_concurrency_cap`` settings will still take their effect on max + concurrency even if ``enforce_sequential_optimization=False``, so if those + settings are specified, max concurrency will be enforced. random_seed: Fixed random seed for the Sobol generator. torch_device: The device to use for generation steps implemented in PyTorch (e.g. via BoTorch). Some generation steps (in particular EHVI-based ones @@ -356,21 +361,21 @@ def choose_generation_strategy_legacy( min_sobol_trials_observed: Minimum number of Sobol trials that must be observed before proceeding to the next generation step. Defaults to `ceil(num_initialization_trials / 2)`. - max_parallelism_cap: Integer cap on parallelism in this generation strategy. - If specified, ``max_parallelism`` setting in each generation step will be + max_concurrency_cap: Integer cap on concurrency in this generation strategy. + If specified, ``max_concurrency`` setting in each generation step will be set to the minimum of the default setting for that step and the value of - this cap. ``max_parallelism_cap`` is meant to just be a hard limit on - parallelism (e.g. to avoid overloading machine(s) that evaluate the + this cap. ``max_concurrency_cap`` is meant to just be a hard limit on + concurrency (e.g. to avoid overloading machine(s) that evaluate the experiment trials). Specify only if not specifying - ``max_parallelism_override``. - max_parallelism_override: Integer, with which to override the default max - parallelism setting for all steps in the generation strategy returned from - this function. Each generation step has a ``max_parallelism`` value, which + ``max_concurrency_override``. + max_concurrency_override: Integer, with which to override the default max + concurrency setting for all steps in the generation strategy returned from + this function. Each generation step has a ``max_concurrency`` value, which restricts how many trials can run simultaneously during a given generation - step. By default, the parallelism setting is chosen as appropriate for the - model in a given generation step. If ``max_parallelism_override`` is -1, - no max parallelism will be enforced for any step of the generation - strategy. Be aware that parallelism is limited to improve performance of + step. By default, the concurrency setting is chosen as appropriate for the + model in a given generation step. If ``max_concurrency_override`` is -1, + no max concurrency will be enforced for any step of the generation + strategy. Be aware that concurrency is limited to improve performance of Bayesian optimization, so only disable its limiting if necessary. optimization_config: Used to infer whether to use MOO. should_deduplicate: Whether to deduplicate the parameters of proposed arms @@ -403,6 +408,34 @@ def choose_generation_strategy_legacy( arms generated via Bayesian Optimization by pruning irrelevant parameter changes. """ + # Handle deprecated arguments. + if max_parallelism_cap is not None: + warnings.warn( + "`max_parallelism_cap` is deprecated and will be removed in Ax 1.4. " + "Use `max_concurrency_cap` instead.", + DeprecationWarning, + stacklevel=2, + ) + if max_concurrency_cap is not None: + raise UserInputError( + "Cannot specify both `max_parallelism_cap` and `max_concurrency_cap`." + ) + max_concurrency_cap = max_parallelism_cap + + if max_parallelism_override is not None: + warnings.warn( + "`max_parallelism_override` is deprecated and will be removed in Ax 1.4. " + "Use `max_concurrency_override` instead.", + DeprecationWarning, + stacklevel=2, + ) + if max_concurrency_override is not None: + raise UserInputError( + "Cannot specify both `max_parallelism_override` and " + "`max_concurrency_override`." + ) + max_concurrency_override = max_parallelism_override + if experiment is not None and optimization_config is None: optimization_config = experiment.optimization_config @@ -412,36 +445,36 @@ def choose_generation_strategy_legacy( optimization_config=optimization_config, use_saasbo=use_saasbo, ) - # Determine max parallelism for the generation steps. - if max_parallelism_override == -1: - # `max_parallelism_override` of -1 means no max parallelism enforcement in - # the generation strategy, which means `max_parallelism=None` in gen. steps. - sobol_parallelism = bo_parallelism = None - elif max_parallelism_override is not None: - sobol_parallelism = bo_parallelism = max_parallelism_override - elif max_parallelism_cap is not None: # Max parallelism override is None by now - sobol_parallelism = max_parallelism_cap - bo_parallelism = min(max_parallelism_cap, DEFAULT_BAYESIAN_PARALLELISM) + # Determine max concurrency for the generation steps. + if max_concurrency_override == -1: + # `max_concurrency_override` of -1 means no max concurrency enforcement in + # the generation strategy, which means `max_concurrency=None` in gen. steps. + sobol_concurrency = bo_concurrency = None + elif max_concurrency_override is not None: + sobol_concurrency = bo_concurrency = max_concurrency_override + elif max_concurrency_cap is not None: # Max concurrency override is None by now + sobol_concurrency = max_concurrency_cap + bo_concurrency = min(max_concurrency_cap, DEFAULT_BAYESIAN_CONCURRENCY) elif not enforce_sequential_optimization: - # If no max parallelism settings specified and not enforcing sequential - # optimization, do not limit parallelism. - sobol_parallelism = bo_parallelism = None - else: # No additional max parallelism settings, use defaults - sobol_parallelism = None # No restriction on Sobol phase - bo_parallelism = DEFAULT_BAYESIAN_PARALLELISM + # If no max concurrency settings specified and not enforcing sequential + # optimization, do not limit concurrency. + sobol_concurrency = bo_concurrency = None + else: # No additional max concurrency settings, use defaults + sobol_concurrency = None # No restriction on Sobol phase + bo_concurrency = DEFAULT_BAYESIAN_CONCURRENCY if not force_random_search and suggested_model is not None: if not enforce_sequential_optimization and ( - max_parallelism_override or max_parallelism_cap + max_concurrency_override or max_concurrency_cap ): logger.info( - "If `enforce_sequential_optimization` is False, max parallelism is " - "not enforced and other max parallelism settings will be ignored." + "If `enforce_sequential_optimization` is False, max concurrency is " + "not enforced and other max concurrency settings will be ignored." ) - if max_parallelism_override and max_parallelism_cap: + if max_concurrency_override and max_concurrency_cap: raise ValueError( - "If `max_parallelism_override` specified, cannot also apply " - "`max_parallelism_cap`." + "If `max_concurrency_override` specified, cannot also apply " + "`max_concurrency_cap`." ) # If number of initialization trials is not specified, estimate it. @@ -499,7 +532,7 @@ def choose_generation_strategy_legacy( min_trials_observed=min_sobol_trials_observed, enforce_num_trials=enforce_sequential_optimization, seed=random_seed, - max_parallelism=sobol_parallelism, + max_concurrency=sobol_concurrency, should_deduplicate=should_deduplicate, ) ) @@ -508,7 +541,7 @@ def choose_generation_strategy_legacy( generator=suggested_model, winsorization_config=winsorization_config, derelativize_with_raw_status_quo=derelativize_with_raw_status_quo, - max_parallelism=bo_parallelism, + max_concurrency=bo_concurrency, generator_kwargs=generator_kwargs, should_deduplicate=should_deduplicate, disable_progbar=disable_progbar, @@ -540,7 +573,7 @@ def choose_generation_strategy_legacy( _make_sobol_step( seed=random_seed, should_deduplicate=should_deduplicate, - max_parallelism=sobol_parallelism, + max_concurrency=sobol_concurrency, ) ] ) diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index a3efebc8cd7..dcff9f8fd0e 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -8,6 +8,7 @@ from __future__ import annotations +import warnings from collections import defaultdict from collections.abc import Sequence from logging import Logger @@ -958,9 +959,9 @@ class GenerationStep: If `num_trials` of a given step have been generated but `min_trials_ observed` have not been completed, a call to `generation_strategy.gen` will fail with a `DataRequiredError`. - max_parallelism: How many trials generated in the course of this step are + max_concurrency: How many trials generated in the course of this step are allowed to be run (i.e. have `trial.status` of `RUNNING`) simultaneously. - If `max_parallelism` trials from this step are already running, a call + If `max_concurrency` trials from this step are already running, a call to `generation_strategy.gen` will fail with a `MaxParallelismReached Exception`, indicating that more trials need to be completed before generating and running next trials. @@ -1012,7 +1013,7 @@ def __new__( generator_kwargs: dict[str, Any] | None = None, generator_gen_kwargs: dict[str, Any] | None = None, min_trials_observed: int = 0, - max_parallelism: int | None = None, + max_concurrency: int | None = None, enforce_num_trials: bool = True, should_deduplicate: bool = False, generator_name: str | None = None, @@ -1022,6 +1023,7 @@ def __new__( # Deprecated arguments for backwards compatibility. model_kwargs: dict[str, Any] | None = None, model_gen_kwargs: dict[str, Any] | None = None, + max_parallelism: int | None = None, # DEPRECATED: use max_concurrency. ) -> GenerationNode: r"""Creates a ``GenerationNode`` configured as a single-model generation step. @@ -1033,15 +1035,29 @@ def __new__( if use_update: raise DeprecationWarning("`GenerationStep.use_update` is deprecated.") + # Handle deprecated `max_parallelism` argument. + if max_parallelism is not None: + warnings.warn( + "`max_parallelism` is deprecated and will be removed in Ax 1.4. " + "Use `max_concurrency` instead.", + DeprecationWarning, + stacklevel=2, + ) + if max_concurrency is not None: + raise UserInputError( + "Cannot specify both `max_parallelism` and `max_concurrency`." + ) + max_concurrency = max_parallelism + if num_trials < 1 and num_trials != -1: raise UserInputError( "`num_trials` must be positive or -1 (indicating unlimited) " "for all generation steps." ) - if max_parallelism is not None and max_parallelism < 1: + if max_concurrency is not None and max_concurrency < 1: raise UserInputError( - "Maximum parallelism should be None (if no limit) or " - f"a positive number. Got: {max_parallelism} for " + "Maximum concurrency should be None (if no limit) or " + f"a positive number. Got: {max_concurrency} for " f"step {generator_name}." ) @@ -1115,10 +1131,10 @@ def __new__( use_all_trials_in_exp=use_all_trials_in_exp, ) ) - if max_parallelism is not None: + if max_concurrency is not None: transition_criteria.append( MaxGenerationParallelism( - threshold=max_parallelism, + threshold=max_concurrency, transition_to=placeholder_transition_to, only_in_statuses=[TrialStatus.RUNNING], block_gen_if_met=True, diff --git a/ax/generation_strategy/tests/test_dispatch_utils.py b/ax/generation_strategy/tests/test_dispatch_utils.py index c382dd9fe3a..1b9b0803fbf 100644 --- a/ax/generation_strategy/tests/test_dispatch_utils.py +++ b/ax/generation_strategy/tests/test_dispatch_utils.py @@ -19,7 +19,7 @@ _make_botorch_step, calculate_num_initialization_trials, choose_generation_strategy_legacy, - DEFAULT_BAYESIAN_PARALLELISM, + DEFAULT_BAYESIAN_CONCURRENCY, ) from ax.generation_strategy.generation_node import GenerationNode from ax.generation_strategy.transition_criterion import ( @@ -621,14 +621,14 @@ def test_enforce_sequential_optimization(self) -> None: sobol_gpei._nodes[0].transition_criteria[0], MinTrials ) self.assertTrue(node0_min_trials.block_gen_if_met) - # Check that max_parallelism is set by verifying MaxGenerationParallelism + # Check that max_concurrency is set by verifying MaxGenerationParallelism # criterion exists on node 1 - node1_max_parallelism = [ + node1_max_concurrency = [ tc for tc in sobol_gpei._nodes[1].transition_criteria if isinstance(tc, MaxGenerationParallelism) ] - self.assertTrue(len(node1_max_parallelism) > 0) + self.assertTrue(len(node1_max_concurrency) > 0) with self.subTest("False"): sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), @@ -646,22 +646,22 @@ def test_enforce_sequential_optimization(self) -> None: sobol_gpei._nodes[0].transition_criteria[0], MinTrials ) self.assertFalse(node0_min_trials.block_gen_if_met) - # Check that max_parallelism is None by verifying no + # Check that max_concurrency is None by verifying no # MaxGenerationParallelism criterion exists on node 1 - node1_max_parallelism = [ + node1_max_concurrency = [ tc for tc in sobol_gpei._nodes[1].transition_criteria if isinstance(tc, MaxGenerationParallelism) ] - self.assertEqual(len(node1_max_parallelism), 0) - with self.subTest("False and max_parallelism_override"): + self.assertEqual(len(node1_max_concurrency), 0) + with self.subTest("False and max_concurrency_override"): with self.assertLogs( choose_generation_strategy_legacy.__module__, logging.INFO ) as logger: choose_generation_strategy_legacy( search_space=get_branin_search_space(), enforce_sequential_optimization=False, - max_parallelism_override=5, + max_concurrency_override=5, ) self.assertTrue( any( @@ -670,14 +670,14 @@ def test_enforce_sequential_optimization(self) -> None: ), logger.output, ) - with self.subTest("False and max_parallelism_cap"): + with self.subTest("False and max_concurrency_cap"): with self.assertLogs( choose_generation_strategy_legacy.__module__, logging.INFO ) as logger: choose_generation_strategy_legacy( search_space=get_branin_search_space(), enforce_sequential_optimization=False, - max_parallelism_cap=5, + max_concurrency_cap=5, ) self.assertTrue( any( @@ -686,27 +686,27 @@ def test_enforce_sequential_optimization(self) -> None: ), logger.output, ) - with self.subTest("False and max_parallelism_override and max_parallelism_cap"): + with self.subTest("False and max_concurrency_override and max_concurrency_cap"): with self.assertRaisesRegex( ValueError, ( - "If `max_parallelism_override` specified, cannot also apply " - "`max_parallelism_cap`." + "If `max_concurrency_override` specified, cannot also apply " + "`max_concurrency_cap`." ), ): choose_generation_strategy_legacy( search_space=get_branin_search_space(), enforce_sequential_optimization=False, - max_parallelism_override=5, - max_parallelism_cap=5, + max_concurrency_override=5, + max_concurrency_cap=5, ) - def test_max_parallelism_override(self) -> None: + def test_max_concurrency_override(self) -> None: sobol_gpei = choose_generation_strategy_legacy( - search_space=get_branin_search_space(), max_parallelism_override=10 + search_space=get_branin_search_space(), max_concurrency_override=10 ) self.assertTrue( - all(self._get_max_parallelism(s) == 10 for s in sobol_gpei._nodes) + all(self._get_max_concurrency(s) == 10 for s in sobol_gpei._nodes) ) def test_winsorization(self) -> None: @@ -817,47 +817,47 @@ def test_fixed_num_initialization_trials(self) -> None: 3, ) - def _get_max_parallelism(self, node: GenerationNode) -> int | None: - """Helper to extract max_parallelism from transition criteria.""" + def _get_max_concurrency(self, node: GenerationNode) -> int | None: + """Helper to extract max_concurrency from transition criteria.""" for tc in node.transition_criteria: if isinstance(tc, MaxGenerationParallelism): return tc.threshold return None - def test_max_parallelism_adjustments(self) -> None: + def test_max_concurrency_adjustments(self) -> None: # No adjustment. sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space() ) - self.assertIsNone(self._get_max_parallelism(sobol_gpei._nodes[0])) + self.assertIsNone(self._get_max_concurrency(sobol_gpei._nodes[0])) self.assertEqual( - self._get_max_parallelism(sobol_gpei._nodes[1]), - DEFAULT_BAYESIAN_PARALLELISM, + self._get_max_concurrency(sobol_gpei._nodes[1]), + DEFAULT_BAYESIAN_CONCURRENCY, ) # Impose a cap of 1 on max parallelism for all steps. sobol_gpei = choose_generation_strategy_legacy( - search_space=get_branin_search_space(), max_parallelism_cap=1 + search_space=get_branin_search_space(), max_concurrency_cap=1 ) self.assertEqual( - self._get_max_parallelism(sobol_gpei._nodes[0]), + self._get_max_concurrency(sobol_gpei._nodes[0]), 1, ) self.assertEqual( - self._get_max_parallelism(sobol_gpei._nodes[1]), + self._get_max_concurrency(sobol_gpei._nodes[1]), 1, ) # Disable enforcing max parallelism for all steps. sobol_gpei = choose_generation_strategy_legacy( - search_space=get_branin_search_space(), max_parallelism_override=-1 + search_space=get_branin_search_space(), max_concurrency_override=-1 ) - self.assertIsNone(self._get_max_parallelism(sobol_gpei._nodes[0])) - self.assertIsNone(self._get_max_parallelism(sobol_gpei._nodes[1])) + self.assertIsNone(self._get_max_concurrency(sobol_gpei._nodes[0])) + self.assertIsNone(self._get_max_concurrency(sobol_gpei._nodes[1])) # Override max parallelism for all steps. sobol_gpei = choose_generation_strategy_legacy( - search_space=get_branin_search_space(), max_parallelism_override=10 + search_space=get_branin_search_space(), max_concurrency_override=10 ) - self.assertEqual(self._get_max_parallelism(sobol_gpei._nodes[0]), 10) - self.assertEqual(self._get_max_parallelism(sobol_gpei._nodes[1]), 10) + self.assertEqual(self._get_max_concurrency(sobol_gpei._nodes[0]), 10) + self.assertEqual(self._get_max_concurrency(sobol_gpei._nodes[1]), 10) def test_set_should_deduplicate(self) -> None: sobol_gpei = choose_generation_strategy_legacy( diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index 8d4804893b7..cc9f1cd639c 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -984,10 +984,10 @@ def should_consider_optimization_complete(self) -> tuple[bool, str]: """ if self._optimization_complete: return True, "" - if len(self.pending_trials) == 0 and self._get_max_pending_trials() == 0: + if len(self.pending_trials) == 0 and self._get_max_concurrent_trials() == 0: return ( True, - "All pending trials have completed and max_pending_trials is zero.", + "All pending trials have completed and max_concurrent_trials is zero.", ) should_stop, message = self._should_stop_due_to_global_stopping_strategy() @@ -1084,7 +1084,7 @@ def _check_if_failure_rate_exceeded(self, force_check: bool = False) -> bool: Effect on state: If the failure rate has been exceeded, a warning is logged and the private attribute `_failure_rate_has_been_exceeded` is set to True, which causes the - `_get_max_pending_trials` to return zero, so that no further trials are + `_get_max_concurrent_trials` to return zero, so that no further trials are scheduled and an error is raised at the end of the optimization. Returns: @@ -1120,7 +1120,7 @@ def _check_if_failure_rate_exceeded(self, force_check: bool = False) -> bool: "check if anything could cause your metrics to be flaky or " "broken." ) - # NOTE: this private attribute causes `_get_max_pending_trials` to + # NOTE: this private attribute causes `_get_max_concurrent_trials` to # return zero, which causes no further trials to be scheduled. self._failure_rate_has_been_exceeded = True return True @@ -1639,14 +1639,14 @@ def _validate_options(self, options: OrchestratorOptions) -> None: "MultiTypeExperiment." ) - def _get_max_pending_trials(self) -> int: - """Returns the maximum number of pending trials specified in the options, or + def _get_max_concurrent_trials(self) -> int: + """Returns the maximum number of concurrent trials specified in the options, or zero, if the failure rate limit has been exceeded at any point during the optimization. """ if self._failure_rate_has_been_exceeded: return 0 - return self.options.max_pending_trials + return self.options.max_concurrent_trials def _prepare_trials( self, max_new_trials: int @@ -1679,14 +1679,15 @@ def _prepare_trials( # limit on pending trials and limit on total trials. n = capacity if self.options.run_trials_in_batches else 1 total_trials = self.options.total_trials - max_pending_trials = self._get_max_pending_trials() + max_concurrent_trials = self._get_max_concurrent_trials() num_pending_trials = len(self.pending_trials) - max_pending_upper_bound = max_pending_trials - num_pending_trials + max_pending_upper_bound = max_concurrent_trials - num_pending_trials if max_pending_upper_bound < 1: self.logger.debug( - f"`max_pending_trials={max_pending_trials}` and {num_pending_trials} " - "trials are currently pending; not initiating any additional trials." + f"`max_concurrent_trials={max_concurrent_trials}` and " + f"{num_pending_trials} trials are currently pending; " + "not initiating any additional trials." ) return [], [] n = max_pending_upper_bound if n == -1 else min(max_pending_upper_bound, n) diff --git a/ax/orchestration/orchestrator_options.py b/ax/orchestration/orchestrator_options.py index 70ab3b04cf3..e477c65cb54 100644 --- a/ax/orchestration/orchestrator_options.py +++ b/ax/orchestration/orchestrator_options.py @@ -5,12 +5,14 @@ # pyre-strict -from dataclasses import dataclass, field +import warnings +from dataclasses import dataclass, field, InitVar from enum import Enum from logging import INFO from typing import Any from ax.early_stopping.strategies import BaseEarlyStoppingStrategy +from ax.exceptions.core import UserInputError from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy @@ -24,10 +26,10 @@ class OrchestratorOptions: """Settings for a Orchestrator instance. Attributes: - max_pending_trials: Maximum number of pending trials the Orchestrator + max_concurrent_trials: Maximum number of concurrent trials the Orchestrator can have ``STAGED`` or ``RUNNING`` at once, required. If looking to use ``Runner.poll_available_capacity`` as a primary guide for - how many trials should be pending at a given time, set this limit + how many trials should be concurrent at a given time, set this limit to a high number, as an upper bound on number of trials that should not be exceeded. trial_type: Type of trials (1-arm ``Trial`` or multi-arm ``Batch @@ -90,7 +92,7 @@ class OrchestratorOptions: deployment. The size of the groups will be determined as the minimum of ``self.poll_available_capacity()`` and the number of generator runs that the generation strategy is able to produce - without more data or reaching its allowed max paralellism limit. + without more data or reaching its allowed max concurrency limit. debug_log_run_metadata: Whether to log run_metadata for debugging purposes. early_stopping_strategy: A ``BaseEarlyStoppingStrategy`` that determines whether a trial should be stopped given the current state of @@ -125,7 +127,7 @@ class OrchestratorOptions: Default to False. """ - max_pending_trials: int = 10 + max_concurrent_trials: int = 10 trial_type: TrialType = TrialType.TRIAL batch_size: int | None = None total_trials: int | None = None @@ -149,7 +151,24 @@ class OrchestratorOptions: enforce_immutable_search_space_and_opt_config: bool = True mt_experiment_trial_type: str | None = None terminate_if_status_quo_infeasible: bool = False + # Deprecated argument for backwards compatibility. + max_pending_trials: InitVar[int | None] = None + + def __post_init__(self, max_pending_trials: int | None) -> None: + # Handle deprecated `max_pending_trials` argument. + if max_pending_trials is not None: + warnings.warn( + "`max_pending_trials` is deprecated and will be removed in Ax 1.4. " + "Use `max_concurrent_trials` instead.", + DeprecationWarning, + stacklevel=2, + ) + if self.max_concurrent_trials != 10: + raise UserInputError( + "Cannot specify both `max_pending_trials` and " + "`max_concurrent_trials`." + ) + object.__setattr__(self, "max_concurrent_trials", max_pending_trials) - def __post_init__(self) -> None: if self.early_stopping_strategy is not None: object.__setattr__(self, "seconds_between_polls_backoff_factor", 1) diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 67f982ca1a1..026c19de0d3 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -165,7 +165,7 @@ class TestAxOrchestrator(TestCase): "generator_key_override=None)], " "transition_criteria=[MaxGenerationParallelism(" "transition_to='GenerationStep_1_BoTorch')])]), " - "options=OrchestratorOptions(max_pending_trials=10, " + "options=OrchestratorOptions(max_concurrent_trials=10, " "trial_type=, batch_size=None, " "total_trials=0, tolerated_trial_failure_rate=0.2, " "min_failed_trials_for_failure_rate_check=5, log_filepath=None, " @@ -1425,7 +1425,7 @@ def test_optimization_complete(self) -> None: experiment=self.branin_experiment, # Has runner and metrics. generation_strategy=gs, options=OrchestratorOptions( - max_pending_trials=100, + max_concurrent_trials=100, init_seconds_between_polls=0, # Short between polls so test is fast. **self.orchestrator_options_kwargs, ), @@ -1459,7 +1459,7 @@ def test_suppress_all_storage_errors(self, mock_save_exp: Mock, _) -> None: experiment=self.branin_experiment, # Has runner and metrics. generation_strategy=gs, options=OrchestratorOptions( - max_pending_trials=100, + max_concurrent_trials=100, init_seconds_between_polls=0, # Short between polls so test is fast. suppress_storage_errors_after_retries=True, **self.orchestrator_options_kwargs, @@ -1468,14 +1468,14 @@ def test_suppress_all_storage_errors(self, mock_save_exp: Mock, _) -> None: ) self.assertEqual(mock_save_exp.call_count, 3) - def test_max_pending_trials(self) -> None: + def test_max_concurrent_trials(self) -> None: # With runners & metrics, `BareBonesTestOrchestrator.run_all_trials` should run. gs = self.sobol_MBM_GS orchestrator = MockOrchestrator( experiment=self.branin_experiment, # Has runner and metrics. generation_strategy=gs, options=OrchestratorOptions( - max_pending_trials=1, + max_concurrent_trials=1, init_seconds_between_polls=0, # Short between polls so test is fast. **self.orchestrator_options_kwargs, ), @@ -2858,7 +2858,7 @@ class TestAxOrchestratorMultiTypeExperiment(TestAxOrchestrator): "generator_key_override=None)], " "transition_criteria=" "[MaxGenerationParallelism(transition_to='GenerationStep_1_BoTorch')])]), " - "options=OrchestratorOptions(max_pending_trials=10, " + "options=OrchestratorOptions(max_concurrent_trials=10, " "trial_type=, batch_size=None, " "total_trials=0, tolerated_trial_failure_rate=0.2, " "min_failed_trials_for_failure_rate_check=5, log_filepath=None, " diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 47a8f54ad36..49046939ca3 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -837,39 +837,39 @@ def get_trials_data_frame(self) -> pd.DataFrame: """ return self.experiment.to_df() - def get_max_parallelism(self) -> list[tuple[int, int]]: - """Retrieves maximum number of trials that can be scheduled in parallel + def get_max_concurrency(self) -> list[tuple[int, int]]: + """Retrieves maximum number of trials that can be scheduled concurrently at different stages of optimization. Some optimization algorithms profit significantly from sequential optimization (i.e. suggest a few points, get updated with data for them, repeat, see https://ax.dev/docs/bayesopt.html). - Parallelism setting indicates how many trials should be running simulteneously + Concurrency setting indicates how many trials should be running simultaneously (generated, but not yet completed with data). The output of this method is mapping of form - {num_trials -> max_parallelism_setting}, where the max_parallelism_setting - is used for num_trials trials. If max_parallelism_setting is -1, as - many of the trials can be ran in parallel, as necessary. If num_trials - in a tuple is -1, then the corresponding max_parallelism_setting + {num_trials -> max_concurrency_setting}, where the max_concurrency_setting + is used for num_trials trials. If max_concurrency_setting is -1, as + many of the trials can be ran concurrently, as necessary. If num_trials + in a tuple is -1, then the corresponding max_concurrency_setting should be used for all subsequent trials. For example, if the returned list is [(5, -1), (12, 6), (-1, 3)], - the schedule could be: run 5 trials with any parallelism, run 6 trials in - parallel twice, run 3 trials in parallel for as long as needed. Here, + the schedule could be: run 5 trials with any concurrency, run 6 trials + concurrently twice, run 3 trials concurrently for as long as needed. Here, 'running' a trial means obtaining a next trial from `AxClient` through get_next_trials and completing it with data when available. Returns: - Mapping of form {num_trials -> max_parallelism_setting}. + Mapping of form {num_trials -> max_concurrency_setting}. """ - parallelism_settings = [] + concurrency_settings = [] for node in self.generation_strategy._nodes: - # Extract max_parallelism from MaxGenerationParallelism criterion - max_parallelism = None + # Extract max_concurrency from MaxGenerationParallelism criterion + max_concurrency = None for tc in node.transition_criteria: if isinstance(tc, MaxGenerationParallelism): - max_parallelism = tc.threshold + max_concurrency = tc.threshold break # Try to get num_trials from the node. If there's no MinTrials # criterion (unlimited trials), num_trials will raise UserInputError. @@ -878,8 +878,18 @@ def get_max_parallelism(self) -> list[tuple[int, int]]: num_trials = node.num_trials except UserInputError: num_trials = -1 - parallelism_settings.append((num_trials, max_parallelism or num_trials)) - return parallelism_settings + concurrency_settings.append((num_trials, max_concurrency or num_trials)) + return concurrency_settings + + def get_max_parallelism(self) -> list[tuple[int, int]]: + """Deprecated. Use `get_max_concurrency` instead.""" + warnings.warn( + "`get_max_parallelism` is deprecated and will be removed in Ax 1.4. " + "Use `get_max_concurrency` instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.get_max_concurrency() def get_optimization_trace( self, objective_optimum: float | None = None diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index afe81f05d01..389e54d5a39 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -50,7 +50,7 @@ UserInputError, ) from ax.exceptions.generation_strategy import MaxParallelismReachedException -from ax.generation_strategy.dispatch_utils import DEFAULT_BAYESIAN_PARALLELISM +from ax.generation_strategy.dispatch_utils import DEFAULT_BAYESIAN_CONCURRENCY from ax.generation_strategy.generation_strategy import ( GenerationNode, GenerationStep, @@ -511,7 +511,7 @@ def test_default_generation_strategy_continuous(self) -> None: if i < 5: self.assertEqual(gen_limit, 5 - i) else: - self.assertEqual(gen_limit, DEFAULT_BAYESIAN_PARALLELISM) + self.assertEqual(gen_limit, DEFAULT_BAYESIAN_CONCURRENCY) parameterization, trial_index = ax_client.get_next_trial() x, y = parameterization.get("x"), parameterization.get("y") ax_client.complete_trial( @@ -1616,14 +1616,14 @@ def test_keep_generating_without_data(self) -> None: self.assertTrue(len(node0_min_trials) > 0) self.assertFalse(node0_min_trials[0].block_gen_if_met) - # Check that max_parallelism is None by verifying no MaxGenerationParallelism + # Check that max_concurrency is None by verifying no MaxGenerationParallelism # criterion exists on node 1 - node1_max_parallelism = [ + node1_max_concurrency = [ tc for tc in ax_client.generation_strategy._nodes[1].transition_criteria if isinstance(tc, MaxGenerationParallelism) ] - self.assertEqual(len(node1_max_parallelism), 0) + self.assertEqual(len(node1_max_concurrency), 0) for _ in range(10): ax_client.get_next_trial() @@ -1939,17 +1939,17 @@ def test_relative_oc_without_sq(self) -> None: def test_recommended_parallelism(self) -> None: ax_client = AxClient() with self.assertRaisesRegex(AssertionError, "No generation strategy"): - ax_client.get_max_parallelism() + ax_client.get_max_concurrency() ax_client.create_experiment( parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], ) - self.assertEqual(ax_client.get_max_parallelism(), [(5, 5), (-1, 3)]) + self.assertEqual(ax_client.get_max_concurrency(), [(5, 5), (-1, 3)]) self.assertEqual( run_trials_using_recommended_parallelism( - ax_client, ax_client.get_max_parallelism(), 20 + ax_client, ax_client.get_max_concurrency(), 20 ), 0, ) @@ -2319,7 +2319,7 @@ def test_deprecated_save_load_method_errors(self) -> None: with self.assertRaises(NotImplementedError): ax_client.load_experiment("test_experiment") with self.assertRaises(NotImplementedError): - ax_client.get_recommended_max_parallelism() + ax_client.get_recommended_max_concurrency() def test_find_last_trial_with_parameterization(self) -> None: ax_client = AxClient() @@ -2872,7 +2872,7 @@ def test_estimate_early_stopping_savings(self) -> None: self.assertEqual(ax_client.estimate_early_stopping_savings(), 0) - def test_max_parallelism_exception_when_early_stopping(self) -> None: + def test_max_concurrency_exception_when_early_stopping(self) -> None: ax_client = AxClient() ax_client.create_experiment( parameters=[ diff --git a/ax/utils/common/complexity_utils.py b/ax/utils/common/complexity_utils.py index a0edc391df2..6a3ca5563c2 100644 --- a/ax/utils/common/complexity_utils.py +++ b/ax/utils/common/complexity_utils.py @@ -111,7 +111,7 @@ class OptimizationSummary: is True). tolerated_trial_failure_rate: Maximum tolerated trial failure rate (should be <= 0.9). - max_pending_trials: Maximum number of pending trials. + max_concurrent_trials: Maximum number of concurrent trials. min_failed_trials_for_failure_rate_check: Minimum failed trials before failure rate is checked. non_default_advanced_options: Whether non-default advanced options are set. @@ -133,7 +133,7 @@ class OptimizationSummary: # Optional keys max_trials: int | None = None tolerated_trial_failure_rate: float | None = None - max_pending_trials: int | None = None + max_concurrent_trials: int | None = None min_failed_trials_for_failure_rate_check: int | None = None non_default_advanced_options: bool | None = None uses_merge_multiple_curves: bool | None = None @@ -211,7 +211,7 @@ def summarize_ax_optimization_complexity( uses_merge_multiple_curves=uses_merge_multiple_curves, uses_standard_api=uses_standard_api, tolerated_trial_failure_rate=options.tolerated_trial_failure_rate, - max_pending_trials=options.max_pending_trials, + max_concurrent_trials=options.max_concurrent_trials, min_failed_trials_for_failure_rate_check=( options.min_failed_trials_for_failure_rate_check ), @@ -400,19 +400,19 @@ def _check_if_is_in_standard_other_settings( is_in_standard, is_supported = False, False why_not_supported.append(f"{tolerated_trial_failure_rate=} is larger than 0.9.") - max_pending_trials = optimization_summary.max_pending_trials + max_concurrent_trials = optimization_summary.max_concurrent_trials min_failed_trials_for_failure_rate_check = ( optimization_summary.min_failed_trials_for_failure_rate_check ) if ( - max_pending_trials is not None + max_concurrent_trials is not None and min_failed_trials_for_failure_rate_check is not None - and max(2 * max_pending_trials, 5) < min_failed_trials_for_failure_rate_check + and max(2 * max_concurrent_trials, 5) < min_failed_trials_for_failure_rate_check ): is_in_standard, is_supported = False, False why_not_supported.append( f"{min_failed_trials_for_failure_rate_check=} exceeds " - f"{max(2 * max_pending_trials, 5)=}. Please reduce " + f"{max(2 * max_concurrent_trials, 5)=}. Please reduce " "min_failed_trials_for_failure_rate_check below the stated threshold for " "this sweep to be in a supported tier." ) @@ -457,7 +457,7 @@ def check_if_in_standard( num_categorical_6_inf, num_parameter_constraints - Optimization config: num_objectives, num_outcome_constraints - Other settings: max_trials, uses_early_stopping, uses_global_stopping, - uses_standard_api, tolerated_trial_failure_rate, max_pending_trials, + uses_standard_api, tolerated_trial_failure_rate, max_concurrent_trials, min_failed_trials_for_failure_rate_check, non_default_advanced_options, uses_merge_multiple_curves tier_messages: A ``TierMessages`` instance containing tier-specific diff --git a/ax/utils/common/tests/test_complexity_utils.py b/ax/utils/common/tests/test_complexity_utils.py index 4197424a67c..7cf22f7ad28 100644 --- a/ax/utils/common/tests/test_complexity_utils.py +++ b/ax/utils/common/tests/test_complexity_utils.py @@ -118,7 +118,7 @@ def test_orchestrator_options_extraction(self) -> None: # GIVEN custom orchestrator options options = OrchestratorOptions( tolerated_trial_failure_rate=0.25, - max_pending_trials=5, + max_concurrent_trials=5, min_failed_trials_for_failure_rate_check=10, ) @@ -131,7 +131,7 @@ def test_orchestrator_options_extraction(self) -> None: # THEN the summary should reflect orchestrator options self.assertEqual(summary.tolerated_trial_failure_rate, 0.25) - self.assertEqual(summary.max_pending_trials, 5) + self.assertEqual(summary.max_concurrent_trials, 5) self.assertEqual(summary.min_failed_trials_for_failure_rate_check, 10) def test_parameter_constraints_counted(self) -> None: @@ -254,7 +254,7 @@ def get_optimization_summary( uses_global_stopping: bool = False, uses_standard_api: bool = True, tolerated_trial_failure_rate: float | None = 0.5, - max_pending_trials: int | None = 5, + max_concurrent_trials: int | None = 5, min_failed_trials_for_failure_rate_check: int | None = 5, non_default_advanced_options: bool | None = None, uses_merge_multiple_curves: bool | None = None, @@ -273,7 +273,7 @@ def get_optimization_summary( uses_global_stopping=uses_global_stopping, uses_standard_api=uses_standard_api, tolerated_trial_failure_rate=tolerated_trial_failure_rate, - max_pending_trials=max_pending_trials, + max_concurrent_trials=max_concurrent_trials, min_failed_trials_for_failure_rate_check=( min_failed_trials_for_failure_rate_check ), @@ -383,7 +383,7 @@ def test_unsupported_tier_conditions(self) -> None: ), ( get_optimization_summary( - max_pending_trials=3, min_failed_trials_for_failure_rate_check=7 + max_concurrent_trials=3, min_failed_trials_for_failure_rate_check=7 ), "min_failed_trials_for_failure_rate_check=7", ),