diff --git a/botorch/models/approximate_gp.py b/botorch/models/approximate_gp.py index f9face99fb..5bd4beecb6 100644 --- a/botorch/models/approximate_gp.py +++ b/botorch/models/approximate_gp.py @@ -40,8 +40,8 @@ from botorch.models.transforms.outcome import OutcomeTransform from botorch.models.utils import validate_input_scaling from botorch.models.utils.gpytorch_modules import ( - get_gaussian_likelihood_with_gamma_prior, - get_matern_kernel_with_gamma_prior, + get_covar_module_with_dim_scaled_prior, + get_gaussian_likelihood_with_lognormal_prior, ) from botorch.models.utils.inducing_point_allocators import ( GreedyVarianceReduction, @@ -193,7 +193,7 @@ def __init__( this does not have to be all of the training inputs). train_Y: Not used. num_outputs: Number of output responses per input. - covar_module: Kernel function. If omitted, uses a `MaternKernel`. + covar_module: Kernel function. If omitted, uses an `RBFKernel`. mean_module: Mean of GP model. If omitted, uses a `ConstantMean`. variational_distribution: Type of variational distribution to use (default: CholeskyVariationalDistribution), the properties of the @@ -217,15 +217,10 @@ def __init__( self._aug_batch_shape = aug_batch_shape if covar_module is None: - covar_module = get_matern_kernel_with_gamma_prior( + covar_module = get_covar_module_with_dim_scaled_prior( ard_num_dims=train_X.shape[-1], batch_shape=self._aug_batch_shape, ).to(train_X) - self._subset_batch_dict = { - "mean_module.constant": -2, - "covar_module.raw_outputscale": -1, - "covar_module.base_kernel.raw_lengthscale": -3, - } if inducing_point_allocator is None: inducing_point_allocator = GreedyVarianceReduction() @@ -343,7 +338,7 @@ def __init__( either a `GaussianLikelihood` (if `num_outputs=1`) or a `MultitaskGaussianLikelihood`(if `num_outputs>1`). num_outputs: Number of output responses per input (default: 1). - covar_module: Kernel function. If omitted, uses a `MaternKernel`. + covar_module: Kernel function. If omitted, uses an `RBFKernel`. mean_module: Mean of GP model. If omitted, uses a `ConstantMean`. variational_distribution: Type of variational distribution to use (default: CholeskyVariationalDistribution), the properties of the @@ -378,7 +373,7 @@ def __init__( if likelihood is None: if num_outputs == 1: - likelihood = get_gaussian_likelihood_with_gamma_prior( + likelihood = get_gaussian_likelihood_with_lognormal_prior( batch_shape=self._aug_batch_shape ) else: diff --git a/botorch/models/contextual_multioutput.py b/botorch/models/contextual_multioutput.py index 81178f54d1..dc954a4df3 100644 --- a/botorch/models/contextual_multioutput.py +++ b/botorch/models/contextual_multioutput.py @@ -64,7 +64,7 @@ def __init__( is common across all tasks. mean_module: The mean function to be used. Defaults to `ConstantMean`. covar_module: The module for computing the covariance matrix between - the non-task features. Defaults to `MaternKernel`. + the non-task features. Defaults to `RBFKernel`. likelihood: A likelihood. The default is selected based on `train_Yvar`. If `train_Yvar` is None, a standard `GaussianLikelihood` with inferred noise level is used. Otherwise, a FixedNoiseGaussianLikelihood is used. diff --git a/botorch/models/gp_regression.py b/botorch/models/gp_regression.py index 69bc21d147..9c7f236549 100644 --- a/botorch/models/gp_regression.py +++ b/botorch/models/gp_regression.py @@ -149,7 +149,7 @@ def __init__( is None, and a `FixedNoiseGaussianLikelihood` with the given noise observations if `train_Yvar` is not None. covar_module: The module computing the covariance (Kernel) matrix. - If omitted, use a `MaternKernel`. + If omitted, uses an `RBFKernel`. mean_module: The mean function to be used. If omitted, use a `ConstantMean`. outcome_transform: An outcome transform that is applied to the @@ -207,6 +207,7 @@ def __init__( ard_num_dims=transformed_X.shape[-1], batch_shape=self._aug_batch_shape, ) + # Used for subsetting along the output dimension. See Model.subset_output. self._subset_batch_dict = { "mean_module.raw_constant": -1, "covar_module.raw_lengthscale": -3, diff --git a/botorch/models/gp_regression_fidelity.py b/botorch/models/gp_regression_fidelity.py index bb0255e835..99f7f02acc 100644 --- a/botorch/models/gp_regression_fidelity.py +++ b/botorch/models/gp_regression_fidelity.py @@ -26,7 +26,6 @@ from __future__ import annotations import warnings - from typing import Any, Optional, Union import torch @@ -39,9 +38,9 @@ ) from botorch.models.transforms.input import InputTransform from botorch.models.transforms.outcome import OutcomeTransform +from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior from botorch.utils.datasets import SupervisedDataset from gpytorch.kernels.kernel import ProductKernel -from gpytorch.kernels.rbf_kernel import RBFKernel from gpytorch.kernels.scale_kernel import ScaleKernel from gpytorch.likelihoods.likelihood import Likelihood from gpytorch.priors.torch_priors import GammaPrior @@ -153,6 +152,7 @@ def __init__( outcome_transform=outcome_transform, input_transform=input_transform, ) + # Used for subsetting along the output dimension. See Model.subset_output. self._subset_batch_dict = { "mean_module.raw_constant": -1, "covar_module.raw_outputscale": -1, @@ -273,10 +273,9 @@ def _setup_multifidelity_covar_module( non_active_dims.add(iteration_fidelity) active_dimsX = sorted(set(range(dim)) - non_active_dims) kernels.append( - RBFKernel( + get_covar_module_with_dim_scaled_prior( ard_num_dims=len(active_dimsX), batch_shape=aug_batch_shape, - lengthscale_prior=GammaPrior(3.0, 6.0), active_dims=active_dimsX, ) ) diff --git a/botorch/models/gp_regression_mixed.py b/botorch/models/gp_regression_mixed.py index fbb99b870f..5dc4697f04 100644 --- a/botorch/models/gp_regression_mixed.py +++ b/botorch/models/gp_regression_mixed.py @@ -13,15 +13,13 @@ from botorch.models.kernels.categorical import CategoricalKernel from botorch.models.transforms.input import InputTransform from botorch.models.transforms.outcome import OutcomeTransform +from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior from botorch.utils.datasets import SupervisedDataset from botorch.utils.transforms import normalize_indices from gpytorch.constraints import GreaterThan from gpytorch.kernels.kernel import Kernel -from gpytorch.kernels.matern_kernel import MaternKernel from gpytorch.kernels.scale_kernel import ScaleKernel -from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood from gpytorch.likelihoods.likelihood import Likelihood -from gpytorch.priors import GammaPrior from torch import Tensor @@ -82,7 +80,7 @@ def __init__( cont_kernel_factory: A method that accepts `batch_shape`, `ard_num_dims`, and `active_dims` arguments and returns an instantiated GPyTorch `Kernel` object to be used as the base kernel for the continuous - dimensions. If omitted, this model uses a Matern-2.5 kernel as + dimensions. If omitted, this model uses an `RBFKernel` as the kernel for the ordinal parameters. likelihood: A likelihood. If omitted, use a standard GaussianLikelihood with inferred noise level. @@ -105,30 +103,7 @@ def __init__( _, aug_batch_shape = self.get_batch_dimensions(train_X=train_X, train_Y=train_Y) if cont_kernel_factory is None: - - def cont_kernel_factory( - batch_shape: torch.Size, - ard_num_dims: int, - active_dims: list[int], - ) -> MaternKernel: - return MaternKernel( - nu=2.5, - batch_shape=batch_shape, - ard_num_dims=ard_num_dims, - active_dims=active_dims, - lengthscale_constraint=GreaterThan(1e-04), - ) - - if likelihood is None and train_Yvar is None: - # This Gamma prior is quite close to the Horseshoe prior - min_noise = 1e-5 if train_X.dtype == torch.float else 1e-6 - likelihood = GaussianLikelihood( - batch_shape=aug_batch_shape, - noise_constraint=GreaterThan( - min_noise, transform=None, initial_value=1e-3 - ), - noise_prior=GammaPrior(0.9, 10.0), - ) + cont_kernel_factory = get_covar_module_with_dim_scaled_prior d = train_X.shape[-1] cat_dims = normalize_indices(indices=cat_dims, d=d) diff --git a/botorch/models/gpytorch.py b/botorch/models/gpytorch.py index d5adb2b45e..3bc2059c14 100644 --- a/botorch/models/gpytorch.py +++ b/botorch/models/gpytorch.py @@ -543,7 +543,9 @@ def subset_output(self, idcs: list[int]) -> BatchedMultiOutputGPyTorchModel: subset_batch_dict = self._subset_batch_dict except AttributeError: raise NotImplementedError( - "subset_output requires the model to define a `_subset_dict` attribute" + "`subset_output` requires the model to define a `_subset_batch_dict` " + "attribute that lists the indices of the output dimensions in each " + "model parameter that needs to be subset." ) m = len(idcs) diff --git a/botorch/models/kernels/contextual_lcea.py b/botorch/models/kernels/contextual_lcea.py index c33526ec2c..df30cc6a0d 100644 --- a/botorch/models/kernels/contextual_lcea.py +++ b/botorch/models/kernels/contextual_lcea.py @@ -7,9 +7,9 @@ from typing import Any, Optional import torch +from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior from gpytorch.constraints import Positive from gpytorch.kernels.kernel import Kernel -from gpytorch.kernels.matern_kernel import MaternKernel from gpytorch.priors.torch_priors import GammaPrior from linear_operator.operators import DiagLinearOperator from linear_operator.operators.dense_linear_operator import DenseLinearOperator @@ -158,18 +158,14 @@ def __init__( if train_embedding: self._set_emb_layers() # task covariance matrix - self.task_covar_module = MaternKernel( - nu=2.5, + self.task_covar_module = get_covar_module_with_dim_scaled_prior( ard_num_dims=self.n_embs, batch_shape=batch_shape, - lengthscale_prior=GammaPrior(3.0, 6.0), ) # base kernel - self.base_kernel = MaternKernel( - nu=2.5, + self.base_kernel = get_covar_module_with_dim_scaled_prior( ard_num_dims=self.num_param, batch_shape=batch_shape, - lengthscale_prior=GammaPrior(3.0, 6.0), ) # outputscales for each context (note this is like sqrt of outputscale) self.context_weight = None diff --git a/botorch/models/kernels/contextual_sac.py b/botorch/models/kernels/contextual_sac.py index dcc13f2f29..dbd09b344f 100644 --- a/botorch/models/kernels/contextual_sac.py +++ b/botorch/models/kernels/contextual_sac.py @@ -7,8 +7,8 @@ from typing import Any, Optional import torch +from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior from gpytorch.kernels.kernel import Kernel -from gpytorch.kernels.matern_kernel import MaternKernel from gpytorch.kernels.scale_kernel import ScaleKernel from gpytorch.priors.torch_priors import GammaPrior from linear_operator.operators.sum_linear_operator import SumLinearOperator @@ -36,7 +36,7 @@ class SACKernel(Kernel): where * :math: M is the number of partitions of parameter space. Each partition contains same number of parameters d. Each kernel `k_i` acts only on d parameters of ith - partition i.e. `\mathbf{x}_(i)`. Each kernel `k_i` is a scaled Matern kernel + partition i.e. `\mathbf{x}_(i)`. Each kernel `k_i` is a scaled RBF kernel with same lengthscales but different outputscales. """ @@ -72,11 +72,9 @@ def __init__( for context, active_params in self.decomposition.items() } - self.base_kernel = MaternKernel( - nu=2.5, + self.base_kernel = get_covar_module_with_dim_scaled_prior( ard_num_dims=num_param, batch_shape=batch_shape, - lengthscale_prior=GammaPrior(3.0, 6.0), ) self.kernel_dict = {} # scaled kernel for each parameter space partition diff --git a/botorch/models/multitask.py b/botorch/models/multitask.py index 575577b038..807b4b55c9 100644 --- a/botorch/models/multitask.py +++ b/botorch/models/multitask.py @@ -40,7 +40,8 @@ from botorch.models.transforms.input import InputTransform from botorch.models.transforms.outcome import OutcomeTransform from botorch.models.utils.gpytorch_modules import ( - get_matern_kernel_with_gamma_prior, + get_covar_module_with_dim_scaled_prior, + get_gaussian_likelihood_with_lognormal_prior, MIN_INFERRED_NOISE_LEVEL, ) from botorch.posteriors.multitask import MultitaskGPPosterior @@ -51,12 +52,8 @@ ) from gpytorch.distributions.multivariate_normal import MultivariateNormal from gpytorch.kernels.index_kernel import IndexKernel -from gpytorch.kernels.matern_kernel import MaternKernel from gpytorch.kernels.multitask_kernel import MultitaskKernel -from gpytorch.likelihoods.gaussian_likelihood import ( - FixedNoiseGaussianLikelihood, - GaussianLikelihood, -) +from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood from gpytorch.likelihoods.likelihood import Likelihood from gpytorch.likelihoods.multitask_gaussian_likelihood import ( MultitaskGaussianLikelihood, @@ -167,7 +164,7 @@ def __init__( Note that the inferred noise is common across all tasks. mean_module: The mean function to be used. Defaults to `ConstantMean`. covar_module: The module for computing the covariance matrix between - the non-task features. Defaults to `MaternKernel`. + the non-task features. Defaults to `RBFKernel`. likelihood: A likelihood. The default is selected based on `train_Yvar`. If `train_Yvar` is None, a standard `GaussianLikelihood` with inferred noise level is used. Otherwise, a FixedNoiseGaussianLikelihood is used. @@ -233,7 +230,7 @@ def __init__( # TODO (T41270962): Support task-specific noise levels in likelihood if likelihood is None: if train_Yvar is None: - likelihood = GaussianLikelihood(noise_prior=GammaPrior(1.1, 0.05)) + likelihood = get_gaussian_likelihood_with_lognormal_prior() else: likelihood = FixedNoiseGaussianLikelihood(noise=train_Yvar.squeeze(-1)) @@ -247,7 +244,7 @@ def __init__( ) self.mean_module = mean_module or ConstantMean() if covar_module is None: - self.covar_module = get_matern_kernel_with_gamma_prior( + self.covar_module = get_covar_module_with_dim_scaled_prior( ard_num_dims=self.num_non_task_features ) else: @@ -442,7 +439,7 @@ def __init__( `MultitaskGaussianLikelihood` with a `GammaPrior(1.1, 0.05)` noise prior. data_covar_module: The module computing the covariance (Kernel) matrix - in data space. If omitted, use a `MaternKernel`. + in data space. If omitted, uses an `RBFKernel`. task_covar_prior : A Prior on the task covariance matrix. Must operate on p.s.d. matrices. A common prior for this is the `LKJ` prior. If omitted, uses `LKJCovariancePrior` with `eta` parameter as specified @@ -500,10 +497,8 @@ def __init__( base_means=ConstantMean(batch_shape=batch_shape), num_tasks=num_tasks ) if data_covar_module is None: - data_covar_module = MaternKernel( - nu=2.5, + data_covar_module = get_covar_module_with_dim_scaled_prior( ard_num_dims=ard_num_dims, - lengthscale_prior=GammaPrior(3.0, 6.0), batch_shape=batch_shape, ) else: diff --git a/botorch/models/utils/gpytorch_modules.py b/botorch/models/utils/gpytorch_modules.py index 4d231ecc24..9f096d6bf7 100644 --- a/botorch/models/utils/gpytorch_modules.py +++ b/botorch/models/utils/gpytorch_modules.py @@ -18,7 +18,7 @@ """ from math import log, sqrt -from typing import Optional, Union +from typing import Optional, Sequence, Union import torch from gpytorch.constraints.constraints import GreaterThan @@ -101,7 +101,8 @@ def get_covar_module_with_dim_scaled_prior( ard_num_dims: int, batch_shape: Optional[torch.Size] = None, use_rbf_kernel: bool = True, -) -> Union[MaternKernel, RBFKernel, ScaleKernel]: + active_dims: Optional[Sequence[int]] = None, +) -> Union[MaternKernel, RBFKernel]: """Returns an RBF or Matern kernel with priors from [Hvarfner2024vanilla]_. @@ -109,6 +110,9 @@ def get_covar_module_with_dim_scaled_prior( ard_num_dims: Number of feature dimensions for ARD. batch_shape: Batch shape for the covariance module. use_rbf_kernel: Whether to use an RBF kernel. If False, uses a Matern kernel. + active_dims: The set of input dimensions to compute the covariances on. + By default, the covariance is computed using the full input tensor. + Set this if you'd like to ignore certain dimensions. Returns: A Kernel constructed according to the given arguments. The prior is constrained @@ -123,5 +127,7 @@ def get_covar_module_with_dim_scaled_prior( lengthscale_constraint=GreaterThan( 2.5e-2, transform=None, initial_value=lengthscale_prior.mode ), + # pyre-ignore[6] GPyTorch type is unnecessarily restrictive. + active_dims=active_dims, ) return base_kernel diff --git a/docs/models.md b/docs/models.md index 0b59bcf449..79d4067f7c 100644 --- a/docs/models.md +++ b/docs/models.md @@ -121,10 +121,14 @@ instead. a fully Bayesian multi-task GP using an ICM kernel. The data kernel uses the SAAS prior to model high-dimensional parameter spaces. -All of the above models use Matérn 5/2 kernels with Automatic Relevance -Discovery (ARD), and have reasonable priors on hyperparameters that make them -work well in settings where the **input features are normalized to the unit -cube** and the **observations are standardized** (zero mean, unit variance). +All of the above models use RBF kernels with Automatic Relevance Discovery +(ARD), and have reasonable priors on hyperparameters that make them work well in +settings where the **input features are normalized to the unit cube** and the +**observations are standardized** (zero mean, unit variance). The lengthscale +priors scale with the input dimension, which makes them adaptable to both low +and high dimensional problems. See +[this discussion](https://github.com/pytorch/botorch/discussions/2451) for +additional context on the default hyperparameters. ## Other useful models @@ -182,6 +186,6 @@ model. If you wish to use gradient-based optimization algorithms, the model should allow back-propagating gradients through the samples to the model input. If you happen to implement a model that would be useful for other researchers as -well (and involves more than just swapping out the Matérn kernel for an RBF +well (and involves more than just swapping out the RBF kernel for a Matérn kernel), please consider [contributing](getting_started#contributing) this model to BoTorch. diff --git a/test/acquisition/multi_objective/test_monte_carlo.py b/test/acquisition/multi_objective/test_monte_carlo.py index b29bd92bf1..91583350ff 100644 --- a/test/acquisition/multi_objective/test_monte_carlo.py +++ b/test/acquisition/multi_objective/test_monte_carlo.py @@ -1841,7 +1841,7 @@ def test_with_multitask(self): def _test_with_multitask(self, acqf_class: type[AcquisitionFunction]): # Verify that _set_sampler works with MTGP, KroneckerMTGP and HOGP. torch.manual_seed(1234) - tkwargs = {"device": self.device, "dtype": torch.double} + tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double} train_x = torch.rand(6, 2, **tkwargs) train_y = torch.randn(6, 2, **tkwargs) mtgp_task = torch.cat( diff --git a/test/models/kernels/test_contextual.py b/test/models/kernels/test_contextual.py index 32c62ba2d1..6647548ced 100644 --- a/test/models/kernels/test_contextual.py +++ b/test/models/kernels/test_contextual.py @@ -14,7 +14,7 @@ from botorch.models.kernels.contextual_sac import SACKernel from botorch.utils.testing import BotorchTestCase -from gpytorch.kernels.matern_kernel import MaternKernel +from gpytorch.kernels.rbf_kernel import RBFKernel from torch import Tensor from torch.nn import ModuleDict @@ -25,7 +25,7 @@ def test_SACKernel(self): kernel = SACKernel(decomposition=decomposition, batch_shape=torch.Size([])) self.assertIsInstance(kernel.kernel_dict, ModuleDict) - self.assertIsInstance(kernel.base_kernel, MaternKernel) + self.assertIsInstance(kernel.base_kernel, RBFKernel) self.assertDictEqual(kernel.decomposition, decomposition) # test diag works well for lazy tensor @@ -46,8 +46,8 @@ def testLCEAKernel(self): # test init self.assertListEqual(kernel.context_list, ["1", "2"]) - self.assertIsInstance(kernel.base_kernel, MaternKernel) - self.assertIsInstance(kernel.task_covar_module, MaternKernel) + self.assertIsInstance(kernel.base_kernel, RBFKernel) + self.assertIsInstance(kernel.task_covar_module, RBFKernel) self.assertEqual(kernel.permutation, [0, 3, 1, 2]) # test raise of ValueError diff --git a/test/models/test_approximate_gp.py b/test/models/test_approximate_gp.py index 995d5f4a81..e7d0b7f980 100644 --- a/test/models/test_approximate_gp.py +++ b/test/models/test_approximate_gp.py @@ -133,7 +133,7 @@ def test_variational_setUp(self): # but that the covariance does have a gradient self.assertIsNotNone( - batched_model.model.covar_module.raw_outputscale.grad + batched_model.model.covar_module.raw_lengthscale.grad ) # check that we always have three outputs diff --git a/test/models/test_gp_regression_mixed.py b/test/models/test_gp_regression_mixed.py index bb07d64470..ee2534622a 100644 --- a/test/models/test_gp_regression_mixed.py +++ b/test/models/test_gp_regression_mixed.py @@ -20,7 +20,7 @@ from botorch.utils.test_helpers import get_pvar_expected from botorch.utils.testing import _get_random_data, BotorchTestCase from gpytorch.kernels.kernel import AdditiveKernel, ProductKernel -from gpytorch.kernels.matern_kernel import MaternKernel +from gpytorch.kernels.rbf_kernel import RBFKernel from gpytorch.kernels.scale_kernel import ScaleKernel from gpytorch.likelihoods import FixedNoiseGaussianLikelihood from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood @@ -90,10 +90,10 @@ def test_gp(self): self.assertIsInstance(prod_kernel.base_kernel, ProductKernel) sum_cont_kernel, sum_cat_kernel = sum_kernel.base_kernel.kernels prod_cont_kernel, prod_cat_kernel = prod_kernel.base_kernel.kernels - self.assertIsInstance(sum_cont_kernel, MaternKernel) + self.assertIsInstance(sum_cont_kernel, RBFKernel) self.assertIsInstance(sum_cat_kernel, ScaleKernel) self.assertIsInstance(sum_cat_kernel.base_kernel, CategoricalKernel) - self.assertIsInstance(prod_cont_kernel, MaternKernel) + self.assertIsInstance(prod_cont_kernel, RBFKernel) self.assertIsInstance(prod_cat_kernel, CategoricalKernel) else: self.assertIsInstance(model.covar_module, ScaleKernel) diff --git a/test/models/test_gpytorch.py b/test/models/test_gpytorch.py index f586ba7df0..4490d4ac38 100644 --- a/test/models/test_gpytorch.py +++ b/test/models/test_gpytorch.py @@ -414,6 +414,7 @@ def test_posterior_transform(self): class TestModelListGPyTorchModel(BotorchTestCase): def test_model_list_gpytorch_model(self): + torch.manual_seed(12345) for dtype in (torch.float, torch.double): tkwargs = {"device": self.device, "dtype": dtype} train_X1, train_X2 = ( @@ -512,16 +513,15 @@ def test_model_list_gpytorch_model(self): self.assertEqual( posterior_subset.mean.shape, torch.Size([2, len(output_indices)]) ) - self.assertTrue( - torch.allclose( - posterior_subset.mean, posterior.mean[..., output_indices] - ) + self.assertAllClose( + posterior_subset.mean, + posterior.mean[..., output_indices], + atol=1e-6, ) - self.assertTrue( - torch.allclose( - posterior_subset.variance, - posterior.variance[..., output_indices], - ) + self.assertAllClose( + posterior_subset.variance, + posterior.variance[..., output_indices], + atol=1e-6, ) # test observation noise model = SimpleModelListGPyTorchModel(m1, m2) diff --git a/test/models/test_multitask.py b/test/models/test_multitask.py index 45218d89a3..9e5582451e 100644 --- a/test/models/test_multitask.py +++ b/test/models/test_multitask.py @@ -26,13 +26,7 @@ from botorch.utils.test_helpers import gen_multi_task_dataset from botorch.utils.testing import BotorchTestCase from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal -from gpytorch.kernels import ( - IndexKernel, - MaternKernel, - MultitaskKernel, - RBFKernel, - ScaleKernel, -) +from gpytorch.kernels import IndexKernel, MaternKernel, MultitaskKernel, RBFKernel from gpytorch.likelihoods import ( FixedNoiseGaussianLikelihood, GaussianLikelihood, @@ -162,10 +156,8 @@ def test_MultiTaskGP(self) -> None: else: self.assertIsInstance(model.likelihood, GaussianLikelihood) self.assertIsInstance(model.mean_module, ConstantMean) - self.assertIsInstance(model.covar_module, ScaleKernel) - matern_kernel = model.covar_module.base_kernel - self.assertIsInstance(matern_kernel, MaternKernel) - self.assertIsInstance(matern_kernel.lengthscale_prior, GammaPrior) + self.assertIsInstance(model.covar_module, RBFKernel) + self.assertIsInstance(model.covar_module.lengthscale_prior, LogNormalPrior) self.assertIsInstance(model.task_covar_module, IndexKernel) self.assertEqual(model._rank, 2) self.assertEqual( @@ -334,10 +326,8 @@ def test_MultiTaskGP_single_output(self) -> None: self.assertEqual(model.num_outputs, 1) self.assertIsInstance(model.likelihood, GaussianLikelihood) self.assertIsInstance(model.mean_module, ConstantMean) - self.assertIsInstance(model.covar_module, ScaleKernel) - matern_kernel = model.covar_module.base_kernel - self.assertIsInstance(matern_kernel, MaternKernel) - self.assertIsInstance(matern_kernel.lengthscale_prior, GammaPrior) + self.assertIsInstance(model.covar_module, RBFKernel) + self.assertIsInstance(model.covar_module.lengthscale_prior, LogNormalPrior) self.assertIsInstance(model.task_covar_module, IndexKernel) self.assertEqual(model._rank, 2) self.assertEqual( @@ -534,16 +524,14 @@ def test_KroneckerMultiTaskGP_default(self) -> None: self.assertIsInstance(model.mean_module, MultitaskMean) self.assertIsInstance(model.covar_module, MultitaskKernel) base_kernel = model.covar_module - self.assertIsInstance(base_kernel.data_covar_module, MaternKernel) + self.assertIsInstance(base_kernel.data_covar_module, RBFKernel) self.assertIsInstance(base_kernel.task_covar_module, IndexKernel) task_covar_prior = base_kernel.task_covar_module.IndexKernelPrior self.assertIsInstance(task_covar_prior, LKJCovariancePrior) self.assertEqual(task_covar_prior.correlation_prior.eta, 1.5) self.assertIsInstance(task_covar_prior.sd_prior, SmoothedBoxPrior) lengthscale_prior = base_kernel.data_covar_module.lengthscale_prior - self.assertIsInstance(lengthscale_prior, GammaPrior) - self.assertEqual(lengthscale_prior.concentration, 3.0) - self.assertEqual(lengthscale_prior.rate, 6.0) + self.assertIsInstance(lengthscale_prior, LogNormalPrior) self.assertEqual(base_kernel.task_covar_module.covar_factor.shape[-1], 2) # test model fitting