From 0dd3b62cb031295a5ddb7422a004d038d484dd58 Mon Sep 17 00:00:00 2001 From: voetberg Date: Thu, 27 Jun 2024 15:08:18 -0500 Subject: [PATCH] Add warning that simulator is missing --- src/deepdiagnostics/client/client.py | 27 ++++++++++++------- src/deepdiagnostics/data/data.py | 8 ++++-- src/deepdiagnostics/metrics/__init__.py | 7 ++++- .../metrics/local_two_sample.py | 3 +++ src/deepdiagnostics/plots/__init__.py | 8 +++++- .../plots/predictive_posterior_check.py | 3 +++ .../plots/predictive_prior_check.py | 5 +++- .../utils/{register.py => simulator_utils.py} | 6 ++++- tests/conftest.py | 2 +- tests/test_client.py | 22 +++++++++++++++ 10 files changed, 74 insertions(+), 17 deletions(-) rename src/deepdiagnostics/utils/{register.py => simulator_utils.py} (96%) diff --git a/src/deepdiagnostics/client/client.py b/src/deepdiagnostics/client/client.py index 59c8024..90ef431 100644 --- a/src/deepdiagnostics/client/client.py +++ b/src/deepdiagnostics/client/client.py @@ -8,6 +8,7 @@ from deepdiagnostics.models import ModelModules from deepdiagnostics.metrics import Metrics from deepdiagnostics.plots import Plots +from deepdiagnostics.utils.simulator_utils import SimulatorMissingError def parser(): @@ -47,19 +48,19 @@ def parser(): # List of metrics (cannot supply specific kwargs) parser.add_argument( "--metrics", - nargs="?", - default=list(Defaults["metrics"].keys()), + nargs="+", + default=[], choices=Metrics.keys(), - help="List of metrics to run. To not run any, supply `--metrics `" + help="List of metrics to run." ) # List of plots parser.add_argument( "--plots", - nargs="?", - default=list(Defaults["plots"].keys()), + nargs="+", + default=[], choices=Plots.keys(), - help="List of plots to run. To not run any, supply `--plots `" + help="List of plots to run." ) @@ -109,9 +110,15 @@ def main(): plots = config.get_section("plots", raise_exception=False) for metrics_name, metrics_args in metrics.items(): - Metrics[metrics_name](model, data, **metrics_args)() + try: + Metrics[metrics_name](model, data, save=True)(**metrics_args) + except SimulatorMissingError: + print(f"Cannot run {metrics_name} - simulator missing.") for plot_name, plot_args in plots.items(): - Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)( - **plot_args - ) + try: + Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)( + **plot_args + ) + except SimulatorMissingError: + print(f"Cannot run {plot_name} - simulator missing.") \ No newline at end of file diff --git a/src/deepdiagnostics/data/data.py b/src/deepdiagnostics/data/data.py index 3b9539f..e536775 100644 --- a/src/deepdiagnostics/data/data.py +++ b/src/deepdiagnostics/data/data.py @@ -2,7 +2,7 @@ import numpy as np from deepdiagnostics.utils.config import get_item -from deepdiagnostics.utils.register import load_simulator +from deepdiagnostics.utils.simulator_utils import load_simulator class Data: """ @@ -35,7 +35,11 @@ def __init__( get_item("common", "random_seed", raise_exception=False) ) self.data = self._load(path) - self.simulator = load_simulator(simulator_name, simulator_kwargs) + try: + self.simulator = load_simulator(simulator_name, simulator_kwargs) + except RuntimeError: + print("Warning: Simulator not loaded. Can only run non-generative metrics.") + self.prior_dist = self.load_prior(prior, prior_kwargs) self.n_dims = self.get_theta_true().shape[1] self.simulator_dimensions = simulation_dimensions if simulation_dimensions is not None else get_item("data", "simulator_dimensions", raise_exception=False) diff --git a/src/deepdiagnostics/metrics/__init__.py b/src/deepdiagnostics/metrics/__init__.py index 18dc3df..a713120 100644 --- a/src/deepdiagnostics/metrics/__init__.py +++ b/src/deepdiagnostics/metrics/__init__.py @@ -2,8 +2,13 @@ from deepdiagnostics.metrics.coverage_fraction import CoverageFraction from deepdiagnostics.metrics.local_two_sample import LocalTwoSampleTest as LC2ST +def void(*args, **kwargs): + def void2(*args, **kwargs): + return None + return void2 + Metrics = { - "": lambda **kwargs: None, + "": void, CoverageFraction.__name__: CoverageFraction, AllSBC.__name__: AllSBC, "LC2ST": LC2ST diff --git a/src/deepdiagnostics/metrics/local_two_sample.py b/src/deepdiagnostics/metrics/local_two_sample.py index c6e7e50..c85c377 100644 --- a/src/deepdiagnostics/metrics/local_two_sample.py +++ b/src/deepdiagnostics/metrics/local_two_sample.py @@ -7,6 +7,7 @@ from sklearn.utils import shuffle from deepdiagnostics.metrics.metric import Metric +from deepdiagnostics.utils.simulator_utils import SimulatorMissingError class LocalTwoSampleTest(Metric): """ @@ -46,6 +47,8 @@ def __init__( percentiles, number_simulations ) + if not hasattr(self.data, "simulator"): + raise SimulatorMissingError("Missing a simulator to run LC2ST.") def _collect_data_params(self): # P is the prior and x_P is generated via the simulator from the parameters P. diff --git a/src/deepdiagnostics/plots/__init__.py b/src/deepdiagnostics/plots/__init__.py index a63b51c..e83274c 100644 --- a/src/deepdiagnostics/plots/__init__.py +++ b/src/deepdiagnostics/plots/__init__.py @@ -7,8 +7,14 @@ from deepdiagnostics.plots.parity import Parity from deepdiagnostics.plots.predictive_prior_check import PriorPC + +def void(*args, **kwargs): + def void2(*args, **kwargs): + return None + return void2 + Plots = { - "": lambda **kwargs: None, + "": void, CDFRanks.__name__: CDFRanks, CoverageFraction.__name__: CoverageFraction, Ranks.__name__: Ranks, diff --git a/src/deepdiagnostics/plots/predictive_posterior_check.py b/src/deepdiagnostics/plots/predictive_posterior_check.py index 09e4162..0b6bcaf 100644 --- a/src/deepdiagnostics/plots/predictive_posterior_check.py +++ b/src/deepdiagnostics/plots/predictive_posterior_check.py @@ -4,6 +4,7 @@ from deepdiagnostics.plots.plot import Display from deepdiagnostics.utils.plotting_utils import get_hex_colors +from deepdiagnostics.utils.simulator_utils import SimulatorMissingError class PPC(Display): """ @@ -33,6 +34,8 @@ def __init__( colorway =None): super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway) + if not hasattr(self.data, "simulator"): + raise SimulatorMissingError("Missing a simulator to run PPC.") def plot_name(self): return "predictive_posterior_check.png" diff --git a/src/deepdiagnostics/plots/predictive_prior_check.py b/src/deepdiagnostics/plots/predictive_prior_check.py index bd528c6..8051443 100644 --- a/src/deepdiagnostics/plots/predictive_prior_check.py +++ b/src/deepdiagnostics/plots/predictive_prior_check.py @@ -3,6 +3,7 @@ import numpy as np from deepdiagnostics.plots.plot import Display +from deepdiagnostics.utils.simulator_utils import SimulatorMissingError class PriorPC(Display): """ @@ -36,7 +37,9 @@ def __init__( colorway = None): super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway) - + if not hasattr(self.data, "simulator"): + raise SimulatorMissingError("Missing a simulator to run PriorPC.") + if self.data.simulator_dimensions == 1: self.plot_image = False diff --git a/src/deepdiagnostics/utils/register.py b/src/deepdiagnostics/utils/simulator_utils.py similarity index 96% rename from src/deepdiagnostics/utils/register.py rename to src/deepdiagnostics/utils/simulator_utils.py index 3021af9..ea8a3bf 100644 --- a/src/deepdiagnostics/utils/register.py +++ b/src/deepdiagnostics/utils/simulator_utils.py @@ -71,4 +71,8 @@ def load_simulator(name, simulator_kwargs): "Simulator improperly formed - requires a simulate method." ) - return simulator_instance \ No newline at end of file + return simulator_instance + + +class SimulatorMissingError(Exception): + pass \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 672c9d1..736cf13 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ from deepdiagnostics.data.simulator import Simulator from deepdiagnostics.models import SBIModel from deepdiagnostics.utils.config import get_item -from deepdiagnostics.utils.register import register_simulator +from deepdiagnostics.utils.simulator_utils import register_simulator class MockSimulator(Simulator): diff --git a/tests/test_client.py b/tests/test_client.py index d628afa..78a955d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -62,3 +62,25 @@ def test_main_missing_args(model_path): process = subprocess.run(command) exit_code = process.returncode assert exit_code == 1 + + +def test_missing_simulator(model_path, data_path): + command = [ + "diagnose", + "--model_path", + model_path, + "--data_path", + data_path, + "--simulator", + "Not_A_Registered_Name", + "--plots", + "PPC", + "--metrics", + "" + ] + process = subprocess.run(command, capture_output=True) + exit_code = process.returncode + stdout = process.stdout.decode("utf-8") + assert exit_code == 0 + plot_name = "PPC" + assert f"Cannot run {plot_name} - simulator missing." in stdout