Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion ax/analysis/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,22 @@
import numpy as np
import pandas as pd
from ax.analysis.plotly.utils import STALE_FAIL_REASON, truncate_label
from ax.analysis.utils import _relativize_df_with_sq, prepare_arm_data
from ax.analysis.utils import (
_get_scalarized_constraint_mean_and_sem,
_prepare_p_feasible,
_relativize_df_with_sq,
prepare_arm_data,
)
from ax.api.client import Client
from ax.api.configs import RangeParameterConfig
from ax.core.arm import Arm
from ax.core.batch_trial import BatchTrial
from ax.core.data import relativize_dataframe
from ax.core.experiment import Experiment
from ax.core.metric import Metric
from ax.core.outcome_constraint import OutcomeConstraint, ScalarizedOutcomeConstraint
from ax.core.trial_status import TrialStatus # noqa
from ax.core.types import ComparisonOp
from ax.exceptions.core import UserInputError
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_offline_experiments, get_online_experiments
Expand Down Expand Up @@ -865,3 +872,58 @@ def test_offline(self) -> None:
trial_index=trial_index,
additional_arms=additional_arms,
)

def test_scalarized_constraints(self) -> None:
df = pd.DataFrame(
{
"trial_index": [0, 0],
"arm_name": ["arm1", "arm2"],
"m1_mean": [5.0, 15.0],
"m1_sem": [1.0, 1.0],
"m2_mean": [5.0, 15.0],
"m2_sem": [1.0, 1.0],
"regular_mean": [8.0, 12.0],
"regular_sem": [0.5, 0.5],
}
)

scalarized_constraint = ScalarizedOutcomeConstraint(
metrics=[Metric(name="m1"), Metric(name="m2")],
weights=[1.0, 1.0],
op=ComparisonOp.LEQ,
bound=25.0,
relative=False,
)

# Helper math: mean = w1*m1 + w2*m2, SEM = sqrt(w1^2*s1^2 + w2^2*s2^2)
mean, sem = _get_scalarized_constraint_mean_and_sem(df, scalarized_constraint)
np.testing.assert_array_almost_equal(mean, [10.0, 30.0])
np.testing.assert_array_almost_equal(sem, [np.sqrt(2), np.sqrt(2)])

# Missing metric returns NaN mean and zero SEM
missing_constraint = ScalarizedOutcomeConstraint(
metrics=[Metric(name="m1"), Metric(name="missing")],
weights=[1.0, 1.0],
op=ComparisonOp.LEQ,
bound=10.0,
)
mean, sem = _get_scalarized_constraint_mean_and_sem(df, missing_constraint)
self.assertTrue(np.all(np.isnan(mean)))
np.testing.assert_array_equal(sem, np.zeros(2))

# p_feasible with mixed regular + scalarized constraints
regular_constraint = OutcomeConstraint(
metric=Metric(name="regular"),
op=ComparisonOp.LEQ,
bound=10.0,
relative=False,
)
p_feasible = _prepare_p_feasible(
df=df,
status_quo_df=None,
outcome_constraints=[regular_constraint, scalarized_constraint],
)
self.assertFalse(p_feasible.isna().any())
# arm1 (regular=8, scalarized=10) more feasible than
# arm2 (regular=12, scalarized=30)
self.assertGreater(p_feasible.iloc[0], p_feasible.iloc[1])
123 changes: 82 additions & 41 deletions ax/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Sequence

import numpy as np
import numpy.typing as npt
import pandas as pd
import torch
from ax.adapter.base import Adapter
Expand Down Expand Up @@ -540,6 +541,54 @@ def _extract_generation_node_name(trial: BaseTrial, arm: Arm) -> str:
return Keys.UNKNOWN_GENERATION_NODE.value


def _get_scalarized_constraint_mean_and_sem(
df: pd.DataFrame,
constraint: ScalarizedOutcomeConstraint,
) -> tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]]:
"""
Compute the combined mean and SEM for a ScalarizedOutcomeConstraint.

For independent random variables:
combined_mean = sum(weight_i * mean_i)
combined_sem = sqrt(sum((weight_i * sem_i)^2))

Args:
df: DataFrame with "{metric_name}_mean" and "{metric_name}_sem" columns.
constraint: The ScalarizedOutcomeConstraint.

Returns:
Tuple of (combined_mean, combined_sem) as numpy arrays.
If any component metric is missing, mean is NaN and sem is 0.
"""
n_rows = len(df)
combined_mean = np.zeros(n_rows)
combined_var = np.zeros(n_rows)
all_metrics_present = True

for metric, weight in constraint.metric_weights:
mean_col = f"{metric.name}_mean"
sem_col = f"{metric.name}_sem"

if mean_col in df.columns:
combined_mean += weight * df[mean_col].values
else:
all_metrics_present = False
break

