Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the remaining models to use new default covar & likelihood modules #2742

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 14 additions & 19 deletions ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@
from botorch.models.multitask import MultiTaskGP
from botorch.utils.types import DEFAULT
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.kernels.rbf_kernel import RBFKernel
from gpytorch.kernels.scale_kernel import ScaleKernel
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.priors.torch_priors import GammaPrior
from gpytorch.priors.torch_priors import GammaPrior, LogNormalPrior


class ModelRegistryTest(TestCase):
Expand Down Expand Up @@ -452,12 +453,7 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None:
]
)

lengthscale_priors = [
GammaPrior(6.0, 3.0),
GammaPrior(3.0, 6.0),
]

for surrogate, lengthscale_prior in zip(surrogates, lengthscale_priors):
for surrogate, default_model in zip(surrogates, (False, True)):
constructor = Models.SAAS_MTGP if use_saas else Models.ST_MTGP
mtgp = constructor(
experiment=exp,
Expand All @@ -468,26 +464,25 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None:
self.assertIsInstance(mtgp, TorchModelBridge)
self.assertIsInstance(mtgp.model, BoTorchModel)
self.assertEqual(mtgp.model.acquisition_class, Acquisition)

self.assertIsInstance(mtgp.model.surrogate.model, ModelListGP)
models = mtgp.model.surrogate.model.models

for i in range(len(models)):
for model in mtgp.model.surrogate.model.models:
self.assertIsInstance(
models[i],
model,
SaasFullyBayesianMultiTaskGP if use_saas else MultiTaskGP,
)
if use_saas is False:
self.assertIsInstance(models[i].covar_module, ScaleKernel)
base_kernel = models[i].covar_module.base_kernel
if use_saas is False and default_model is False:
self.assertIsInstance(model.covar_module, ScaleKernel)
base_kernel = model.covar_module.base_kernel
self.assertIsInstance(base_kernel, MaternKernel)
self.assertEqual(
base_kernel.lengthscale_prior.concentration,
lengthscale_prior.concentration,
base_kernel.lengthscale_prior.concentration, 6.0
)
self.assertEqual(
base_kernel.lengthscale_prior.rate,
lengthscale_prior.rate,
self.assertEqual(base_kernel.lengthscale_prior.rate, 3.0)
elif use_saas is False:
self.assertIsInstance(model.covar_module, RBFKernel)
self.assertIsInstance(
model.covar_module.lengthscale_prior, LogNormalPrior
)

gr = mtgp.gen(
Expand Down
Loading