Conversation
* Define mean functions for the available presets * Pass both mean and likelihood when instantiating preset
There was a problem hiding this comment.
Pull request overview
Adds an IndexKernel and expands GaussianProcessSurrogate customization via kernel/mean/likelihood factories, including new GP presets and serialization behavior for non-BayBE GP components.
Changes:
- Introduces
IndexKerneland adds it to Hypothesis kernel strategies. - Refactors GP surrogate to accept kernel/mean/likelihood components (including some GPyTorch objects), adds
EDBO/EDBO_SMOOTHEDpresets, and updates preset exports. - Updates serialization hooks and generic/abstract type handling; adds tests for transfer learning and GPyTorch-kernel behavior.
Reviewed changes
Copilot reviewed 23 out of 23 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_transfer_learning.py | Adds transfer-learning recommendation test coverage |
| tests/test_gp.py | Adds tests for using GPyTorch kernels and blocking their serialization |
| tests/hypothesis_strategies/kernels.py | Extends kernel generation strategies to include IndexKernel |
| baybe/utils/boolean.py | Updates abstract detection to handle generics |
| baybe/utils/basic.py | Updates subclass discovery to handle generics |
| baybe/surrogates/gaussian_process/presets/utils.py | Adds lazy-loaded constant mean factory |
| baybe/surrogates/gaussian_process/presets/edbo_smoothed.py | Adds “smoothed EDBO” kernel/likelihood preset implementations |
| baybe/surrogates/gaussian_process/presets/edbo.py | Refactors EDBO preset into kernel/mean/likelihood factories |
| baybe/surrogates/gaussian_process/presets/default.py | Repoints default preset to smoothed EDBO kernel/likelihood + lazy mean |
| baybe/surrogates/gaussian_process/presets/core.py | Adds preset enum values and preset-based GP construction |
| baybe/surrogates/gaussian_process/presets/init.py | Expands exported preset factories and core API |
| baybe/surrogates/gaussian_process/kernel_factory.py | Removes old kernel factory module (replaced by components) |
| baybe/surrogates/gaussian_process/core.py | Adds mean/likelihood factories; uses component factories and IndexKernel for multitask |
| baybe/surrogates/gaussian_process/components/mean.py | Introduces GP mean component factory typing |
| baybe/surrogates/gaussian_process/components/likelihood.py | Introduces GP likelihood component factory typing |
| baybe/surrogates/gaussian_process/components/kernel.py | Introduces GP kernel component factory typing |
| baybe/surrogates/gaussian_process/components/generic.py | Adds generic component factory + serialization blocking hook for GPyTorch kernels |
| baybe/surrogates/gaussian_process/components/init.py | Exposes component factory types |
| baybe/serialization/core.py | Improves base-class (un)structuring and non-BayBE type names in error messages |
| baybe/searchspace/core.py | Refactors task-parameter lookup into a helper property |
| baybe/kernels/basic.py | Adds IndexKernel |
| baybe/kernels/base.py | Updates kernel factory import path and broadens active_dims type |
| CHANGELOG.md | Documents new GP customization support, presets, and IndexKernel |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @pytest.mark.parametrize("active_tasks", ["target_only", "both"]) | ||
| @pytest.mark.parametrize("training_data", ["source", "target", "both"]) | ||
| def test_recommendation(campaign: Campaign): | ||
| """Transfer learning recommendation works regardless of which tasks are | ||
| present in the training data and which tasks are active. | ||
| """ # noqa: D205 | ||
| campaign.recommend(1) |
There was a problem hiding this comment.
This test is parametrized with active_tasks and training_data, but the test function does not accept these arguments and there are no fixtures with those names. Pytest will error at collection time (unused parametrization arguments). Consider either (a) adding active_tasks and training_data as parametrized fixtures that campaign depends on, or (b) parametrizing campaign indirectly (e.g., @pytest.mark.parametrize('campaign', ..., indirect=True)) and removing the unused parametrizations.
| def _is_gpytorch_kernel_class(obj) -> bool: | ||
| """Check if a class is a GPyTorch kernel class using lazy loading.""" | ||
| if sys.modules.get("gpytorch") is None: | ||
| return False | ||
| from gpytorch.kernels import Kernel as GPyTorchKernel | ||
|
|
||
| return issubclass(obj, GPyTorchKernel) | ||
|
|
||
|
|
||
| def _validate_component(instance, attribute: Attribute, value: Any): | ||
| """Validate that an object is a BayBE or a GPyTorch GP component.""" | ||
| if isinstance(value, Kernel) or _is_gpytorch_kernel_class(type(value)): | ||
| return | ||
|
|
||
| raise TypeError( | ||
| f"The object provided for '{attribute.alias}' of " | ||
| f"'{instance.__class__.__name__}' must be a BayBE or a GPyTorch GP component. " | ||
| f"Got: {type(value)}" | ||
| ) |
There was a problem hiding this comment.
The PR claims support for GPyTorch means and likelihoods, but the component checks/validators only recognize BayBE Kernel or GPyTorch kernels. Passing a gpytorch.means.Mean or gpytorch.likelihoods.Likelihood instance will currently fall through to_component_factory and be returned as-is (not wrapped), which will later fail because it is not a factory callable. Update the component detection/validation/conversion logic to also recognize and wrap GPyTorch mean and likelihood instances (and register the corresponding serialization-block hooks as needed).
| def to_component_factory(x: Component | ComponentFactory, /) -> ComponentFactory: | ||
| """Wrap a component into a plain component factory (with factory passthrough).""" | ||
| if isinstance(x, Component) or _is_gpytorch_kernel_class(type(x)): | ||
| return PlainComponentFactory(x) | ||
| return x |
There was a problem hiding this comment.
The PR claims support for GPyTorch means and likelihoods, but the component checks/validators only recognize BayBE Kernel or GPyTorch kernels. Passing a gpytorch.means.Mean or gpytorch.likelihoods.Likelihood instance will currently fall through to_component_factory and be returned as-is (not wrapped), which will later fail because it is not a factory callable. Update the component detection/validation/conversion logic to also recognize and wrap GPyTorch mean and likelihood instances (and register the corresponding serialization-block hooks as needed).
| ) -> GPyTorchMean: | ||
| from gpytorch.means import ConstantMean | ||
|
|
||
| return ConstantMean() |
There was a problem hiding this comment.
Previously, the GP code set batch_shape = train_x.shape[:-2] and constructed ConstantMean(batch_shape=batch_shape) to support batched training data. Returning ConstantMean() without batch_shape can break or mis-broadcast in batched settings. Consider deriving batch_shape from train_x here and passing it into ConstantMean.
| return ConstantMean() | |
| batch_shape = train_x.shape[:-2] | |
| return ConstantMean(batch_shape=batch_shape) |
| import torch | ||
| from gpytorch.likelihoods import GaussianLikelihood |
There was a problem hiding this comment.
The previous implementation created GaussianLikelihood(..., batch_shape=batch_shape) and set noise with the correct batch shape/device/dtype. Here, GaussianLikelihood is created without batch_shape, and torch.tensor([initial_value]) will default to CPU and default dtype. This can cause device/dtype mismatches or incorrect broadcasting in batched contexts. Consider constructing the noise tensor on train_x.device/dtype and passing compatible batch_shape when applicable.
| rank=context.n_tasks, # TODO: make controllable | ||
| ) | ||
| covar_module = base_covar_module * task_covar_module | ||
| ).to_gpytorch(active_dims=context.task_idx) |
There was a problem hiding this comment.
Kernel.to_gpytorch(..., active_dims=...) now types active_dims as Sequence[int] | None, but context.task_idx is an int. If the underlying GPyTorch kernel expects an iterable of indices, passing an int can break. Consider passing a 1-element sequence instead (e.g., (context.task_idx,)).
| ).to_gpytorch(active_dims=context.task_idx) | |
| ).to_gpytorch(active_dims=(context.task_idx,)) |
baybe/kernels/basic.py
Outdated
| num_tasks: int = field(default=None, validator=[instance_of(int), ge(2)]) | ||
| """The number of tasks.""" | ||
|
|
||
| rank: int = field(default=None, validator=[instance_of(int), ge(1)]) | ||
| """The rank of the task covariance matrix.""" | ||
|
|
||
|
|
There was a problem hiding this comment.
num_tasks and rank are typed as int but default to None, while the validators reject None. This creates an invalid default state and can lead to confusing initialization/serialization behavior. Consider removing the defaults (making them required) or making the fields optional. Also, GPyTorch IndexKernel typically expects rank <= num_tasks; consider adding a validator to enforce that constraint.
| num_tasks: int = field(default=None, validator=[instance_of(int), ge(2)]) | |
| """The number of tasks.""" | |
| rank: int = field(default=None, validator=[instance_of(int), ge(1)]) | |
| """The rank of the task covariance matrix.""" | |
| num_tasks: int = field(validator=[instance_of(int), ge(2)]) | |
| """The number of tasks.""" | |
| rank: int = field(validator=[instance_of(int), ge(1)]) | |
| """The rank of the task covariance matrix.""" | |
| @rank.validator | |
| def _validate_rank_not_exceed_num_tasks(self, attribute, value) -> None: | |
| """Ensure that the rank does not exceed the number of tasks.""" | |
| if value > self.num_tasks: | |
| raise ValueError( | |
| "IndexKernel 'rank' must be less than or equal to 'num_tasks'. " | |
| f"Got rank={value}, num_tasks={self.num_tasks}." | |
| ) |
| index_kernels = st.builds( | ||
| IndexKernel, | ||
| num_tasks=st.integers(min_value=2, max_value=5), | ||
| rank=st.integers(min_value=1, max_value=5), |
There was a problem hiding this comment.
This strategy can generate invalid combinations where rank > num_tasks (e.g., num_tasks=2, rank=5), which may fail when converting/using a GPyTorch IndexKernel. Constrain rank to be <= num_tasks (e.g., generate num_tasks first and then rank=st.integers(1, num_tasks)).
| index_kernels = st.builds( | |
| IndexKernel, | |
| num_tasks=st.integers(min_value=2, max_value=5), | |
| rank=st.integers(min_value=1, max_value=5), | |
| index_kernels = st.integers(min_value=2, max_value=5).flatmap( | |
| lambda num_tasks: st.builds( | |
| IndexKernel, | |
| num_tasks=st.just(num_tasks), | |
| rank=st.integers(min_value=1, max_value=num_tasks), | |
| ), |
| PresetKernelFactory: type[KernelFactory] | ||
| PresetMeanFactory: type[KernelFactory] | ||
| PresetLikelihoodFactory: type[KernelFactory] |
There was a problem hiding this comment.
PresetMeanFactory and PresetLikelihoodFactory are annotated as type[KernelFactory], which is misleading and defeats type checking/documentation for preset construction. These should be type[MeanFactory] and type[LikelihoodFactory] (and imported accordingly).
| for determining the Pareto front | ||
| - Support for GPyTorch objects (kernels, means, likelihood) as Gaussian process | ||
| components, enabling full low-level customization | ||
| - `EDBO` and `SMOOTHED_EDBO` presets for `GaussianProcessSurrogate` |
There was a problem hiding this comment.
The changelog lists the preset name as SMOOTHED_EDBO, but the code introduces GaussianProcessPreset.EDBO_SMOOTHED. Please align the documented preset name with the actual enum value to avoid confusing users.
| - `EDBO` and `SMOOTHED_EDBO` presets for `GaussianProcessSurrogate` | |
| - `EDBO` and `EDBO_SMOOTHED` presets for `GaussianProcessSurrogate` |
66b5748 to
917da44
Compare
There was a problem hiding this comment.
At the moment tis not very clear to me why a baybe index kernel class is needed. But Im sure there is a reason, could you please provide it as PR description?
naively I'd think that
- the index kernel is a special case which doesn't require to be user-selectability, hence why do we need a baybe class for it?
- given your planned changes to allow arb gpytorch kernel classes, why is this class required here?
Adds the
IndexKernelclass