diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py index 568e6d04335..0a3c7547f25 100644 --- a/ax/modelbridge/tests/test_registry.py +++ b/ax/modelbridge/tests/test_registry.py @@ -47,9 +47,10 @@ from botorch.models.multitask import MultiTaskGP from botorch.utils.types import DEFAULT from gpytorch.kernels.matern_kernel import MaternKernel +from gpytorch.kernels.rbf_kernel import RBFKernel from gpytorch.kernels.scale_kernel import ScaleKernel from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood -from gpytorch.priors.torch_priors import GammaPrior +from gpytorch.priors.torch_priors import GammaPrior, LogNormalPrior class ModelRegistryTest(TestCase): @@ -452,12 +453,7 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None: ] ) - lengthscale_priors = [ - GammaPrior(6.0, 3.0), - GammaPrior(3.0, 6.0), - ] - - for surrogate, lengthscale_prior in zip(surrogates, lengthscale_priors): + for surrogate, default_model in zip(surrogates, (False, True)): constructor = Models.SAAS_MTGP if use_saas else Models.ST_MTGP mtgp = constructor( experiment=exp, @@ -468,26 +464,25 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None: self.assertIsInstance(mtgp, TorchModelBridge) self.assertIsInstance(mtgp.model, BoTorchModel) self.assertEqual(mtgp.model.acquisition_class, Acquisition) - self.assertIsInstance(mtgp.model.surrogate.model, ModelListGP) - models = mtgp.model.surrogate.model.models - for i in range(len(models)): + for model in mtgp.model.surrogate.model.models: self.assertIsInstance( - models[i], + model, SaasFullyBayesianMultiTaskGP if use_saas else MultiTaskGP, ) - if use_saas is False: - self.assertIsInstance(models[i].covar_module, ScaleKernel) - base_kernel = models[i].covar_module.base_kernel + if use_saas is False and default_model is False: + self.assertIsInstance(model.covar_module, ScaleKernel) + base_kernel = model.covar_module.base_kernel self.assertIsInstance(base_kernel, MaternKernel) self.assertEqual( - base_kernel.lengthscale_prior.concentration, - lengthscale_prior.concentration, + base_kernel.lengthscale_prior.concentration, 6.0 ) - self.assertEqual( - base_kernel.lengthscale_prior.rate, - lengthscale_prior.rate, + self.assertEqual(base_kernel.lengthscale_prior.rate, 3.0) + elif use_saas is False: + self.assertIsInstance(model.covar_module, RBFKernel) + self.assertIsInstance( + model.covar_module.lengthscale_prior, LogNormalPrior ) gr = mtgp.gen(