if sem_col in df.columns:
metric_sem = df[sem_col].fillna(0).values
else:
metric_sem = np.zeros(n_rows)

combined_var += (weight**2) * (metric_sem**2)

if not all_metrics_present:
# Match existing pattern: mean=NaN, sem=0 for missing data
return np.full(n_rows, np.nan), np.zeros(n_rows)

return combined_mean, np.sqrt(combined_var)


def _prepare_p_feasible(
df: pd.DataFrame,
status_quo_df: pd.DataFrame | None,
Expand Down Expand Up @@ -571,34 +620,27 @@ def _prepare_p_feasible(
return pd.Series(np.ones(len(df)))

# If an arm is missing data for a metric leave the mean as NaN.
oc_names = []
for oc in outcome_constraints:
if isinstance(oc, ScalarizedOutcomeConstraint):
# take the str representation of the scalarized outcome constraint
oc_names.append(str(oc))
else:
oc_names.append(oc.metric.name)

assert len(oc_names) == len(outcome_constraints)

means = []
sigmas = []
for i, oc_name in enumerate(oc_names):
df_constraint = none_throws(rel_df if outcome_constraints[i].relative else df)
# TODO[T235432214]: currently we are leaving the mean as NaN if the constraint
# is on ScalarizedOutcomeConstraint but we should be able to calculate it by
# setting the mean to be weights * individual metrics and sem to be
# sqrt(sum((weights * individual_sems)^2)), assuming independence.
if f"{oc_name}_mean" in df_constraint.columns:
means.append(df_constraint[f"{oc_name}_mean"].tolist())
for oc in outcome_constraints:
df_constraint = none_throws(rel_df if oc.relative else df)

if isinstance(oc, ScalarizedOutcomeConstraint):
mean, sem = _get_scalarized_constraint_mean_and_sem(df_constraint, oc)
means.append(mean.tolist())
sigmas.append(sem.tolist())
else:
means.append([float("nan")] * len(df_constraint))
sigmas.append(
(df_constraint[f"{oc_name}_sem"].fillna(0)).tolist()
if f"{oc_name}_sem" in df_constraint.columns
else [0] * len(df)
)
metric_name = oc.metric.name
if f"{metric_name}_mean" in df_constraint.columns:
means.append(df_constraint[f"{metric_name}_mean"].tolist())
else:
means.append([float("nan")] * len(df_constraint))

sigmas.append(
(df_constraint[f"{metric_name}_sem"].fillna(0)).tolist()
if f"{metric_name}_sem" in df_constraint.columns
else [0] * len(df)
)

con_lower_inds = [
i
Expand Down Expand Up @@ -665,28 +707,27 @@ def _prepare_p_feasible_per_constraint(
if len(outcome_constraints) == 0:
return pd.DataFrame(index=df.index)

oc_names = []
for oc in outcome_constraints:
if isinstance(oc, ScalarizedOutcomeConstraint):
oc_names.append(str(oc))
else:
oc_names.append(oc.metric.name)

result_df = pd.DataFrame(index=df.index)
# Compute probability for each constraint individually
for oc_name, oc in zip(oc_names, outcome_constraints):
for oc in outcome_constraints:
df_constraint = none_throws(rel_df if oc.relative else df)

# Get mean and sigma for this constraint
if f"{oc_name}_mean" in df_constraint.columns:
mean = df_constraint[f"{oc_name}_mean"].values
if isinstance(oc, ScalarizedOutcomeConstraint):
mean, sigma = _get_scalarized_constraint_mean_and_sem(df_constraint, oc)
oc_display_name = str(oc)
else:
mean = np.nan * np.ones(len(df_constraint))
metric_name = oc.metric.name
oc_display_name = metric_name

if f"{oc_name}_sem" in df_constraint.columns:
sigma = df_constraint[f"{oc_name}_sem"].fillna(0).values
else:
sigma = np.zeros(len(df))
if f"{metric_name}_mean" in df_constraint.columns:
mean = df_constraint[f"{metric_name}_mean"].values
else:
mean = np.full(len(df_constraint), np.nan)

if f"{metric_name}_sem" in df_constraint.columns:
sigma = df_constraint[f"{metric_name}_sem"].fillna(0).values
else:
sigma = np.zeros(len(df))

# Convert to torch tensors (shape: [n_arms, 1])
mean_tensor = torch.tensor(mean, dtype=torch.double).unsqueeze(-1)
Expand All @@ -706,7 +747,7 @@ def _prepare_p_feasible_per_constraint(

# Convert back to numpy and store in result dataframe
prob = log_prob.exp().squeeze().numpy()
result_df[f"p_feasible_{oc_name}"] = prob
result_df[f"p_feasible_{oc_display_name}"] = prob

return result_df

Expand Down