Skip to content

Commit

Permalink
Update the remaining models to use new default covar & likelihood mod…
Browse files Browse the repository at this point in the history
…ules (#2507)

Summary:
X-link: facebook/Ax#2742

Pull Request resolved: #2507

Updates the default covar & likelihood modules of BoTorch models. See #2451 for details on the new defaults.

For models that utilize a composite kernel, such as multi-fidelity/task/context, this change only affects the base kernel.

Exceptions / Models that do not utilize the new modules:
- Fully-bayesian models.
- Pairwise GP.
- Higher order GP: Produced weird division by zero errors after the change.
- Fidelity kernels for MF models.
- (likelihood only) Any model that utilizes a likelihood other than `GaussianLikelihood` (e.g., `MultiTaskGaussianLikelihood`).

Reviewed By: esantorella

Differential Revision: D62196414

fbshipit-source-id: e2c8983a49a9f00d878e1fb7cf346212acb895e9
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Sep 5, 2024
1 parent 3db1a0e commit 1417189
Show file tree
Hide file tree
Showing 17 changed files with 74 additions and 115 deletions.
17 changes: 6 additions & 11 deletions botorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.models.utils import validate_input_scaling
from botorch.models.utils.gpytorch_modules import (
get_gaussian_likelihood_with_gamma_prior,
get_matern_kernel_with_gamma_prior,
get_covar_module_with_dim_scaled_prior,
get_gaussian_likelihood_with_lognormal_prior,
)
from botorch.models.utils.inducing_point_allocators import (
GreedyVarianceReduction,
Expand Down Expand Up @@ -193,7 +193,7 @@ def __init__(
this does not have to be all of the training inputs).
train_Y: Not used.
num_outputs: Number of output responses per input.
covar_module: Kernel function. If omitted, uses a `MaternKernel`.
covar_module: Kernel function. If omitted, uses an `RBFKernel`.
mean_module: Mean of GP model. If omitted, uses a `ConstantMean`.
variational_distribution: Type of variational distribution to use
(default: CholeskyVariationalDistribution), the properties of the
Expand All @@ -217,15 +217,10 @@ def __init__(
self._aug_batch_shape = aug_batch_shape

if covar_module is None:
covar_module = get_matern_kernel_with_gamma_prior(
covar_module = get_covar_module_with_dim_scaled_prior(
ard_num_dims=train_X.shape[-1],
batch_shape=self._aug_batch_shape,
).to(train_X)
self._subset_batch_dict = {
"mean_module.constant": -2,
"covar_module.raw_outputscale": -1,
"covar_module.base_kernel.raw_lengthscale": -3,
}

if inducing_point_allocator is None:
inducing_point_allocator = GreedyVarianceReduction()
Expand Down Expand Up @@ -343,7 +338,7 @@ def __init__(
either a `GaussianLikelihood` (if `num_outputs=1`) or a
`MultitaskGaussianLikelihood`(if `num_outputs>1`).
num_outputs: Number of output responses per input (default: 1).
covar_module: Kernel function. If omitted, uses a `MaternKernel`.
covar_module: Kernel function. If omitted, uses an `RBFKernel`.
mean_module: Mean of GP model. If omitted, uses a `ConstantMean`.
variational_distribution: Type of variational distribution to use
(default: CholeskyVariationalDistribution), the properties of the
Expand Down Expand Up @@ -378,7 +373,7 @@ def __init__(

if likelihood is None:
if num_outputs == 1:
likelihood = get_gaussian_likelihood_with_gamma_prior(
likelihood = get_gaussian_likelihood_with_lognormal_prior(
batch_shape=self._aug_batch_shape
)
else:
Expand Down
2 changes: 1 addition & 1 deletion botorch/models/contextual_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
is common across all tasks.
mean_module: The mean function to be used. Defaults to `ConstantMean`.
covar_module: The module for computing the covariance matrix between
the non-task features. Defaults to `MaternKernel`.
the non-task features. Defaults to `RBFKernel`.
likelihood: A likelihood. The default is selected based on `train_Yvar`.
If `train_Yvar` is None, a standard `GaussianLikelihood` with inferred
noise level is used. Otherwise, a FixedNoiseGaussianLikelihood is used.
Expand Down
3 changes: 2 additions & 1 deletion botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(
is None, and a `FixedNoiseGaussianLikelihood` with the given
noise observations if `train_Yvar` is not None.
covar_module: The module computing the covariance (Kernel) matrix.
If omitted, use a `MaternKernel`.
If omitted, uses an `RBFKernel`.
mean_module: The mean function to be used. If omitted, use a
`ConstantMean`.
outcome_transform: An outcome transform that is applied to the
Expand Down Expand Up @@ -207,6 +207,7 @@ def __init__(
ard_num_dims=transformed_X.shape[-1],
batch_shape=self._aug_batch_shape,
)
# Used for subsetting along the output dimension. See Model.subset_output.
self._subset_batch_dict = {
"mean_module.raw_constant": -1,
"covar_module.raw_lengthscale": -3,
Expand Down
7 changes: 3 additions & 4 deletions botorch/models/gp_regression_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from __future__ import annotations

import warnings

from typing import Any, Optional, Union

import torch
Expand All @@ -39,9 +38,9 @@
)
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior
from botorch.utils.datasets import SupervisedDataset
from gpytorch.kernels.kernel import ProductKernel
from gpytorch.kernels.rbf_kernel import RBFKernel
from gpytorch.kernels.scale_kernel import ScaleKernel
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.priors.torch_priors import GammaPrior
Expand Down Expand Up @@ -153,6 +152,7 @@ def __init__(
outcome_transform=outcome_transform,
input_transform=input_transform,
)
# Used for subsetting along the output dimension. See Model.subset_output.
self._subset_batch_dict = {
"mean_module.raw_constant": -1,
"covar_module.raw_outputscale": -1,
Expand Down Expand Up @@ -273,10 +273,9 @@ def _setup_multifidelity_covar_module(
non_active_dims.add(iteration_fidelity)
active_dimsX = sorted(set(range(dim)) - non_active_dims)
kernels.append(
RBFKernel(
get_covar_module_with_dim_scaled_prior(
ard_num_dims=len(active_dimsX),
batch_shape=aug_batch_shape,
lengthscale_prior=GammaPrior(3.0, 6.0),
active_dims=active_dimsX,
)
)
Expand Down
31 changes: 3 additions & 28 deletions botorch/models/gp_regression_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@
from botorch.models.kernels.categorical import CategoricalKernel
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.transforms import normalize_indices
from gpytorch.constraints import GreaterThan
from gpytorch.kernels.kernel import Kernel
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.kernels.scale_kernel import ScaleKernel
from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.priors import GammaPrior
from torch import Tensor


Expand Down Expand Up @@ -82,7 +80,7 @@ def __init__(
cont_kernel_factory: A method that accepts `batch_shape`, `ard_num_dims`,
and `active_dims` arguments and returns an instantiated GPyTorch
`Kernel` object to be used as the base kernel for the continuous
dimensions. If omitted, this model uses a Matern-2.5 kernel as
dimensions. If omitted, this model uses an `RBFKernel` as
the kernel for the ordinal parameters.
likelihood: A likelihood. If omitted, use a standard
GaussianLikelihood with inferred noise level.
Expand All @@ -105,30 +103,7 @@ def __init__(
_, aug_batch_shape = self.get_batch_dimensions(train_X=train_X, train_Y=train_Y)

if cont_kernel_factory is None:

def cont_kernel_factory(
batch_shape: torch.Size,
ard_num_dims: int,
active_dims: list[int],
) -> MaternKernel:
return MaternKernel(
nu=2.5,
batch_shape=batch_shape,
ard_num_dims=ard_num_dims,
active_dims=active_dims,
lengthscale_constraint=GreaterThan(1e-04),
)

if likelihood is None and train_Yvar is None:
# This Gamma prior is quite close to the Horseshoe prior
min_noise = 1e-5 if train_X.dtype == torch.float else 1e-6
likelihood = GaussianLikelihood(
batch_shape=aug_batch_shape,
noise_constraint=GreaterThan(
min_noise, transform=None, initial_value=1e-3
),
noise_prior=GammaPrior(0.9, 10.0),
)
cont_kernel_factory = get_covar_module_with_dim_scaled_prior

d = train_X.shape[-1]
cat_dims = normalize_indices(indices=cat_dims, d=d)
Expand Down
4 changes: 3 additions & 1 deletion botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,9 @@ def subset_output(self, idcs: list[int]) -> BatchedMultiOutputGPyTorchModel:
subset_batch_dict = self._subset_batch_dict
except AttributeError:
raise NotImplementedError(
"subset_output requires the model to define a `_subset_dict` attribute"
"`subset_output` requires the model to define a `_subset_batch_dict` "
"attribute that lists the indices of the output dimensions in each "
"model parameter that needs to be subset."
)

m = len(idcs)
Expand Down
10 changes: 3 additions & 7 deletions botorch/models/kernels/contextual_lcea.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from typing import Any, Optional

import torch
from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior
from gpytorch.constraints import Positive
from gpytorch.kernels.kernel import Kernel
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.priors.torch_priors import GammaPrior
from linear_operator.operators import DiagLinearOperator
from linear_operator.operators.dense_linear_operator import DenseLinearOperator
Expand Down Expand Up @@ -158,18 +158,14 @@ def __init__(
if train_embedding:
self._set_emb_layers()
# task covariance matrix
self.task_covar_module = MaternKernel(
nu=2.5,
self.task_covar_module = get_covar_module_with_dim_scaled_prior(
ard_num_dims=self.n_embs,
batch_shape=batch_shape,
lengthscale_prior=GammaPrior(3.0, 6.0),
)
# base kernel
self.base_kernel = MaternKernel(
nu=2.5,
self.base_kernel = get_covar_module_with_dim_scaled_prior(
ard_num_dims=self.num_param,
batch_shape=batch_shape,
lengthscale_prior=GammaPrior(3.0, 6.0),
)
# outputscales for each context (note this is like sqrt of outputscale)
self.context_weight = None
Expand Down
8 changes: 3 additions & 5 deletions botorch/models/kernels/contextual_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing import Any, Optional

import torch
from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior
from gpytorch.kernels.kernel import Kernel
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.kernels.scale_kernel import ScaleKernel
from gpytorch.priors.torch_priors import GammaPrior
from linear_operator.operators.sum_linear_operator import SumLinearOperator
Expand Down Expand Up @@ -36,7 +36,7 @@ class SACKernel(Kernel):
where
* :math: M is the number of partitions of parameter space. Each partition contains
same number of parameters d. Each kernel `k_i` acts only on d parameters of ith
partition i.e. `\mathbf{x}_(i)`. Each kernel `k_i` is a scaled Matern kernel
partition i.e. `\mathbf{x}_(i)`. Each kernel `k_i` is a scaled RBF kernel
with same lengthscales but different outputscales.
"""

Expand Down Expand Up @@ -72,11 +72,9 @@ def __init__(
for context, active_params in self.decomposition.items()
}

self.base_kernel = MaternKernel(
nu=2.5,
self.base_kernel = get_covar_module_with_dim_scaled_prior(
ard_num_dims=num_param,
batch_shape=batch_shape,
lengthscale_prior=GammaPrior(3.0, 6.0),
)

self.kernel_dict = {} # scaled kernel for each parameter space partition
Expand Down
21 changes: 8 additions & 13 deletions botorch/models/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.models.utils.gpytorch_modules import (
get_matern_kernel_with_gamma_prior,
get_covar_module_with_dim_scaled_prior,
get_gaussian_likelihood_with_lognormal_prior,
MIN_INFERRED_NOISE_LEVEL,
)
from botorch.posteriors.multitask import MultitaskGPPosterior
Expand All @@ -51,12 +52,8 @@
)
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels.index_kernel import IndexKernel
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.kernels.multitask_kernel import MultitaskKernel
from gpytorch.likelihoods.gaussian_likelihood import (
FixedNoiseGaussianLikelihood,
GaussianLikelihood,
)
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.likelihoods.multitask_gaussian_likelihood import (
MultitaskGaussianLikelihood,
Expand Down Expand Up @@ -167,7 +164,7 @@ def __init__(
Note that the inferred noise is common across all tasks.
mean_module: The mean function to be used. Defaults to `ConstantMean`.
covar_module: The module for computing the covariance matrix between
the non-task features. Defaults to `MaternKernel`.
the non-task features. Defaults to `RBFKernel`.
likelihood: A likelihood. The default is selected based on `train_Yvar`.
If `train_Yvar` is None, a standard `GaussianLikelihood` with inferred
noise level is used. Otherwise, a FixedNoiseGaussianLikelihood is used.
Expand Down Expand Up @@ -233,7 +230,7 @@ def __init__(
# TODO (T41270962): Support task-specific noise levels in likelihood
if likelihood is None:
if train_Yvar is None:
likelihood = GaussianLikelihood(noise_prior=GammaPrior(1.1, 0.05))
likelihood = get_gaussian_likelihood_with_lognormal_prior()
else:
likelihood = FixedNoiseGaussianLikelihood(noise=train_Yvar.squeeze(-1))

Expand All @@ -247,7 +244,7 @@ def __init__(
)
self.mean_module = mean_module or ConstantMean()
if covar_module is None:
self.covar_module = get_matern_kernel_with_gamma_prior(
self.covar_module = get_covar_module_with_dim_scaled_prior(
ard_num_dims=self.num_non_task_features
)
else:
Expand Down Expand Up @@ -442,7 +439,7 @@ def __init__(
`MultitaskGaussianLikelihood` with a `GammaPrior(1.1, 0.05)`
noise prior.
data_covar_module: The module computing the covariance (Kernel) matrix
in data space. If omitted, use a `MaternKernel`.
in data space. If omitted, uses an `RBFKernel`.
task_covar_prior : A Prior on the task covariance matrix. Must operate
on p.s.d. matrices. A common prior for this is the `LKJ` prior. If
omitted, uses `LKJCovariancePrior` with `eta` parameter as specified
Expand Down Expand Up @@ -500,10 +497,8 @@ def __init__(
base_means=ConstantMean(batch_shape=batch_shape), num_tasks=num_tasks
)
if data_covar_module is None:
data_covar_module = MaternKernel(
nu=2.5,
data_covar_module = get_covar_module_with_dim_scaled_prior(
ard_num_dims=ard_num_dims,
lengthscale_prior=GammaPrior(3.0, 6.0),
batch_shape=batch_shape,
)
else:
Expand Down
10 changes: 8 additions & 2 deletions botorch/models/utils/gpytorch_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""

from math import log, sqrt
from typing import Optional, Union
from typing import Optional, Sequence, Union

import torch
from gpytorch.constraints.constraints import GreaterThan
Expand Down Expand Up @@ -101,14 +101,18 @@ def get_covar_module_with_dim_scaled_prior(
ard_num_dims: int,
batch_shape: Optional[torch.Size] = None,
use_rbf_kernel: bool = True,
) -> Union[MaternKernel, RBFKernel, ScaleKernel]:
active_dims: Optional[Sequence[int]] = None,
) -> Union[MaternKernel, RBFKernel]:
"""Returns an RBF or Matern kernel with priors
from [Hvarfner2024vanilla]_.
Args:
ard_num_dims: Number of feature dimensions for ARD.
batch_shape: Batch shape for the covariance module.
use_rbf_kernel: Whether to use an RBF kernel. If False, uses a Matern kernel.
active_dims: The set of input dimensions to compute the covariances on.
By default, the covariance is computed using the full input tensor.
Set this if you'd like to ignore certain dimensions.
Returns:
A Kernel constructed according to the given arguments. The prior is constrained
Expand All @@ -123,5 +127,7 @@ def get_covar_module_with_dim_scaled_prior(
lengthscale_constraint=GreaterThan(
2.5e-2, transform=None, initial_value=lengthscale_prior.mode
),
# pyre-ignore[6] GPyTorch type is unnecessarily restrictive.
active_dims=active_dims,
)
return base_kernel
14 changes: 9 additions & 5 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,14 @@ instead.
a fully Bayesian multi-task GP using an ICM kernel. The data kernel uses the
SAAS prior to model high-dimensional parameter spaces.

All of the above models use Matérn 5/2 kernels with Automatic Relevance
Discovery (ARD), and have reasonable priors on hyperparameters that make them
work well in settings where the **input features are normalized to the unit
cube** and the **observations are standardized** (zero mean, unit variance).
All of the above models use RBF kernels with Automatic Relevance Discovery
(ARD), and have reasonable priors on hyperparameters that make them work well in
settings where the **input features are normalized to the unit cube** and the
**observations are standardized** (zero mean, unit variance). The lengthscale
priors scale with the input dimension, which makes them adaptable to both low
and high dimensional problems. See
[this discussion](https://github.com/pytorch/botorch/discussions/2451) for
additional context on the default hyperparameters.

## Other useful models

Expand Down Expand Up @@ -182,6 +186,6 @@ model. If you wish to use gradient-based optimization algorithms, the model
should allow back-propagating gradients through the samples to the model input.

If you happen to implement a model that would be useful for other researchers as
well (and involves more than just swapping out the Matérn kernel for an RBF
well (and involves more than just swapping out the RBF kernel for a Matérn
kernel), please consider [contributing](getting_started#contributing) this model
to BoTorch.
2 changes: 1 addition & 1 deletion test/acquisition/multi_objective/test_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,7 +1841,7 @@ def test_with_multitask(self):
def _test_with_multitask(self, acqf_class: type[AcquisitionFunction]):
# Verify that _set_sampler works with MTGP, KroneckerMTGP and HOGP.
torch.manual_seed(1234)
tkwargs = {"device": self.device, "dtype": torch.double}
tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double}
train_x = torch.rand(6, 2, **tkwargs)
train_y = torch.randn(6, 2, **tkwargs)
mtgp_task = torch.cat(
Expand Down
Loading

0 comments on commit 1417189

Please sign in to comment.