diff --git a/chirho/contrib/__init__.py b/chirho/contrib/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/chirho/contrib/compexp/__init__.py b/chirho/contrib/compexp/__init__.py new file mode 100644 index 000000000..11237f5e2 --- /dev/null +++ b/chirho/contrib/compexp/__init__.py @@ -0,0 +1,8 @@ +from .composeable_expectation.composed_expectation import ComposedExpectation +from .composeable_expectation.expectation_atom import ExpectationAtom +from .handlers.expectation_handler import ExpectationHandler +from .handlers.importance_sampling_expectation_handler import ImportanceSamplingExpectationHandler +from .handlers.montecarlo_expectation_handler import MonteCarloExpectationHandler +from .typedecs import StochasticFunction, ExpectationFunction +from .composeable_expectation import grad +E = ExpectationAtom diff --git a/chirho/contrib/compexp/composeable_expectation/__init__.py b/chirho/contrib/compexp/composeable_expectation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/chirho/contrib/compexp/composeable_expectation/composed_expectation.py b/chirho/contrib/compexp/composeable_expectation/composed_expectation.py new file mode 100644 index 000000000..726b8a553 --- /dev/null +++ b/chirho/contrib/compexp/composeable_expectation/composed_expectation.py @@ -0,0 +1,55 @@ +from typing import Callable, List, TYPE_CHECKING +from ..typedecs import ModelType +import torch +from torch import Tensor as TT +if TYPE_CHECKING: + from .expectation_atom import ExpectationAtom + + +class ComposedExpectation: + def __init__( + self, + children: List["ComposedExpectation"], + op: Callable[[TT, ...], TT], + parts: List["ExpectationAtom"] + ): + self.op = op + self.children = children + self.parents: List[ComposedExpectation] = [] + self.parts = parts + + self._normalization_constant_cancels = False + + def recursively_refresh_parts(self): + self.parts = None + + for child in self.children: + child.recursively_refresh_parts() + + self.parts = [] + for child in self.children: + self.parts.extend(child.parts) + + def __op_other(self, other: "ComposedExpectation", op) -> "ComposedExpectation": + ce = ComposedExpectation( + children=[self, other], + op=op, + parts=self.parts + other.parts) + self.parents.append(ce) + other.parents.append(ce) + return ce + + def __truediv__(self, other: "ComposedExpectation") -> "ComposedExpectation": + return self.__op_other(other, torch.divide) + + def __add__(self, other: "ComposedExpectation") -> "ComposedExpectation": + return self.__op_other(other, torch.add) + + def __mul__(self, other: "ComposedExpectation") -> "ComposedExpectation": + return self.__op_other(other, torch.multiply) + + def __sub__(self, other: "ComposedExpectation") -> "ComposedExpectation": + return self.__op_other(other, torch.subtract) + + def __call__(self, p: ModelType) -> TT: + return self.op(*[child(p) for child in self.children]) diff --git a/chirho/contrib/compexp/composeable_expectation/expectation_atom.py b/chirho/contrib/compexp/composeable_expectation/expectation_atom.py new file mode 100644 index 000000000..6cc8dc8e0 --- /dev/null +++ b/chirho/contrib/compexp/composeable_expectation/expectation_atom.py @@ -0,0 +1,93 @@ +import pyro +from typing import Optional +import torch +from torch import tensor as tt +from torch import Tensor as TT +from ..typedecs import ModelType, KWType +from .composed_expectation import ComposedExpectation +from ..typedecs import StochasticFunction +from ..ops import _compute_expectation_atom + + +class ExpectationAtom(ComposedExpectation): + + def __init__( + self, + f: StochasticFunction, # TODO say in docstring that this has to have scalar valued output. + name: str, + log_fac_eps: float = 1e-45, + guide: Optional[ModelType] = None): + + self.f = f + self.name = name + self.log_fac_eps = log_fac_eps + self.guide = guide + + self._is_positive_everywhere = False + + super().__init__(children=[], op=lambda v: v, parts=[self]) + + def recursively_refresh_parts(self): + self.parts = [self] + + def __call__(self, p: ModelType) -> TT: + """ + Overrides the non-atomic call to actually estimate the value of this expectation atom. + """ + ret = _compute_expectation_atom(self, p) + if ret.ndim != 0: + raise ValueError(f"Argument f to {ExpectationAtom.__name__} with name {self.name} must return a scalar," + f" but got {ret} instead.") + return ret + + def build_pseudo_density(self, p: ModelType) -> ModelType: + + if not self._is_positive_everywhere: + raise NotImplementedError("Non positive pseudo-density construction is not supported. " + f"Convert atom named {self.name} by using output of " + "CompoundedExpectation.split_into_positive_components().") + + # This defines a new density that is the product of the density defined by p and this all-positive function + # we want to take the expectation wrt. + def pseudo_density() -> KWType: + stochastics = p() + factor_name = self.name + "_factor" + pyro.factor(factor_name, torch.log(self.log_fac_eps + self.f(stochastics))) + return stochastics + + return pseudo_density + + # TODO maybe rename this to get_tabi_decomposition. + def split_into_positive_components( + self, + # TODO bdt18dosjk maybe don't allow for guide specification, but rather handler specification that + # may specify a particular guide arrangement? + pos_guide: Optional[ModelType] = None, + neg_guide: Optional[ModelType] = None, + den_guide: Optional[ModelType] = None) -> "ComposedExpectation": + + pos_part = ExpectationAtom( + f=lambda s: torch.relu(self.f(s)), name=self.name + "_split_pos", guide=pos_guide) + pos_part._is_positive_everywhere = True + neg_part = ExpectationAtom( + f=lambda s: torch.relu(-self.f(s)), name=self.name + "_split_neg", guide=neg_guide) + neg_part._is_positive_everywhere = True + den_part = ExpectationAtom( + lambda s: tt(1.), name=self.name + "_split_den", guide=den_guide) + den_part._is_positive_everywhere = True + + ret: ComposedExpectation = (pos_part - neg_part) / den_part + ret._normalization_constant_cancels = True + + return ret + + def swap_self_for_other_child(self, other): + for parent in self.parents: + positions_as_child = [i for i, child in enumerate(parent.children) if child is self] + assert len(positions_as_child) >= 1, "This shouldn't be possible." \ + " There's a reference mismatch with parents." + # Now, swap out the old atom with the new composite. + for pac in positions_as_child: + parent.children[pac] = other + other.parents.append(parent) + self.parents = [] diff --git a/chirho/contrib/compexp/composeable_expectation/grad.py b/chirho/contrib/compexp/composeable_expectation/grad.py new file mode 100644 index 000000000..8b5311adb --- /dev/null +++ b/chirho/contrib/compexp/composeable_expectation/grad.py @@ -0,0 +1,113 @@ +from typing import Callable +from ..typedecs import KWType +import torch +from torch import Tensor as TT +import warnings +from .expectation_atom import ExpectationAtom +from .composed_expectation import ComposedExpectation + + +def _build_df_dd(dparams: TT, di: int, part: ExpectationAtom) -> Callable[[KWType], TT]: + def df_dd(stochastics: KWType) -> TT: + y: TT = part.f(stochastics) + + if y.ndim != 0: + raise ValueError(f"Argument f to {ExpectationAtom.__name__} with name {part.name} must return a scalar," + f" but got {y} instead.") + + assert dparams[di].ndim == 0, "This shouldn't be possible due to the outer check of 1 dimension." + + try: + df_ddparam, = torch.autograd.grad( + outputs=(y,), + # FIXME HACK Have to grad wrt the whole tensor apparently, and then index after. + inputs=(dparams,), + create_graph=True + ) + df_ddparam = df_ddparam[di] + except RuntimeError as e: + if "does not require grad and does not have a grad_fn" in str(e): + # FIXME FIXME kf2801dgi1 this is only correct when this particular atom is a mul or div of one + # that does require grad. It would be nice to not have to repro autodiff here but + # somehow this needs to how the parent incorporates this. Maybe we could autodiff + # parent's op and see how this atom relates, then use that to determine what + # should be returned here? + warnings.warn(f"The gradient of atom named {part.name} with respect to dparam {di}" + f" is 0.0, but returning the original atom's value for now because" + f" it's probably a scaling factor. This is a hack and should be fixed." + f" See FIXME tagged kf2801dgi1.") + return y + else: + raise + + assert df_ddparam.ndim == 0, "This shouldn't be possible due to out and in being 0 dimensional." + + return df_ddparam + + return df_dd + + +# FIXME 7301ykd0sk See below. Want to conver to proper in place operation with no return value. +def gradify_in_place_but_need_return( + output: ComposedExpectation, dparams: TT, split_atoms=False) -> ComposedExpectation: + + if dparams.ndim != 1: + raise ValueError(f"Argument dparams to {gradify_in_place_but_need_return.__name__} must be a 1d tensor, " + f"but got ndim {dparams.ndim} instead.") + + assert len(output.parts) >= 1, "This shouldn't be possible due to composites always having at least one " \ + "part (themselves)." + + if not len(dparams) >= 1: + raise ValueError(f"Argument dparams to {gradify_in_place_but_need_return.__name__} must have at least one " + f"element, but got {len(dparams)} instead.") + + # Only relevant if output is an atom. Just defining outside of loop so type checking is happy below. + sub_atom_composite = None + + for part in output.parts: + + sub_atoms = [] + + # Create a new atom for each of the old atoms. + for di, _ in enumerate(dparams): + + # Create a new atom just for just this element of the gradient vector. + ea = ExpectationAtom( + f=_build_df_dd(dparams, di, part), + name=f"d{part.name}_dd{di}", + log_fac_eps=part.log_fac_eps + # TODO maybe seed a new guide with the original guide (if present)? + ) + + if split_atoms: + ea = ea.split_into_positive_components() + + sub_atoms.append(ea) + + # Create a composite that simply concatenates the new atoms into one tensor. + sub_atom_composite = ComposedExpectation( + children=sub_atoms, + op=lambda *v: torch.stack(v, dim=0), + # Note bm72gdi1: This will be updated before the return of this function. + parts=[] + ) + + for parent in part.parents: + positions_as_child = [i for i, child in enumerate(parent.children) if child is part] + assert len(positions_as_child) >= 1, "This shouldn't be possible." \ + " There's a reference mismatch with parents." + # Now, swap out the old atom with the new composite. + for pac in positions_as_child: + parent.children[pac] = sub_atom_composite + sub_atom_composite.parents.append(parent) + + if isinstance(output, ExpectationAtom): + assert output.parts[0] and len(output.parts) == 1, "This shouldn't be possible: atom code broken?" + # FIXME 7301ykd0sk this is why you have to take the return value here... + output = sub_atom_composite + + # Note bm72gdi1 this ensures that the cached part list is up-to-date. + output.recursively_refresh_parts() + + return output diff --git a/chirho/contrib/compexp/handlers/__init__.py b/chirho/contrib/compexp/handlers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/chirho/contrib/compexp/handlers/expectation_handler.py b/chirho/contrib/compexp/handlers/expectation_handler.py new file mode 100644 index 000000000..2737ce52b --- /dev/null +++ b/chirho/contrib/compexp/handlers/expectation_handler.py @@ -0,0 +1,12 @@ +import pyro + + +class ExpectationHandler(pyro.poutine.messenger.Messenger): + + # noinspection PyMethodMayBeStatic + def _pyro_compute_expectation_atom(self, msg) -> None: + if msg["done"]: + # TODO bdt18dosjk Do something similar to the demo1 setup where we define an OOP interface as well. + # Then handler can be passed to expectation atoms so that users can specify different handlers + # for different atoms. + raise RuntimeError("Only one default expectation handler can be in effect at a time.") diff --git a/chirho/contrib/compexp/handlers/guide_registration_mixin.py b/chirho/contrib/compexp/handlers/guide_registration_mixin.py new file mode 100644 index 000000000..18fd7b62d --- /dev/null +++ b/chirho/contrib/compexp/handlers/guide_registration_mixin.py @@ -0,0 +1,173 @@ +import pyro +from typing import Callable, List, Tuple, Type, Optional +import torch +from pyro.infer.autoguide import AutoGuide +from ..typedecs import ModelType, KWType +from ..composeable_expectation.composed_expectation import ComposedExpectation +from ..composeable_expectation.expectation_atom import ExpectationAtom + + +class _GuideRegistrationMixin: + def __init__(self): + self.pseudo_densities = dict() + self.guides = dict() + self.registered_model = None + + def optimize_guides(self, lr: float, n_steps: int, + adjust_grads_: Callable[[torch.nn.Parameter, ...], None] = None, + callback: Optional[Callable[[str, int], None]] = None, + start_scale=1.0): + if not len(self.keys()): + raise ValueError("No guides registered. Did you call " + f"{_GuideRegistrationMixin.__name__}.register_guides?") + + for k in self.keys(): + pseudo_density = self.pseudo_densities[k] + guide = self.guides[k] + + current_scale = start_scale + + def scaling_pseudo_density() -> KWType: + return pyro.poutine.scale(pseudo_density, scale=current_scale)() + + # noinspection PyTypeChecker + elbo = pyro.infer.Trace_ELBO()(scaling_pseudo_density, guide) + elbo() # Call to surface parameters for optimizer. + optim = torch.optim.ASGD(elbo.parameters(), lr=lr) + + for i in range(n_steps): + + # Scale from start_scale up to 1. + current_scale = start_scale + (1. - start_scale) * (i / n_steps) + + for param in elbo.parameters(): + param.grad = None + optim.zero_grad() + + loss = elbo() + loss.backward() + + if adjust_grads_ is not None: + adjust_grads_(*tuple(elbo.parameters())) + + if callback is not None: + callback(k, i) + + optim.step() + + def register_guides(self, ce: ComposedExpectation, model: ModelType, + auto_guide: Optional[Type[AutoGuide]], auto_guide_kwargs=None, + allow_repeated_names=False): + self.clear_guides() + + if auto_guide_kwargs is None: + auto_guide_kwargs = dict() + else: + if auto_guide is None: + raise ValueError("auto_guide_kwargs provided but no auto_guide class provided. Did you mean to " + "provide an auto_guide class?") + + for part in ce.parts: + pseudo_density = part.build_pseudo_density(model) + if part.guide is not None: + guide = part.guide + + if tuple(guide().keys()) != tuple(model().keys()): + raise ValueError("A preset guide must return the same variables as the model, but got " + f"{tuple(guide().keys())} and {tuple(model().keys())} instead.") + + else: + if auto_guide is None: + raise ValueError("No guide preregistered and no no auto guide class provided.") + guide = auto_guide(model, **auto_guide_kwargs) + + if not allow_repeated_names: + if part.name in self.pseudo_densities: + raise ValueError(f"Repeated part name {part.name}.") + if part.name in self.guides: + raise ValueError(f"Repeated part name {part.name}.") + + if part.name not in self.pseudo_densities: + self.pseudo_densities[part.name] = pseudo_density + + if part.name not in self.guides: + self.guides[part.name] = guide + + self.registered_model = model + + def keys(self) -> frozenset: + assert tuple(self.pseudo_densities.keys()) == tuple(self.guides.keys()), "Should not be possible b/c these" \ + " are added to at the same time." + return frozenset(self.pseudo_densities.keys()) + + def clear_guides(self): + self.pseudo_densities = dict() + self.guides = dict() + self.registered_model = None + + def _get_pq(self, ea: "ExpectationAtom", p: ModelType) -> Tuple[ModelType, ModelType]: + try: + q: ModelType = self.guides[ea.name] + except KeyError: + raise KeyError(f"No guide registered for {ea.name}. " + f"Did you call {_GuideRegistrationMixin.__name__}.register_guides?") + + if p is not self.registered_model: + raise ValueError("The probability distribution registered with the guides does not match the " + "probability distribution called to compute the expectation. In other words," + f"the same p must be used in {_GuideRegistrationMixin.__name__}" + f".register_guides and {ComposedExpectation.__name__}.__call__.") + + return p, q + + def plot_guide_pseudo_likelihood( + self, rv_name: str, guide_kde_kwargs, pseudo_density_plot_kwargs, keys: List[str] = None): + # TODO move this to a separate class and inherit or something, just so plotting code doesn't clutter + # up functional code. + import seaborn as sns + import matplotlib.pyplot as plt + + if not len(self.keys()): + raise ValueError("No guides registered. Did you call " + f"{_GuideRegistrationMixin.__name__}.register_guides?") + + figs = [] + + if keys is None: + keys = self.keys() + + for k in keys: + pseudo_density = self.pseudo_densities[k] + guide = self.guides[k] + + if self.registered_model()[rv_name].ndim != 0: + raise ValueError("Can only plot pseudo likelihood/guide comparisons for univariates.") + + fig, ax = plt.subplots(1, 1) + sns.kdeplot([guide()[rv_name].item() for _ in range(1000)], label="guide", **guide_kde_kwargs) + + tax = ax.twinx() + + model_samples = torch.tensor([self.registered_model()[rv_name] for _ in range(1000)]) + xx = torch.linspace(model_samples.min(), model_samples.max(), 1000).detach() + + lps = [] + for x in xx: + cm = pyro.poutine.condition(pseudo_density, data={rv_name: x}) + lp = pyro.poutine.trace(cm).get_trace().log_prob_sum() + lps.append(lp) + lps = torch.tensor(lps).exp().detach().numpy() + + # This will be squiggly if there are other latents. TODO smooth? + tax.plot(xx, lps, label="pseudo-density", **pseudo_density_plot_kwargs) + + ax.set_title(f"q and pseudo-p for {rv_name} and part {k}") + + # Add single legend for both. + lines, labels = ax.get_legend_handles_labels() + lines2, labels2 = tax.get_legend_handles_labels() + ax.legend(lines + lines2, labels + labels2, loc=0) + + figs.append(fig) + + return figs diff --git a/chirho/contrib/compexp/handlers/importance_sampling_expectation_handler.py b/chirho/contrib/compexp/handlers/importance_sampling_expectation_handler.py new file mode 100644 index 000000000..b7114e5fb --- /dev/null +++ b/chirho/contrib/compexp/handlers/importance_sampling_expectation_handler.py @@ -0,0 +1,35 @@ +from ..composeable_expectation.expectation_atom import ExpectationAtom +from .expectation_handler import ExpectationHandler +from .guide_registration_mixin import _GuideRegistrationMixin +import pyro +import torch +from ..utils import msg_args_kwargs_to_kwargs, kft +from ..typedecs import ModelType + + +class ImportanceSamplingExpectationHandler(ExpectationHandler, _GuideRegistrationMixin): + + def __init__(self, num_samples: int): + super().__init__() + self.num_samples = num_samples + + def _pyro__compute_expectation_atom(self, msg) -> None: + super()._pyro_compute_expectation_atom(msg) + + kwargs = msg_args_kwargs_to_kwargs(msg) + + ea: ExpectationAtom = kwargs.pop("ea") + p: ModelType = kwargs["p"] + p, q = self._get_pq(ea, p) + + fpqvals = [] + + for _ in range(self.num_samples): + qtr = pyro.poutine.trace(q).get_trace() + ptr = pyro.poutine.trace(pyro.poutine.replay(p, trace=qtr)).get_trace() + s = kft(qtr) + + fpqval = ptr.log_prob_sum() + torch.log(ea.log_fac_eps + ea.f(s)) - qtr.log_prob_sum() + fpqvals.append(torch.exp(fpqval)) + + msg["value"] = torch.mean(torch.stack(fpqvals), dim=0) diff --git a/chirho/contrib/compexp/handlers/montecarlo_expectation_handler.py b/chirho/contrib/compexp/handlers/montecarlo_expectation_handler.py new file mode 100644 index 000000000..47b7ff535 --- /dev/null +++ b/chirho/contrib/compexp/handlers/montecarlo_expectation_handler.py @@ -0,0 +1,25 @@ +from .expectation_handler import ExpectationHandler +import torch +from ..composeable_expectation.expectation_atom import ExpectationAtom +from ..typedecs import ModelType +from ..utils import msg_args_kwargs_to_kwargs + + +class MonteCarloExpectationHandler(ExpectationHandler): + + def __init__(self, num_samples: int): + self.num_samples = num_samples + + def _pyro__compute_expectation_atom(self, msg) -> None: + super()._pyro_compute_expectation_atom(msg) + + kwargs = msg_args_kwargs_to_kwargs(msg) + ea: ExpectationAtom = kwargs.pop("ea") + p: ModelType = kwargs["p"] + + fvals = [] + for _ in range(self.num_samples): + s = p() + fvals.append(ea.f(s)) + + msg["value"] = torch.mean(torch.stack(fvals), dim=0) diff --git a/chirho/contrib/compexp/ops.py b/chirho/contrib/compexp/ops.py new file mode 100644 index 000000000..95010cedb --- /dev/null +++ b/chirho/contrib/compexp/ops.py @@ -0,0 +1,11 @@ +import pyro +from .handlers.expectation_handler import ExpectationHandler +from .typedecs import ModelType +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .composeable_expectation.expectation_atom import ExpectationAtom + + +@pyro.poutine.runtime.effectful(type="_compute_expectation_atom") +def _compute_expectation_atom(ea: "ExpectationAtom", p: ModelType): + raise NotImplementedError(f"Must be called in the context of an {ExpectationHandler.__name__}.") diff --git a/chirho/contrib/compexp/typedecs.py b/chirho/contrib/compexp/typedecs.py new file mode 100644 index 000000000..0fa6245b4 --- /dev/null +++ b/chirho/contrib/compexp/typedecs.py @@ -0,0 +1,11 @@ +from typing import Callable +from torch import Tensor +from torch.nn import Parameter +from collections import OrderedDict + +KWType = OrderedDict[str, Tensor] +KWTypeNNParams = OrderedDict[str, Parameter] +ModelType = Callable[[], KWType] + +StochasticFunction = Callable[[KWType], Tensor] +ExpectationFunction = Callable[[ModelType], Tensor] diff --git a/chirho/contrib/compexp/utils.py b/chirho/contrib/compexp/utils.py new file mode 100644 index 000000000..6bab1b1ac --- /dev/null +++ b/chirho/contrib/compexp/utils.py @@ -0,0 +1,41 @@ +import torch +from torch import tensor as tt, Tensor as TT +from collections import OrderedDict +import inspect +from .typedecs import KWType, KWTypeNNParams + + +def flatten_dparams(dparams: KWTypeNNParams) -> TT: + return torch.cat([torch.flatten(dparams[k]) for k in dparams.keys()]) + + +def unflatten_df_dparams(cat_df_dparams: TT, dparams: KWTypeNNParams) -> KWType: + df_dparams = OrderedDict() + + last_fidx = 0 + for k in dparams.keys(): + slice_len = dparams[k].numel() + cat_df_dparams_slice = cat_df_dparams[last_fidx:last_fidx + slice_len] + last_fidx += slice_len + + df_dparams[k] = torch.unflatten(cat_df_dparams_slice, dim=0, sizes=dparams[k].shape) + + return df_dparams + + +def msg_args_kwargs_to_kwargs(msg): + + ba = inspect.signature(msg["fn"]).bind(*msg["args"], **msg["kwargs"]) + ba.apply_defaults() + + return ba.arguments + + +# Keyword arguments From Trace +def kft(trace) -> KWType: + # Copy the ordereddict. + new_trace = OrderedDict(trace.nodes.items()) + # Remove the _INPUT and _RETURN nodes. + del new_trace["_INPUT"] + del new_trace["_RETURN"] + return OrderedDict(zip(new_trace.keys(), [v["value"] for v in new_trace.values()])) diff --git a/docs/source/expectation_programming/README.md b/docs/source/expectation_programming/README.md new file mode 100644 index 000000000..dc35cd6ad --- /dev/null +++ b/docs/source/expectation_programming/README.md @@ -0,0 +1,2 @@ +The demo scripts in this folder need to be converted and reduced into some notebooks. +It contains two still-functional old demos that don't use the compositional expectation machinery. diff --git a/docs/source/expectation_programming/demo_compexp.py b/docs/source/expectation_programming/demo_compexp.py new file mode 100644 index 000000000..615ffa2e1 --- /dev/null +++ b/docs/source/expectation_programming/demo_compexp.py @@ -0,0 +1,336 @@ +# TODO Needs supporting stuff ported here. Also break out into notebook and tests. + + +import pyro +import torch +from torch import tensor as tt +from pyro.infer.autoguide import AutoGuide + +from toy_tabi_problem import ( + cost, + model as model_ttp, + q_optimal_normal_guide_mean_var, + MODEL_DIST as MODEL_TTP_DIST + ) +import pyro.distributions as dist +from collections import OrderedDict +import matplotlib.pyplot as plt +import numpy as np +import chirho.contrib.compexp as compexp + +import old_ep_demo_scratch as stor + +pyro.settings.set(module_local_params=True) + + +def main(): + # noinspection PyPep8Naming + D = tt(0.5) + # noinspection PyPep8Naming + C = tt(1.0) + # noinspection PyPep8Naming + GT = -1.1337 + + # Make sure we match a ground truth value. + toy_opt_guidep = dist.Normal(*q_optimal_normal_guide_mean_var(d=D, c=C, z=False)) + toy_opt_guiden = dist.Normal(*q_optimal_normal_guide_mean_var(d=D, c=C, z=True)) + + # decomposed toy expectation + # noinspection PyPep8Naming + ttE: compexp.ComposedExpectation = compexp.E( + f=lambda s: 1e1*cost(d=tt(0.5), c=tt(1.), **s), name='ttE' + ).split_into_positive_components( + # TODO bdt18dosjk + pos_guide=lambda: OrderedDict(x=pyro.sample('x', toy_opt_guidep)), + neg_guide=lambda: OrderedDict(x=pyro.sample('x', toy_opt_guiden)), + den_guide=lambda: OrderedDict(x=pyro.sample('x', MODEL_TTP_DIST)) + ) + + # Just to make sure everything still works after this is called. + ttE.recursively_refresh_parts() + + # + mcestimates = [] + for _ in range(1000): + with compexp.MonteCarloExpectationHandler(num_samples=100): + mcestimates.append(ttE(model_ttp)) + + plt.suptitle("MC Estimate") + plt.hist(mcestimates, bins=30) + mcestimate = torch.mean(torch.stack(mcestimates)) + plt.axvline(x=mcestimate, color='r') + plt.axvline(x=GT, color='black', linestyle='--') + plt.suptitle(f"MC Estimate: {mcestimate}\n GT = {GT}") + plt.show() + # + + # + # When using the decomposition above with per-atom guides and importance sampling, we get TABI. Because + # we've preset the optimal guides above, we will get an exact estimate. + iseh = compexp.ImportanceSamplingExpectationHandler(num_samples=1) + iseh.register_guides(ce=ttE, model=model_ttp, auto_guide=None, auto_guide_kwargs=None) + with iseh: + tabiestimate = ttE(model_ttp) + print(f"TABI Estimate: {tabiestimate}", f"GT = {GT}") + assert torch.isclose(tabiestimate, tt(GT), atol=1e-4) + # + + # + + ttE = compexp.E( + f=lambda s: 1e1 * cost(d=tt(0.5), c=tt(1.), **s), name='ttE' + ).split_into_positive_components() + + iseh2 = compexp.ImportanceSamplingExpectationHandler(num_samples=300) + iseh2.register_guides( + ce=ttE, + model=model_ttp, + auto_guide=pyro.infer.autoguide.AutoNormal, + auto_guide_kwargs=dict(init_scale=2.)) + + def plot_callback_(k, i): + if i % 5000 == 0: + figs = iseh2.plot_guide_pseudo_likelihood( + rv_name='x', + guide_kde_kwargs=dict(bw_method=0.1, color='orange'), + pseudo_density_plot_kwargs=dict(color='purple'), + keys=[k] if k is not None else None + ) + plt.show() + for f in figs: + plt.close(f) + + iseh2.optimize_guides( + lr=5e-3, n_steps=10001, + callback=plot_callback_ + ) + + tabi_learned_estimates = [] + for _ in range(100): + with iseh2: + tabi_learned_estimates.append(ttE(model_ttp)) + tabi_learned_estimates = torch.stack(tabi_learned_estimates).detach().numpy() + + plt.suptitle("TABI Gradient Estimate") + plt.hist(tabi_learned_estimates, bins=30) + plt.axvline(x=np.mean(tabi_learned_estimates), color='r') + plt.axvline(x=GT, color='black', linestyle='--') + plt.suptitle(f"TABI Gradient Estimate: {np.mean(tabi_learned_estimates)}\n GT = {GT}") + plt.show() + + # + + # + dps = torch.nn.Parameter(tt([0.5])) + + # noinspection PyPep8Naming + ttgradE1: compexp.ComposedExpectation = compexp.E( + # f=lambda s: 1e1*cost(d=dps[0] * dps[1], c=tt(1.), **s), name='ttgradE' + f=lambda s: 1e1*cost(d=dps[0], c=tt(1.), **s), name='ttgradE' + ).split_into_positive_components() + ttgradE1 = compexp.grad.gradify_in_place_but_need_return(ttgradE1, dparams=dps) + + mc_grad_estimates = [] + for _ in range(1000): + with compexp.MonteCarloExpectationHandler(num_samples=100): + mc_grad_estimates.append(ttgradE1(model_ttp)) + mc_grad_estimates = torch.stack(mc_grad_estimates).detach().numpy().T[0] + + plt.suptitle("MC Gradient Estimate") + plt.hist(mc_grad_estimates, bins=30) + plt.axvline(x=np.mean(mc_grad_estimates), color='r') + plt.axvline(x=2.358, color='black', linestyle='--') + plt.suptitle(f"MC Gradient Estimate: {np.mean(mc_grad_estimates)}\n GT = {2.358}") + plt.show() + # + + # + dps = torch.nn.Parameter(tt([0.25, 2.])) + + # noinspection PyPep8Naming + ttgradE2: compexp.ComposedExpectation = compexp.E( + f=lambda s: 1e1*cost(d=dps[0] * dps[1], c=tt(1.), **s), name='ttgradE' + ).split_into_positive_components() + ttgradE1 = compexp.grad.gradify_in_place_but_need_return(ttgradE2, dparams=dps) + + with compexp.MonteCarloExpectationHandler(num_samples=10): + print(ttgradE2(model_ttp), "GT unknown but it runs") + # + + # + dps = torch.nn.Parameter(tt([0.5])) + + ttgrad_tabi_unfit = compexp.E( + f=lambda s: 1e1*cost(d=dps[0], c=tt(1.), **s), name='ttgradE' + ) + ttgrad_tabi_unfit = compexp.grad.gradify_in_place_but_need_return(ttgrad_tabi_unfit, dparams=dps, split_atoms=True) + + iseh2 = compexp.ImportanceSamplingExpectationHandler(num_samples=50) + iseh2.register_guides( + ce=ttgrad_tabi_unfit, + model=model_ttp, + auto_guide=pyro.infer.autoguide.AutoNormal, + auto_guide_kwargs=dict(init_scale=1.5)) + iseh2.plot_guide_pseudo_likelihood( + rv_name='x', + guide_kde_kwargs=dict(bw_method=0.1, color='orange'), + pseudo_density_plot_kwargs=dict(color='purple') + ) + plt.show() + + tabi_unlearned_grad_estimates = [] + for _ in range(1000): + with iseh2: + tabi_unlearned_grad_estimates.append(ttgrad_tabi_unfit(model_ttp)) + tabi_unlearned_grad_estimates = torch.stack(tabi_unlearned_grad_estimates).detach().numpy() + + plt.suptitle("TABI Gradient Estimate") + plt.hist(tabi_unlearned_grad_estimates, bins=30) + plt.axvline(x=np.mean(tabi_unlearned_grad_estimates), color='r') + plt.axvline(x=2.358, color='black', linestyle='--') + plt.suptitle(f"TABI Gradient Estimate: {np.mean(tabi_unlearned_grad_estimates)}\n GT = {2.358}") + plt.show() + # + + # + dps = torch.nn.Parameter(tt([0.5])) + + ttgrad_tabi_unfit = compexp.E( + f=lambda s: 1e1 * cost(d=dps[0], c=tt(1.), **s), name='ttgradE' + ) + ttgrad_tabi_unfit = compexp.grad.gradify_in_place_but_need_return(ttgrad_tabi_unfit, dparams=dps, split_atoms=True) + + iseh2 = compexp.ImportanceSamplingExpectationHandler(num_samples=300) + iseh2.register_guides( + ce=ttgrad_tabi_unfit, + model=model_ttp, + auto_guide=pyro.infer.autoguide.AutoNormal, + auto_guide_kwargs=dict(init_scale=1.)) + + # Make guides that are roughly in the correct positions. + pos_iloc_grad, pos_istd_grad = q_optimal_normal_guide_mean_var(d=tt(0.5), c=tt(1.0), z=False) + neg_iloc_grad, neg_istd_grad = q_optimal_normal_guide_mean_var(d=tt(0.5), c=tt(1.0), z=True) + pos_guide_grad = stor.MultiModalGuide1D( + num_components=2, + init_loc=[pos_iloc_grad.item(), neg_iloc_grad.item()], + init_scale=[pos_istd_grad.item(), neg_istd_grad.item()] + ) + neg_guide_grad = stor.MultiModalGuide1D( + num_components=2, + init_loc=[pos_iloc_grad.item(), neg_iloc_grad.item()], + init_scale=[pos_istd_grad.item(), neg_istd_grad.item()] + ) + + # TODO is this a good pattern for specifying guides? If so should add some more runtime checking to make sure + # e.g. names exist and that the guides spit out the right stochastics. + iseh2.guides['dttgradE_dd0_split_pos'] = pos_guide_grad + iseh2.guides['dttgradE_dd0_split_neg'] = neg_guide_grad + # Leave 'dttgradE_dd0_split_den' as default AutoNormal. + + iseh2.plot_guide_pseudo_likelihood( + rv_name='x', + guide_kde_kwargs=dict(bw_method=0.1, color='orange'), + pseudo_density_plot_kwargs=dict(color='purple') + ) + plt.show() + + tabi_unlearned_good_grad_estimates = [] + for _ in range(100): + with iseh2: + tabi_unlearned_good_grad_estimates.append(ttgrad_tabi_unfit(model_ttp)) + tabi_unlearned_good_grad_estimates = torch.stack(tabi_unlearned_good_grad_estimates).detach().numpy() + + plt.suptitle("TABI Unlearned Good Gradient Estimate") + plt.hist(tabi_unlearned_good_grad_estimates, bins=30) + plt.axvline(x=np.mean(tabi_unlearned_good_grad_estimates), color='r') + plt.axvline(x=2.358, color='black', linestyle='--') + plt.suptitle(f"TABI Unlearned Good Gradient Estimate: {np.mean(tabi_unlearned_good_grad_estimates)}\n GT = {2.358}") + plt.show() + # + + # + dps = torch.nn.Parameter(tt([0.5])) + + ttgrad_tabi_unfit = compexp.E( + f=lambda s: 1e1 * cost(d=dps[0], c=tt(1.), **s), name='ttgradE' + ) # .split_into_positive_components() + ttgrad_tabi_unfit = compexp.grad.gradify_in_place_but_need_return(ttgrad_tabi_unfit, dparams=dps, split_atoms=True) + + # # HACK any atom with the word "den" in it is actually the same atom. + # broadcasting_den = None + # for part in ttgrad_tabi_unfit.parts: + # if 'den' in part.name: + # if broadcasting_den is None: + # broadcasting_den = part + # else: + # part.swap_self_for_other_child(broadcasting_den) + # ttgrad_tabi_unfit.recursively_refresh_parts() + + iseh2 = compexp.ImportanceSamplingExpectationHandler(num_samples=50) + iseh2.register_guides( + ce=ttgrad_tabi_unfit, + model=model_ttp, + auto_guide=pyro.infer.autoguide.AutoNormal, + auto_guide_kwargs=dict(init_scale=2.)) + + # Start the guides roughly in the correct positions. + pos_iloc_grad, pos_istd_grad = q_optimal_normal_guide_mean_var(d=tt(0.5), c=tt(1.0), z=False) + neg_iloc_grad, neg_istd_grad = q_optimal_normal_guide_mean_var(d=tt(0.5), c=tt(1.0), z=True) + pos_guide_grad = stor.MultiModalGuide1D( + num_components=2, + init_loc=[pos_iloc_grad.item(), neg_iloc_grad.item()], + init_scale=[pos_istd_grad.item() * 2., neg_istd_grad.item() * 2.], + studentt=True + ) + neg_guide_grad = stor.MultiModalGuide1D( + num_components=2, + init_loc=[pos_iloc_grad.item(), neg_iloc_grad.item()], + init_scale=[pos_istd_grad.item() * 2., neg_istd_grad.item() * 2.], + studentt=True + ) + + # TODO is this a good pattern for specifying guides? If so should add some more runtime checking to make sure + # e.g. names exist and that the guides spit out the right stochastics. + iseh2.guides['dttgradE_dd0_split_pos'] = pos_guide_grad + iseh2.guides['dttgradE_dd0_split_neg'] = neg_guide_grad + # Leave 'dttgradE_dd0_split_den' as default AutoNormal. + + NSTEPS = 10000 + + def plot_callback_(k, i): + if i % NSTEPS == 0: + figs = iseh2.plot_guide_pseudo_likelihood( + rv_name='x', + guide_kde_kwargs=dict(bw_method=0.1, color='orange'), + pseudo_density_plot_kwargs=dict(color='purple'), + keys=[k] if k is not None else None + ) + plt.show() + for f in figs: + plt.close(f) + + # plot_callback_(None, 0) + + iseh2.optimize_guides( + lr=1e-3, n_steps=NSTEPS + 1, + adjust_grads_=stor.abort_guide_grads_, + callback=plot_callback_ + ) + + tabi_learned_grad_estimates = [] + for _ in range(100): + with iseh2: + tabi_learned_grad_estimates.append(ttgrad_tabi_unfit(model_ttp)) + tabi_learned_grad_estimates = torch.stack(tabi_learned_grad_estimates).detach().numpy() + + plt.suptitle("TABI Learned Gradient Estimate") + plt.hist(tabi_learned_grad_estimates, bins=30) + plt.axvline(x=np.mean(tabi_learned_grad_estimates), color='r') + plt.axvline(x=2.358, color='black', linestyle='--') + plt.suptitle(f"TABI Learned Gradient Estimate: {np.mean(tabi_learned_grad_estimates)}\n GT = {2.358}") + plt.show() + # + + +if __name__ == "__main__": + main() diff --git a/docs/source/expectation_programming/old_ep_demo_scratch.py b/docs/source/expectation_programming/old_ep_demo_scratch.py new file mode 100644 index 000000000..d21a8acd7 --- /dev/null +++ b/docs/source/expectation_programming/old_ep_demo_scratch.py @@ -0,0 +1,1191 @@ +import pyro +import torch +from pyro.infer.autoguide import AutoMultivariateNormal +import pyro.distributions as dist +import matplotlib.pyplot as plt +import seaborn as sns +import inspect +from torch import tensor as tt + +from collections import OrderedDict + +from typing import ( + Callable, + TypeVar, + Optional, + Dict, + List, + Generic, + Union +) + +# from typing_extensions import Unpack + +import numpy as np +import functools + +from pyro.poutine.replay_messenger import ReplayMessenger + +pyro.settings.set(module_local_params=True) + + +# FIXME Use the actual type from pyro. Gotta find ref. Just using this to force type checking. +class TraceType: + @property + def nodes(self) -> Dict[str, Dict[str, torch.Tensor]]: + raise NotImplementedError() + + def log_prob_sum(self) -> torch.Tensor: + raise NotImplementedError() + + +KWType = OrderedDict[str, torch.Tensor] +KWTypeNNParams = OrderedDict[str, torch.nn.Parameter] +KWTypeWithTrace = OrderedDict[str, Union[torch.Tensor, TraceType]] + +# Model should involv epyro primitives. +ModelType = Callable[[], KWType] + +# This should not involve pyro primitives +# This tuple return is primarily to support gradients of expectations. +ExpectigrandType = Callable[[KWType], KWType] + +# This function takes both decision parameters and stochastics, and returns a tuple of things. +# This tuple return is primarily to support gradients of expectations. +DecisionExpectigrandType = Callable[[KWTypeNNParams, KWType], KWType] + + +# noinspection PyUnusedLocal +@pyro.poutine.runtime.effectful(type="expectation") +def expectation( + f: ExpectigrandType, + p: ModelType) -> KWType: + raise NotImplementedError("Requires an expectation handler to function.") + + +@pyro.poutine.runtime.effectful(type="build_expectigrand_gradient") +def build_expectigrand_gradient( + dparams: KWTypeNNParams, + f: DecisionExpectigrandType +) -> ExpectigrandType: + + # Making this an effectful operation allows constraints to 1) add their gradients to the decision + # parameters and 2) add any auxiliary parameters to dparams that are required to represent themselves + # as an unconstrained problem. + + def df_dd(stochastics: KWType) -> KWType: + + y: KWType = f(dparams, stochastics) + + if len(y) != 1: + # TODO eventually support multiple outputs? Will probably want to do this + # once constraints get properly involved. + raise ValueError("Decision function must return a single tensor value.") + + ddparams = torch.autograd.grad( + outputs=list(y.values()), + inputs=list(dparams.values()), + create_graph=True) + + return OrderedDict(zip((_gradify_dparam_name(k) for k in dparams.keys()), ddparams)) + + return df_dd + + +def _gradify_dparam_name(name: str) -> str: + return f"df_d{name}" + + +# noinspection PyUnusedLocal +@pyro.poutine.runtime.effectful(type="optimize_decision") +def optimize_decision( + f: DecisionExpectigrandType, + p: ModelType, + terminal_condition: Callable[[KWTypeNNParams, int], bool], + # Note the difference between this signature and the one in optimize_proposal. TODO unify. + adjust_grads: Optional[Callable[[KWType], KWType]] = None, + callback: Optional[Callable[[], None]] = None +) -> KWType: + raise NotImplementedError("Requires both an expectation and optimizer handler.") + + +def msg_args_kwargs_to_kwargs(msg): + + ba = inspect.signature(msg["fn"]).bind(*msg["args"], **msg["kwargs"]) + ba.apply_defaults() + + return ba.arguments + + +class ExpectationHandler(pyro.poutine.messenger.Messenger): + + def _pyro_expectation(self, msg) -> None: + if msg["done"]: + raise ValueError("You may be operating in a context with more than one expectation handler. In these" + " cases, you must explicitly specify which expectation handler to use by using the OOP" + " style call with the desired handler (e.g. `ExpectationHandler(...).expectation(...)`).") + + def expectation(self, *args, **kwargs): + # Calling this method blocks all other expectation handlers and uses only this one. + with self: + with pyro.poutine.messenger.block_messengers(lambda m: isinstance(m, ExpectationHandler) and m is not self): + return expectation(*args, **kwargs) + + def optimize_proposal(self, *args, **kwargs): + # Calling this method blocks all other expectation handlers and uses only this one. + with self: + with pyro.poutine.messenger.block_messengers(lambda m: isinstance(m, ExpectationHandler) and m is not self): + return optimize_proposal(*args, **kwargs) + + +# Keyword arguments From Trace +def kft(trace: TraceType) -> KWType: + # Copy the ordereddict. + new_trace = OrderedDict(trace.nodes.items()) + # Remove the _INPUT and _RETURN nodes. + del new_trace["_INPUT"] + del new_trace["_RETURN"] + return OrderedDict(zip(new_trace.keys(), [v["value"] for v in new_trace.values()])) + + +DT = TypeVar("DT") +DTr = TypeVar("DTr") + + +class OppableDictParent: # FIXME 28dj10dlk + pass + + +# TODO inherit from dict to just extend relevant methods, but couldn't figure out proper +# inheritance of generics. +class OppableDict(OppableDictParent, Generic[DT]): + """ + Helper class for executing operations on all the values of a dictionary. + Heavily typed to make sure everything is lining up. + """ + + def __init__(self, d: Optional[OrderedDict[str, DT]] = None): + if d is None: + d: OrderedDict[str, DT] = OrderedDict() + + self._d = d + + def __getitem__(self, item: str) -> DT: + return self._d[item] + + def __contains__(self, item: str) -> bool: + return item in self._d + + def __setitem__(self, key: str, value: DT) -> None: + self._d[key] = value + + def items(self): + return self._d.items() + + def op(self, f: Callable[[DT], DTr]) -> 'OppableDict[DTr]': + ret = OppableDict() + + for k, v in self.items(): + ret[k] = f(v) + + return ret + + def op_other( + self, + f: Callable[[DT, ...], DTr], + *others: 'OppableDict[DT]') -> 'OppableDict[DTr]': + ret = OppableDict() + + for k, v in self.items(): + ret[k] = f(v, *tuple(o[k] for o in others)) + + return ret + + @functools.singledispatchmethod + def __sub__(self, other): + raise NotImplementedError() + + @__sub__.register(Union[float, torch.Tensor]) + def _(self, other: Union[float, torch.Tensor]) -> "OppableDict": + return self.op(lambda v: v - other) + + # FIXME 28dj10dlk Once https://github.com/python/cpython/issues/86153 drops. + + # @__add__.register(OppableDict) # FIXME 28dj10dlk + @__sub__.register(OppableDictParent) + # def _(self, other: "OppableDict") -> "OppableDict": # FIXME 28dj10dlk desired + def _(self, other: OppableDictParent) -> "OppableDict": # FIXME 28dj10dlk + assert isinstance(other, OppableDict) # FIXME 28dj10dlk runtime check instead. + return self.op_other(lambda v, o: v - o, other) + + @functools.singledispatchmethod + def __add__(self, other): + raise NotImplementedError() + + @__add__.register(Union[float, torch.Tensor]) + def _(self, other: Union[float, torch.Tensor]) -> "OppableDict": + return self.op(lambda v: v + other) + + # FIXME 28dj10dlk + @__add__.register(OppableDictParent) + def _(self, other: OppableDictParent) -> "OppableDict": + assert isinstance(other, OppableDict) + return self.op_other(lambda v, o: v + o, other) + + @functools.singledispatchmethod + def __truediv__(self, other): + raise NotImplementedError() + + @__truediv__.register(Union[float, torch.Tensor]) + def _(self, other: Union[float, torch.Tensor]) -> "OppableDict": + return self.op(lambda v: v / other) + + # @__truediv__.register(OppableDict) # FIXME 28dj10dlk + @__truediv__.register(OppableDictParent) + # def _(self, other: "OppableDict") -> "OppableDict": # FIXME 28dj10dlk desired + def _(self, other: OppableDictParent) -> "OppableDict": # FIXME 28dj10dlk + assert isinstance(other, OppableDict) # FIXME 28dj10dlk runtime check instead. + return self.op_other(lambda v, o: v / o, other) + + @functools.singledispatchmethod + def __mul__(self, other): + raise NotImplementedError() + + @__mul__.register(Union[float, torch.Tensor]) + def _(self, other: Union[float, torch.Tensor]) -> "OppableDict": + return self.op(lambda v: v * other) + + # FIXME 28dj10dlk desired + @__mul__.register(OppableDictParent) + def _(self, other: OppableDictParent) -> "OppableDict": + assert isinstance(other, OppableDict) + return self.op_other(lambda v, o: v * o, other) + + @property + def wrapped(self) -> OrderedDict[str, DT]: + return self._d + + +class DictOLists(OppableDict[List[torch.Tensor]]): + def append(self, value: Dict[str, torch.Tensor]) -> None: + + for k, v in value.items(): + + if k not in self: + self[k] = [] + + self[k].append(v) + + +class MonteCarloExpectation(ExpectationHandler): + # Adapted from Rafal's "query library" code. + + def __init__(self, num_samples: int): + self.num_samples = num_samples + + super().__init__() + + def _pyro_expectation(self, msg) -> None: + super()._pyro_expectation(msg) + + kwargs = msg_args_kwargs_to_kwargs(msg) + p: ModelType = kwargs["p"] + f: ExpectigrandType = kwargs["f"] + + fvs = DictOLists() + + for i in range(self.num_samples): + trace = pyro.poutine.trace(p).get_trace() + ret = trace.nodes["_RETURN"]["value"] # type: KWType + fv = f(ret) + fvs.append(fv) + + msg_value = fvs.op(lambda v: torch.sum(torch.tensor(v)) / self.num_samples) + + msg["value"] = msg_value + msg["done"] = True + + +class SNISExpectation(ExpectationHandler): + + def __init__(self, q: ModelType, num_samples: int): + self.q = q + self.num_samples = num_samples + + super().__init__() + + def _pyro_expectation(self, msg) -> None: + super()._pyro_expectation(msg) + + kwargs = msg_args_kwargs_to_kwargs(msg) + q = self.q + p = kwargs["p"] # type: ModelType + f = kwargs["f"] # type: ExpectigrandType + + fvs = DictOLists() + plps = [] + qlps = [] + + for i in range(self.num_samples): + # Sample stochastics from the proposal distribution. + q_trace = pyro.poutine.trace(q).get_trace() + qlp = q_trace.log_prob_sum() + + # Trace the full model with the proposed stochastics. + with ReplayMessenger(trace=q_trace): + trace = pyro.poutine.trace(lambda: f(p())).get_trace() + + # Record the return value and the log probability with respect to the model. + fv = trace.nodes["_RETURN"]["value"] # type: KWType + plp = trace.log_prob_sum() + + fvs.append(fv) + plps.append(plp) + qlps.append(qlp) + + plps = torch.tensor(plps) + qlps = torch.tensor(qlps) + + unw = plps - qlps # unnormalized weights + w = torch.exp(unw - torch.logsumexp(unw, dim=0)) # normalized weights + + msg_value = fvs.op(lambda v: torch.sum(torch.tensor(v) * w)) + + msg["value"] = msg_value + msg["done"] = True + + +# noinspection PyUnusedLocal +@pyro.poutine.runtime.effectful(type="optimize_proposal") +def optimize_proposal(p: ModelType, f: ExpectigrandType, n_steps=1, lr=0.01, + # Note the difference in signature here vs in the optimize_decision. TODO unify. + adjust_grads_: Optional[Callable[[torch.nn.Parameter, ...], None]] = None, + callback: Optional[Callable[[], None]] = None): + raise NotImplementedError() + + +class LazySVIStuff: + elbo: pyro.infer.Trace_ELBO = None + optim = None + + +class TABIExpectation(ExpectationHandler): + QTRACE_KEY = "TABIExpectation_q_trace_for_log_prob_and_replay" + FAC_KEY = "TABIExpectation_expectation_targeting_log_factor" + + def __init__(self, q_plus: ModelType, q_den: ModelType, num_samples: int, + grad_clip: Optional[Callable] = None, + q_minus: Optional[ModelType] = None): + self.q_plus = q_plus + self.q_minus = q_minus + self.q_den = q_den + self.num_samples = num_samples + self.grad_clip = grad_clip + + self._lazy_q_plus_svi = LazySVIStuff() + self._lazy_q_minus_svi = LazySVIStuff() + self._lazy_q_den_svi = LazySVIStuff() + + super().__init__() + + def __enter__(self): + self._lazy_q_plus_svi = LazySVIStuff() + self._lazy_q_minus_svi = LazySVIStuff() + self._lazy_q_den_svi = LazySVIStuff() + + return super().__enter__() + + def _optimize_proposal_part(self, n_steps, elbo, optim, adjust_grads_=None, + callback: Optional[Callable[[], None]] = None): + for step in range(0, n_steps + 1): + for param in elbo.parameters(): + param.grad = None + + optim.zero_grad() + loss = elbo() + loss.backward() + + if adjust_grads_ is not None: + adjust_grads_(*tuple(elbo.parameters())) + + if callback is not None: + callback() + + optim.step() + + def get_part(self, sign: float, f: ExpectigrandType, p: ModelType): + def factor_augmented_p(): + stochastics: KWType = p() + fv: KWType = f(stochastics) + fv_od = OppableDict(fv) + + # FIXME HACK 1e-6 is a hack to avoid log(0), make passable argument? + facval = fv_od.op(lambda v: torch.log(1e-6 + torch.relu(sign * v))) + + for k, v in facval.items(): + aug_k = f'{self.FAC_KEY}_{sign}_{k}' + fac = pyro.factor(aug_k, v) + + assert aug_k not in stochastics + stochastics[aug_k] = fac + + return stochastics + return factor_augmented_p + + @staticmethod + def get_svi_stuff(p, q, svi_stuff, lr): + if svi_stuff.elbo is None: + svi_stuff.elbo = pyro.infer.Trace_ELBO()(p, q) + svi_stuff.elbo() + if svi_stuff.optim is None: + svi_stuff.optim = torch.optim.ASGD(svi_stuff.elbo.parameters(), lr=lr) + return svi_stuff + + def _pyro_optimize_proposal(self, msg) -> None: + kwargs = msg_args_kwargs_to_kwargs(msg) + p: ModelType = kwargs["p"] + f: ExpectigrandType = kwargs["f"] + n_steps: int = kwargs["n_steps"] + lr: float = kwargs["lr"] + adjust_grads_: Optional[Callable[[torch.nn.Parameter, ...], None]] = kwargs["adjust_grads_"] + callback: Optional[Callable[[], None]] = kwargs["callback"] + + self.get_svi_stuff(self.get_part(1., f, p), self.q_plus, self._lazy_q_plus_svi, lr) + self._optimize_proposal_part(n_steps, self._lazy_q_plus_svi.elbo, self._lazy_q_plus_svi.optim, + adjust_grads_, callback=callback) + + self.get_svi_stuff(p, self.q_den, self._lazy_q_den_svi, lr) + self._optimize_proposal_part(n_steps, self._lazy_q_den_svi.elbo, self._lazy_q_den_svi.optim, + adjust_grads_, callback=callback) + + if self.q_minus is not None: + self.get_svi_stuff(self.get_part(-1., f, p), self.q_minus, self._lazy_q_minus_svi, lr) + self._optimize_proposal_part(n_steps, self._lazy_q_minus_svi.elbo, self._lazy_q_minus_svi.optim, + adjust_grads_, callback=callback) + + msg["value"] = None + msg["done"] = True + + def _pyro_expectation(self, msg) -> None: + super()._pyro_expectation(msg) + + kwargs = msg_args_kwargs_to_kwargs(msg) + q_plus = self.q_plus + q_minus = self.q_minus + q_den = self.q_den + p = kwargs["p"] # type: ModelType + f = kwargs["f"] # type: ExpectigrandType + + # TODO All this funkiness is to reuse logic between the component proposals, but its super opaque + # and annoying. It's also set up to get the actual proposal trace object to where it needs to go + # so that importance weights can be computed inside the expectigrand. Would like to clean up/simplify. + # This would allow us to get rid of KWTypeWithTrace also. Also, now that I've added the typing to sort + # out the redirection mess, it's like equally as verbose as not sharing code. + + def expectigrand(lf_: Optional[OppableDict[torch.Tensor]], q_trace: TraceType) -> KWType: + with ReplayMessenger(trace=q_trace): + unnorm_log_p = pyro.poutine.trace(p).get_trace().log_prob_sum() + + unnorm_log_q = q_trace.log_prob_sum() + + ret: KWType + if lf_ is not None: + ret = lf_.op(lambda v: torch.exp(v + unnorm_log_p - unnorm_log_q)).wrapped + else: + ret = OrderedDict(expected_importance_weight=torch.exp(unnorm_log_p - unnorm_log_q)) + return ret + + def get_signed_lf(kwstochastics: KWType, s: float) -> OppableDict[torch.Tensor]: + fv = OppableDict(f(kwstochastics)) + return fv.op(lambda v: torch.log(torch.relu(s * v))) + + def expectigrand_plus(kwstochastics: KWTypeWithTrace) -> KWType: + q_trace: TraceType = kwstochastics.pop(self.QTRACE_KEY) + kwstochastics: KWType + return expectigrand(lf_=get_signed_lf(kwstochastics, 1.), q_trace=q_trace) + + def expectigrand_minus(kwstochastics: KWTypeWithTrace) -> KWType: + q_trace: TraceType = kwstochastics.pop(self.QTRACE_KEY) + kwstochastics: KWType + return expectigrand(lf_=get_signed_lf(kwstochastics, -1.), q_trace=q_trace) + + def expectigrand_den(kwstochastics: KWTypeWithTrace) -> KWType: + q_trace: TraceType = kwstochastics.pop(self.QTRACE_KEY) + kwstochastics: KWType + return expectigrand(lf_=None, q_trace=q_trace) + + def get_get_qkwstochastics(q: ModelType) -> ModelType: + def get_qkwstochastics() -> KWTypeWithTrace: + qtr: TraceType = pyro.poutine.trace(q).get_trace() + qkwstochastics = kft(qtr) + # Add the log prob of the proposal for use down the line. + assert self.QTRACE_KEY not in qkwstochastics + qkwstochastics: KWTypeWithTrace + qkwstochastics[self.QTRACE_KEY] = qtr + return qkwstochastics + + return get_qkwstochastics + + with pyro.poutine.messenger.block_messengers(lambda m: m is self): + + with MonteCarloExpectation(self.num_samples): + e_plus = expectation( + f=expectigrand_plus, + p=get_get_qkwstochastics(q_plus) + ) + + if q_minus is not None: + with MonteCarloExpectation(self.num_samples): + e_minus = expectation( + f=expectigrand_minus, + p=get_get_qkwstochastics(q_minus) + ) + else: + e_minus = 0.0 + + with MonteCarloExpectation(self.num_samples): + e_den: torch.Tensor = expectation( + f=expectigrand_den, + p=get_get_qkwstochastics(q_den) + )["expected_importance_weight"] + + # Set this up to "broadcast" between the named tensors of the return dictionaries. + e_plus_od = OppableDict(e_plus) + e_minus_od: Union[OppableDict[torch.Tensor], float] = \ + OppableDict(e_minus) if isinstance(e_minus, OrderedDict) else e_minus + + msg["value"] = ((e_plus_od - e_minus_od) / e_den).wrapped + msg["done"] = True + + +class ConstraintHandler(pyro.poutine.messenger.Messenger): + def __init__(self, g: DecisionExpectigrandType, tau: float, threshold: float): + """ + :param g: The function of decision parameters and stochastics that feeds into the constraint. + :param tau: The scaling factor used to add the constraint gradient to the decision parameter gradient. This + is required to convert constrained problems into unconstrained problems. + :param threshold: The threshold value for the constraint. + """ + + self.g = g + self.tau = tt(tau) + self.threshold = tt(threshold) + +# TODO Proposal Optimization for Constraints +# So I think these constraints need to have expectation handlers passed to them. They can add them to the stack +# in the enter/exit, and then use the OOP strategy to call a specific optimize_proposal? I guess the constraints +# should be able to preempt the optimize_proposal call and do their own optimization — they have everything they need +# cz they have the constraint expectigrand passed directly to them. So no oop for optimize_proposal, the constraints +# just preempt the call, run optimize_proposal in the context of whatever internal constraint expectation estimators +# they happen to be using — they just need to block other expectation handlers during that execution. +# Remember that the DecisionOptimizationHandler automatically converts any optimize_proposal calls to operate over the +# gradient, so proposal optimization here just needs to use the raw, non-differentiated function. + + +class MeanConstraintHandler(ConstraintHandler): + + def __init__(self, *args, **kwargs): + + super().__init__(*args, **kwargs) + + def _pyro_post_build_expectigrand_gradient(self, msg) -> None: + kwargs = msg_args_kwargs_to_kwargs(msg) + dparams: KWTypeNNParams = kwargs["dparams"] + df_dd: ExpectigrandType = msg["value"] + + # Block all constraint handlers, because we just want the raw gradients here so we can add them. + with pyro.poutine.messenger.block_messengers(lambda m: isinstance(m, ConstraintHandler)): + dg_dd: ExpectigrandType = build_expectigrand_gradient(dparams, self.g) + + def rewrapped_df_dd(stochastics: KWType) -> KWType: + + # FIXME so the trick here is that these same stochastics aren't supposed to be sent into all + # three of these functions (df_dd, dg_dd, and g). Instead they're supposed to be three separate + # expectation operations. So what we need here is to actually make grad_expectation an + # operation, and then in post we need to call self.grad_expectation_handler.expectation and + # self.expectation_handler.expectation, and add everything up as needed. + + df_dd_ret = OppableDict(df_dd(stochastics)) + + cret = self.g(dparams, stochastics) + if len(cret) > 1: + raise ValueError(f"{self.__class__.__name__} only supports scalar constraints," + f" but got return values named {tuple(cret.keys())}") + + cval = tuple(cret.values())[0] + constraint_violated = cval > self.threshold + + if constraint_violated: + print("Constraint Violated!") + dg_dd_ret = OppableDict(dg_dd(stochastics)) + return (df_dd_ret + dg_dd_ret * self.tau).wrapped + else: + return df_dd_ret.wrapped + + msg["value"] = rewrapped_df_dd + + +class DecisionOptimizerHandler(pyro.poutine.messenger.Messenger): + + def __init__(self, dparams: KWTypeNNParams, lr: float, + proposal_update_steps: int = 0, proposal_update_lr: float = None, + proposal_adjust_grads_: Optional[Callable[[torch.nn.Parameter, ...], None]] = None): + self.dparams = dparams + self.lr = lr + self.proposal_update_steps = proposal_update_steps + self.proposal_update_lr = proposal_update_lr + self.proposal_adjust_grads_ = proposal_adjust_grads_ + + if self.proposal_update_steps > 0: + assert self.proposal_update_lr is not None, "Learning rate for proposal update must be specified if " \ + "proposal update steps is greater than 0." + + def _swap_f_for_df(self, msg) -> None: + kwargs = msg_args_kwargs_to_kwargs(msg) + f: DecisionExpectigrandType = kwargs["f"] + fp: ExpectigrandType = build_expectigrand_gradient(self.dparams, f) + kwargs["f"] = fp + msg["kwargs"] = kwargs + + def _pyro_optimize_proposal(self, msg) -> None: + # This handler replaces the function of interest with the gradient of that function with respect to the + # decision parameters. This supports any proposal optimization by giving it the correct target function. + self._swap_f_for_df(msg) + + def _swap_f_for_f_of_d(self, msg) -> None: + # This converts the decision function to a function just of stochastics so it can be used with + # standard expectation operations. + kwargs = msg_args_kwargs_to_kwargs(msg) + f: DecisionExpectigrandType = kwargs["f"] + + # Only actually swap this out if it's required. This allows users to pass in the results of + # build_expectigrand_gradient, which already returns a non-decision expectigrand, without + # having to block this handler's effect. + if len(inspect.signature(f).parameters) == 1: + return + + fp: ExpectigrandType = lambda stochastics: f(self.dparams, stochastics) + kwargs["f"] = fp + msg["kwargs"] = kwargs + + def _pyro_expectation(self, msg) -> None: + # Expectation calls within this handler need to include the decision parameters as arguments, + # so this converts things to the default case not involving decision parameters. + self._swap_f_for_f_of_d(msg) + + def _pyro_optimize_decision(self, msg) -> None: + kwargs = msg_args_kwargs_to_kwargs(msg) + f: DecisionExpectigrandType = kwargs["f"] + p: ModelType = kwargs["p"] + adjust_grads: Optional[Callable[[KWType], KWType]] = kwargs["adjust_grads"] + terminal_condition: Callable[[KWTypeNNParams, int], bool] = kwargs["terminal_condition"] + callback: Optional[Callable[[], None]] = kwargs["callback"] + + optim = torch.optim.SGD(list(self.dparams.values()), lr=self.lr) + + i = 0 + while not terminal_condition(self.dparams, i): + + optim.zero_grad() + # Not sure if this necessary. + for d in self.dparams.values(): + d.grad = None + + # Block self here because the gradient function already handles the partial evaluation. I.e. + # we don't want _swap_f_for_f_of_d called, because build_expectation_gradient already returns a valid + # ExpectigrandType. + with pyro.poutine.messenger.block_messengers(lambda m: m is self): + grad_estimate = expectation(f=build_expectigrand_gradient(self.dparams, f), p=p) + + res = OppableDict(grad_estimate).op(lambda v: torch.isfinite(v)) + if not torch.tensor(tuple(res.wrapped.values())).all(): + print("Warning: Non-finite gradient estimate encountered. Skipping update.") # TODO warnings module + continue + + if adjust_grads is not None: + grad_estimate: KWType = adjust_grads(grad_estimate) + + # Sanity check. + assert tuple(grad_estimate.keys()) == tuple(_gradify_dparam_name(k) for k in self.dparams.keys()) + + # Assign gradients to parameters. + for k, dp in self.dparams.items(): + dp.grad = grad_estimate[_gradify_dparam_name(k)] + + # And update those parameters with the specified optimzier. + optim.step() + + if self.proposal_update_steps > 0: + optimize_proposal(p=p, f=f, n_steps=self.proposal_update_steps, lr=self.proposal_update_lr, + adjust_grads_=self.proposal_adjust_grads_) + + if callback is not None: + callback() + + i += 1 + + # msg["value"] = tuple(d.item() for d in self.dparams) + msg["value"] = tuple(v.item() for v in self.dparams.values()) + msg["done"] = True + + +class MultiModalGuide1D(pyro.nn.PyroModule): + + @pyro.nn.PyroParam(constraint=dist.constraints.simplex) + def pi(self): + return torch.ones(self.num_components) / self.num_components + + @pyro.nn.PyroParam(constraint=dist.constraints.positive) + def scale(self): + return self.init_scale + + @pyro.nn.PyroParam(constraint=dist.constraints.real) + def loc(self): + return self.init_loc + + def __init__(self, *args, num_components: int, init_loc, init_scale, studentt=False, **kwargs): + super().__init__(*args, **kwargs) + + self.num_components = num_components + self.init_loc = torch.tensor(init_loc) + self.init_scale = torch.tensor(init_scale) + self.studentt = studentt + + if self.init_loc.shape != (self.num_components,): + raise ValueError("init_loc must be a tensor of shape (num_components,)") + if self.init_scale.shape != (self.num_components,): + raise ValueError("init_scale must be a tensor of shape (num_components,)") + + def forward(self): + component_idx = pyro.sample("component_idx", dist.Categorical(self.pi), infer={'is_auxiliary': True}) + scale = torch.clip(self.scale[component_idx], torch.tensor(0.05), torch.tensor(1.5)) + + if self.studentt: + x = pyro.sample("x", dist.StudentT(10, self.loc[component_idx], scale)) + else: + x = pyro.sample("x", dist.Normal(self.loc[component_idx], scale)) + return OrderedDict(x=x) + + def __call__(self, *args, **kwargs) -> KWType: + return super().__call__(*args, **kwargs) + + +def clip_decision_grads(grads: KWType): + # These blow up sometimes, so just clip them (they are 1d so no need to worry about norms). + ret_grads = OrderedDict() + for k in grads: + ret_grads[k] = torch.clip(grads[k], -2e-1, 2e-1) + return ret_grads + + +def abort_guide_grads_(*parameters: torch.nn.Parameter, lim=50.): + # These gradients also blow up, but clipping them causes weird non-convergence. Just aborting + # the gradient update seems to work. + if torch.any(torch.tensor([torch.any(param.grad > lim) for param in parameters])): + for param in parameters: + param.grad = torch.zeros_like(param.grad) + + +def main(): + + # See https://www.desmos.com/calculator/ixtzpb4l75 for analytical computation of the expectation. + # This all amounts to a functional on C and D, which can be set in the desmos graph. + C = tt(1.) + D = OrderedDict(d=torch.nn.Parameter(tt(0.5))) + NSnaive = 100000 + GT = -1.1337 + print(f"Ground Truth: {GT}") + + # Freeze this decision parameter, as we're using it as a non-optimizeable constant. + for d in D.values(): + d.requires_grad = False + + from toy_tabi_problem import ( + model as model_ttp, + cost as _cost_ttp, + q_optimal_normal_guide_mean_var, + MODEL_DIST as MODEL_TTP_DIST + ) + + # A scaled up cost function, just because the original one is small. + def dparam_scost_ttp(dparams: KWTypeNNParams, stochastics: KWType, c: torch.Tensor) -> KWType: + return OrderedDict(cost=1e1 * _cost_ttp(**dparams, **stochastics, c=c)) + + def get_scost_ttp(dparams: KWTypeNNParams, c: torch.Tensor) -> Callable[[KWType], KWType]: + return lambda stochastics: dparam_scost_ttp(dparams, stochastics, c=c) + + def get_dparam_scost_ttp(c: torch.Tensor) -> Callable[[KWTypeNNParams, KWType], KWType]: + return lambda dparams, stochastics: dparam_scost_ttp(dparams, stochastics, c=c) + + with MonteCarloExpectation(num_samples=NSnaive): + print(f"MCE TABI Toy (N={NSnaive})", + expectation(f=get_scost_ttp(dparams=D, c=C), p=model_ttp)["cost"]) + + def subopt_guide() -> KWType: + return OrderedDict(x=pyro.sample('x', dist.Normal(0.5, 2.))) + + with SNISExpectation( + q=subopt_guide, + num_samples=50000 + ): + print(f"SNIS SubOptGuide TABI Toy (N={NSnaive})", + expectation(f=get_scost_ttp(dparams=D, c=C), p=model_ttp)["cost"]) + + # Do SNIS again but with an optimized proposal. We use studentt here because SNIS requires the division of the + # original probability by the proposal probability, so if something is sampled in the tails of the proposal + # but not in the tails of the original, this will blow up and cause problems. It does this before it ever + # sees the cost function in order self-normalize the weights. + # noinspection PyUnresolvedReferences + snis_opt_guidep = dist.StudentT(1, *q_optimal_normal_guide_mean_var(**D, c=C, z=False)) + # noinspection PyUnresolvedReferences + snis_opt_guiden = dist.StudentT(1, *q_optimal_normal_guide_mean_var(**D, c=C, z=True)) + + def opt_guide_mix() -> KWType: + # noinspection PyUnresolvedReferences + if pyro.sample('z', dist.Bernoulli(probs=torch.tensor([0.5]))): + x = pyro.sample('x', snis_opt_guidep) + else: + x = pyro.sample('x', snis_opt_guiden) + + return OrderedDict(x=x) + + NS_snis_opt = 10000 + + with SNISExpectation( + # The bi-modal optimal proposal covers the product of the model and the absval of the cost function. + q=opt_guide_mix, + num_samples=NS_snis_opt + ): + print(f"SNIS OptGuide TABI Toy (N={NS_snis_opt})", + expectation(f=get_scost_ttp(dparams=D, c=C), p=model_ttp)["cost"]) + + tabi_opt_guidep = dist.Normal(*q_optimal_normal_guide_mean_var(**D, c=C, z=False)) + tabi_opt_guiden = dist.Normal(*q_optimal_normal_guide_mean_var(**D, c=C, z=True)) + + with TABIExpectation( + q_plus=lambda: OrderedDict(x=pyro.sample('x', tabi_opt_guidep)), + q_minus=lambda: OrderedDict(x=pyro.sample('x', tabi_opt_guiden)), + q_den=lambda: OrderedDict(x=pyro.sample('x', MODEL_TTP_DIST)), + num_samples=1 + ): + print(f"TABI Toy Exact (N=1)", expectation(f=get_scost_ttp(dparams=D, c=C), p=model_ttp)["cost"]) + + # + + # noinspection DuplicatedCode + def pos_comp(): + xp = model_ttp() + pos_fac = pyro.factor( + 'pos_fac', torch.log(1e-6 + torch.relu(dparam_scost_ttp(dparams=D, stochastics=xp, c=C)["cost"]))) + + return OrderedDict(x=xp, pos_fac=pos_fac) + + # noinspection DuplicatedCode + def neg_comp(): + xn = model_ttp() + neg_fac = pyro.factor( + 'neg_fac', torch.log(1e-6 + torch.relu(-dparam_scost_ttp(dparams=D, stochastics=xn, c=C)["cost"]))) + + return OrderedDict(x=xn, neg_fac=neg_fac) + + N_STEPS = 10000 + LR = 1e-3 + NS = 100 + + # noinspection DuplicatedCode + def run_svi_inference(model_, guide, n_steps=100, verbose=True): + elbo = pyro.infer.Trace_ELBO()(model_, guide) + elbo() + optim = torch.optim.SGD(elbo.parameters(), lr=LR) + for step in range(0, n_steps): + optim.zero_grad() + loss = elbo() + loss.backward() + optim.step() + if (step % 100 == 0) & verbose: + print("[iteration %04d] loss: %.4f" % (step, loss)) + return guide + + def get_num_guide_init(d, c, z: bool): + iloc, istd = q_optimal_normal_guide_mean_var(d=d, c=c, z=z) + return pyro.infer.autoguide.AutoNormal( + model=pos_comp, + init_loc_fn=pyro.infer.autoguide.initialization.init_to_value(values={'x': iloc}), + init_scale=istd.item() * 3., # scale up the std to give SVI something to do. + ), iloc, istd + pos_guide, pos_iloc, pos_istd = get_num_guide_init(**D, c=C, z=False) + neg_guide, neg_iloc, neg_istd = get_num_guide_init(**D, c=C, z=True) + + def get_den_guide_init(): + return pyro.infer.autoguide.AutoNormal( + model=model_ttp, + init_loc_fn=pyro.infer.autoguide.initialization.init_to_value( + # Very important to get a COPY of this tensor and not pass the model parameter itself. Otherwise + # when the guide updates the model will also change, which naturally leads to insanity. + values={'x': torch.tensor(MODEL_TTP_DIST.loc.item())}), + init_scale=MODEL_TTP_DIST.scale.item() * 3., # scale up the std to give SVI something to do. + ) + den_guide = get_den_guide_init() + + def plot_tabi_guides(tabi_handler: TABIExpectation, og_pos, og_neg, og_den): + plt.figure() + xx_ = torch.linspace(-10, 10, 1000) + sns.kdeplot([tabi_handler.q_plus.forward()['x'].item() for _ in range(10000)], + label='pos', linestyle='--', color='red') + sns.kdeplot([tabi_handler.q_minus.forward()['x'].item() for _ in range(10000)], + label='neg', linestyle='--', color='blue') + sns.kdeplot([tabi_handler.q_den.forward()['x'].item() for _ in range(10000)], + label='den', linestyle='--', color='green') + + plt.plot(xx_, og_den.log_prob(xx_).exp(), color='green', alpha=0.5) + plt.plot(xx_, og_pos.log_prob(xx_).exp(), color='red', alpha=0.5) + plt.plot(xx_, og_neg.log_prob(xx_).exp(), color='blue', alpha=0.5) + + plt.show() + + # # <-----Manual execution> + # + # run_svi_inference(pos_comp, pos_guide, n_steps=N_STEPS, verbose=False) + # run_svi_inference(neg_comp, neg_guide, n_steps=N_STEPS, verbose=False) + # run_svi_inference(model_ttp, den_guide, n_steps=N_STEPS, verbose=False) + # with TABIExpectation( + # q_plus=pos_guide, + # q_minus=neg_guide, + # q_den=den_guide, + # num_samples=NS + # ) as te: + # print(f"TABI Toy Learned Guide (Manual) (N={NS})", + # expectation(f=get_scost_ttp(D, c=C), p=model_ttp)) + # + # plot_tabi_guides(te, dist.Normal(pos_iloc, pos_istd), dist.Normal(neg_iloc, neg_istd), MODEL_TTP_DIST) + # + # # + + # + + # + + with TABIExpectation( + q_plus=get_num_guide_init(**D, c=C, z=False)[0], + q_minus=get_num_guide_init(**D, c=C, z=True)[0], + q_den=get_den_guide_init(), + num_samples=NS + ) as te: + optimize_proposal(p=model_ttp, f=get_scost_ttp(dparams=D, c=C), n_steps=N_STEPS, lr=LR) + + tabi_ress = [] + for _ in range(100): + tabi_ress.append(expectation(f=get_scost_ttp(dparams=D, c=C), p=model_ttp)["cost"]) + + plt.figure() + plt.suptitle(f"TABI Toy Learned Guide (Integrated) (N={NS})") + plt.hist(tabi_ress, bins=20) + plt.axvline(x=GT, color='black', linestyle='--') + + plot_tabi_guides(te, dist.Normal(pos_iloc, pos_istd), dist.Normal(neg_iloc, neg_istd), MODEL_TTP_DIST) + + # + + # + + # Wrap in a tensor so we can optimize it. + dparams = OrderedDict(d=torch.nn.Parameter(torch.tensor(-1.))) + cval = tt(2.) # A value of 2 makes this a bit more difficult than the above. + + dprogression = [dparams['d'].item()] + dgrads = [] + + # The guides for the gradients have to be multi-modal to track with the multi-modal cost function. + # While the positive and negative components of the non-differential are themselves unimodal, when working + # directly with the gradients, each pos/neg component has a positive and negative component themselves. + # This means each positive/negative guide component has to be multimodal. + + # Plot the d/dd expectigrand as a function of x. + plt.figure() + xx = torch.linspace(-5., 5., 100) + ddf = build_expectigrand_gradient(dparams, get_dparam_scost_ttp(c=cval)) + plt.plot(xx, [ + (OppableDict(ddf(OrderedDict(x=x)))*MODEL_TTP_DIST.log_prob(x).exp())[_gradify_dparam_name("d")].detach().item() + for x in xx]) + plt.show() + + # We initialize guides with components at both the positive and negative components of the cost function. They + # will then adjust to capture the two components of each side of the gradient. + pos_iloc_grad, pos_istd_grad = q_optimal_normal_guide_mean_var(**D, c=cval, z=False) + neg_iloc_grad, neg_istd_grad = q_optimal_normal_guide_mean_var(**D, c=cval, z=True) + pos_guide_grad = MultiModalGuide1D( + num_components=2, + init_loc=[pos_iloc_grad.item(), neg_iloc_grad.item()], + init_scale=[pos_istd_grad.item(), neg_istd_grad.item()] + ) + neg_guide_grad = MultiModalGuide1D( + num_components=2, + init_loc=[pos_iloc_grad.item(), neg_iloc_grad.item()], + init_scale=[pos_istd_grad.item(), neg_istd_grad.item()] + ) + # The denominator doesn't involve the cost function, so it stays the same. + + def plotss_(title, te): + # Plot the d/dd expectigrand as a function of x. + plt.figure() + plt.suptitle(title) + + def plot_part_(sign, color): + xy = [] + with pyro.poutine.trace() as og_tr: + te.get_part(sign, f=ddf, p=model_ttp)() + for x in torch.linspace(-4., 4., 1000): + og_tr.get_trace().nodes['x']['value'] = x + with ReplayMessenger(og_tr.trace): + with pyro.poutine.trace() as tr: + te.get_part(sign, f=ddf, p=model_ttp)() + xy.append(( + tr.get_trace().nodes['x']['value'].item(), + tr.get_trace().log_prob_sum().exp().item() + )) + xy = np.array(xy) + plt.plot(*xy.T, color=color, linestyle='--') + + plot_part_(1., 'blue') + plot_part_(-1., 'red') + + # Plot an sns density plot of the guides. Use the same figure. + sns.kdeplot([te.q_plus()['x'].item() for _ in range(1000)], label="q_plus", bw_method=0.05, color='blue') + sns.kdeplot([te.q_minus()['x'].item() for _ in range(1000)], label="q_minus", bw_method=0.05, color='red') + sns.kdeplot([te.q_den()['x'].item() for _ in range(1000)], label="q_den", bw_method=0.25, color='green') + + # Also plot the true model density. + sns.kdeplot([model_ttp()['x'].item() for _ in range(1000)], + label="model", bw_method=0.25, color='green', linestyle='--') + + plt.legend() + plt.xlim(-5., 5.) + plt.show() + plt.close() + + # # <----Manual Execution> + # with TABIExpectation( + # q_plus=pos_guide_grad, + # q_minus=neg_guide_grad, + # q_den=get_den_guide_init(), + # num_samples=1 + # ) as te_: + # + # plotss_("Before", te_) + # + # grad_estimate = expectation(f=ddf, p=model_ttp) + # print(f"TABI Toy Learned Grads (Truth: -0.0575) (Before): {grad_estimate}") + # + # for _ in range(5): + # optimize_proposal(p=model_ttp, f=ddf, n_steps=3000, lr=1e-4, adjust_grads_=abort_guide_grads_) + # + # plotss_("After", te_) + # + # grad_estimate = expectation(f=ddf, p=model_ttp) + # print(f"TABI Toy Learned Grads (Truth: -0.0575) (After): {grad_estimate}") + # + # # Iteratively optimize the decision variable and the proposals. + # optim_ = torch.optim.SGD(tuple(dparams.values()), lr=1e-1) + # + # ii = 1 + # cii = 0 + # while cii < 30: + # + # optim_.zero_grad() + # for dp in dparams.values(): + # dp.grad = None + # + # # Move the decision variable. + # grad_estimate = expectation(f=ddf, p=model_ttp) + # print("Gradient Estimate", grad_estimate["d"].item()) + # + # grad_estimate = clip_decision_grads(grad_estimate) + # + # for k, dp in dparams.items(): + # + # dp.grad = grad_estimate[k] + # + # optim_.step() + # + # # Then update the proposals. + # optimize_proposal(p=model_ttp, f=ddf, n_steps=300, lr=3e-5, adjust_grads_=abort_guide_grads_) + # + # if ii % 30 == 0: + # plotss_(f"Decision {round(dparams['d'].item(), 3)}", te_) + # + # # And append the state of dten. + # dprogression.append(dparams["d"].item()) + # dgrads.append(grad_estimate["d"].item()) + # + # ii += 1 + # + # if abs(dparams["d"].item()) < 1e-3: + # cii += 1 + # else: + # cii = 0 + # + # plt.figure() + # plt.plot(dprogression) + # plt.title("Decision Progression") + # + # plt.figure() + # plt.plot(dgrads) + # plt.title("Decision Gradients") + # + # plt.show() + # # + + # + + # + + te_ = TABIExpectation( + q_plus=pos_guide_grad, + q_minus=neg_guide_grad, + q_den=get_den_guide_init(), + num_samples=1) + dh_ = DecisionOptimizerHandler(dparams=dparams, lr=1e-1, proposal_update_lr=3e-5, proposal_update_steps=300, + proposal_adjust_grads_=abort_guide_grads_) + + with te_, dh_: + + plotss_("Before", te_) + + grad_estimate = expectation(f=ddf, p=model_ttp) + print(f"TABI Toy Learned Grads (Before): {grad_estimate}") + + # Perform initial optimization of proposals. + for _ in range(2): + # In the dh_ handler, the function here will be converted to the gradient with respect to d. + # TODO I don't know if I like this being implicit...? + optimize_proposal(p=model_ttp, f=get_dparam_scost_ttp(c=cval), n_steps=3000, lr=1e-4, + adjust_grads_=abort_guide_grads_) + + plotss_("After", te_) + + grad_estimate = expectation(f=ddf, p=model_ttp) + print(f"TABI Toy Learned Grads (After): {grad_estimate}") + + def terminal_condition_(dparams_, i): + + if i % 10 == 0: + plotss_(f"Decision {round(dparams_['d'].item(), 3)}", te_) + + ret = abs(dparams_['d'].item()) < 1e-2 + + if ret: + plotss_(f"Decision {round(dparams_['d'].item(), 3)}", te_) + + return ret + + optimize_decision(f=get_dparam_scost_ttp(c=cval), p=model_ttp, + terminal_condition=terminal_condition_, + adjust_grads=clip_decision_grads) + + # + + exit() + + +if __name__ == "__main__": + main() diff --git a/docs/source/expectation_programming/old_ep_demo_scratch_sri.py b/docs/source/expectation_programming/old_ep_demo_scratch_sri.py new file mode 100644 index 000000000..5b648a45b --- /dev/null +++ b/docs/source/expectation_programming/old_ep_demo_scratch_sri.py @@ -0,0 +1,621 @@ +from sri_tabi_problem import ReuseableSimulation, LOCKDOWN_TYPE, LockdownType +from torch import tensor as tt + +import pyro +import torch +from pyro.infer.autoguide import AutoMultivariateNormal +import pyro.distributions as dist +import matplotlib.pyplot as plt + +from enum import Enum + +from chirho.dynamical.ops import State, Trajectory + +from collections import OrderedDict + +from typing import ( + Optional, + Union, +) + +import numpy as np + +import old_ep_demo_scratch as stor + +from sklearn.neighbors import KernelDensity + +PT = torch.nn.Parameter +TT = torch.Tensor +UPTT = Union[PT, TT] + +if LOCKDOWN_TYPE == LockdownType.NONCONT_STATE: + DEFAULT_DPARAMS = DDP = OrderedDict( + # # Optimized. + # lockdown_trigger=tt(0.176), + # lockdown_lift_trigger=tt(0.457), + # lockdown_strength=tt(0.656) + + # lockdown_trigger=tt(0.23), + # lockdown_lift_trigger=tt(0.45), + # lockdown_strength=tt(0.69) + + lockdown_trigger=torch.nn.Parameter(tt(0.1)), + lockdown_lift_trigger=torch.nn.Parameter(tt(0.32)), + lockdown_strength=torch.nn.Parameter(tt(0.7)) + ) +elif LOCKDOWN_TYPE == LockdownType.CONT_PLATEAU: + DEFAULT_DPARAMS = DDP = OrderedDict( + # Optimal-ish for continouous plateau setup. + lockdown_trigger=tt(0.03), + lockdown_lift_trigger=tt(0.49), + lockdown_strength=tt(0.61) + + # # Sub-optimal but decent init for continuous plateau setup. + # lockdown_trigger=tt(0.08), + # lockdown_lift_trigger=tt(0.6), + # lockdown_strength=tt(0.5) + ) +elif LOCKDOWN_TYPE == LockdownType.NONCONT_TIME: + DEFAULT_DPARAMS = DDP = OrderedDict( + lockdown_trigger=tt(0.8), + lockdown_lift_trigger=tt(8.0), + lockdown_strength=tt(0.5) + ) + +DEFAULT_INIT_STATE = DIS = State(S=tt(0.99), I=tt(0.01), R=tt(0.0), L=tt(0.0), l=tt(0.0), O=tt(0.0)) + +DEFAULT_STOCHASTICS = DST = OrderedDict( + beta=tt(2.), + gamma=tt(.4), + capacity=tt(0.01), + hospitalization_rate=tt(0.05) +) + +DEFAULT_TIMES = DT = torch.linspace(0., 20., 100) + +if LOCKDOWN_TYPE == LockdownType.NONCONT_STATE: + OEXPO = 1. + OSCALING = 2e2 +elif LOCKDOWN_TYPE == LockdownType.NONCONT_TIME: + OEXPO = 1. + OSCALING = 2e2 +elif LOCKDOWN_TYPE == LockdownType.CONT_PLATEAU: + OSCALING = 1.3e2 + OEXPO = 1.5 + + +def o_transform(o: torch.Tensor) -> torch.Tensor: + return ((1. + o) ** OEXPO - 1.) * OSCALING + + +def copy_odict(odict: OrderedDict[str, torch.Tensor]) -> OrderedDict[str, torch.Tensor]: + return OrderedDict((k, tt(v.item())) for k, v in odict.items()) + + +def f_traj( + dparams: stor.KWTypeNNParams, + stochastics: stor.KWType, + rs: Optional[ReuseableSimulation] = None) -> Trajectory[torch.Tensor]: + + # If dparams or stochastics are missing keys defined in the defaults, fill them in with copies of the default + # values. Also wrap lockdown triggers in their appropriate state objects. + dparams = copy_odict(DEFAULT_DPARAMS) | dparams + stochastics = copy_odict(DEFAULT_STOCHASTICS) | stochastics + + if rs is None: + rs = ReuseableSimulation() + + for k, v in stochastics.items(): + # All stochastics need to be positive, so just do it here and not worry about it in the guide. + stochastics[k] = torch.abs(v) + + for k, v in dparams.items(): + # All dparams need to be positive and non-zero, so put them through a relu with eps added. + dparams[k] = torch.relu(v) + 1e-3 + + traj = rs( + **dparams, + init_state=DEFAULT_INIT_STATE, + **stochastics, + times=DEFAULT_TIMES) + + return traj + + +def f_combined(dparams: stor.KWTypeNNParams, stochastics: stor.KWType) -> stor.KWType: + traj = f_traj(dparams, stochastics) + + total_hospital_overrun = traj[-1].O + total_lockdown_unpleasantness = traj[-1].L + + # For now, just combine these into a single cost. In future we want to minimize lockdown with constraints + # on total hospital overrun. + return OrderedDict(cost=o_transform(total_hospital_overrun) + total_lockdown_unpleasantness) + + +def f_o_only( + dparams: stor.KWTypeNNParams, + stochastics: stor.KWType, + # TODO lo1dop6k This avoids redundancy in the cost and constraint calls, but isn't used right now. + rs: Optional[ReuseableSimulation] = None) -> stor.KWType: + traj = f_traj(dparams, stochastics, rs=rs) + + total_hospital_overrun = traj[-1].O + + return OrderedDict(cost=o_transform(total_hospital_overrun)) + + +def f_l_only( + dparams: stor.KWTypeNNParams, + stochastics: stor.KWType, + # TODO lo1dop6k See tag above. + rs: Optional[ReuseableSimulation] = None) -> stor.KWType: + traj = f_traj(dparams, stochastics, rs=rs) + + total_lockdown_unpleasantness = traj[-1].L + + return OrderedDict(cost=total_lockdown_unpleasantness) + + +def plot_basic(dparams=None, stochastics=None): + + if dparams is None: + dparams = OrderedDict() + if stochastics is None: + stochastics = OrderedDict() + + traj = f_traj(dparams, stochastics) + + fig, (ax1, ax3, ax2) = plt.subplots(3, 1, figsize=(7, 10)) + tax3 = ax3.twinx() + ax2.axhline(DST['capacity'] * (1. / DST['hospitalization_rate']), color='k', linestyle='--') + ax3.axhline(DST['capacity'] * (1. / DST['hospitalization_rate']), color='k', linestyle='--', + label='Healthcare Capacity') + ax2.plot(DT, traj.S, label='S', color='blue') + ax2.plot(DT, traj.I, label='I', color='red') + ax3.plot(DT, traj.I, label='I', color='red') + ax2.plot(DT, traj.R, label='R', color='green') + ax1.plot(DT, traj.L, label='Aggregate Lockdown', color='orange') + ax1.plot(DT, traj.l, label='Lockdown', color='orange', linestyle='--') + tax3.plot(DT, traj.O, label='Aggregate Overrun', color='k') + ax1.legend() + ax2.legend() + ax3.legend() + tax3.legend() + + ax2.set_xlabel('Time') + ax1.set_ylabel('Lockdown Strength') + ax2.set_ylabel('Proportion of Population') + ax3.set_ylabel('Proportion of Population') + tax3.set_ylabel('Aggregate Overrun') + + plt.tight_layout() + + plt.show() + + return + + +def plot_cost_vs_parameter(parameter_name, center: bool = False): + + if center: + c = DDP[parameter_name] + parameter_values = torch.linspace(torch.relu(c - 0.05), c + 0.05, 25) + else: + # Define the range of parameter values to consider + parameter_values = torch.linspace(0.01, 1.0, 100) + + # Initialize empty lists to store the corresponding cost values + cost_values = [] + o_only_values = [] + l_only_values = [] + + # Loop over the parameter values + for parameter_value in parameter_values: + # Calculate the cost for this parameter value + o_only = f_o_only(OrderedDict({parameter_name: parameter_value}), OrderedDict())['cost'] + l_only = f_l_only(OrderedDict({parameter_name: parameter_value}), OrderedDict())['cost'] + cost = f_combined(OrderedDict({parameter_name: parameter_value}), OrderedDict())['cost'] + + o_only_values.append(o_only.item()) + l_only_values.append(l_only.item()) + cost_values.append(cost.item()) + + # Create the plot + fig, ax = plt.subplots(1, 1) + ax.plot(parameter_values, cost_values) + ax.plot(parameter_values, o_only_values) + ax.plot(parameter_values, l_only_values) + ax.set_xlabel(parameter_name.capitalize()) + ax.set_ylabel('Cost') + ax.legend(['Combined', 'Overrun Only', 'Lockdown Only']) + ax.grid(True) + + plt.show() + + +def plot_cost_vs_parameters(parameter_name1, parameter_name2): + # Define the range of parameter values to consider + parameter_values1 = torch.linspace(0.01, 1.0, 10) # Adjust the number of points as needed + parameter_values2 = torch.linspace(0.01, 1.0, 10) # Adjust the number of points as needed + + # Create a meshgrid of parameter values + parameter_grid1, parameter_grid2 = torch.meshgrid(parameter_values1, parameter_values2) + + # Initialize empty tensors to store the corresponding cost values + cost_values = torch.zeros_like(parameter_grid1) + o_only_values = torch.zeros_like(parameter_grid1) + l_only_values = torch.zeros_like(parameter_grid1) + + # Loop over the parameter values + for i in range(parameter_values1.shape[0]): + for j in range(parameter_values2.shape[0]): + # Calculate the cost for this pair of parameter values + o_only = f_o_only(OrderedDict({parameter_name1: parameter_grid1[i, j], parameter_name2: parameter_grid2[i, j]}), OrderedDict())['cost'] + l_only = f_l_only(OrderedDict({parameter_name1: parameter_grid1[i, j], parameter_name2: parameter_grid2[i, j]}), OrderedDict())['cost'] + cost = f_combined(OrderedDict({parameter_name1: parameter_grid1[i, j], parameter_name2: parameter_grid2[i, j]}), OrderedDict())['cost'] + + o_only_values[i, j] = o_only.item() + l_only_values[i, j] = l_only.item() + cost_values[i, j] = cost.item() + + # Create the contour plots + fig, axs = plt.subplots(3, 1, figsize=(10, 12)) + + axs[0].contourf(parameter_grid1.numpy(), parameter_grid2.numpy(), o_only_values.numpy(), cmap='viridis') + axs[0].set_xlabel(parameter_name1.capitalize()) + axs[0].set_ylabel(parameter_name2.capitalize()) + axs[0].set_title('Overrun Cost') + fig.colorbar(ax=axs[0], mappable=axs[0].collections[0], label='Cost') + + axs[1].contourf(parameter_grid1.numpy(), parameter_grid2.numpy(), l_only_values.numpy(), cmap='viridis') + axs[1].set_xlabel(parameter_name1.capitalize()) + axs[1].set_ylabel(parameter_name2.capitalize()) + axs[1].set_title('Lockdown Cost') + fig.colorbar(ax=axs[1], mappable=axs[1].collections[0], label='Cost') + + axs[2].contourf(parameter_grid1.numpy(), parameter_grid2.numpy(), cost_values.numpy(), cmap='viridis') + axs[2].set_xlabel(parameter_name1.capitalize()) + axs[2].set_ylabel(parameter_name2.capitalize()) + axs[2].set_title('Combined Cost') + fig.colorbar(ax=axs[2], mappable=axs[2].collections[0], label='Cost') + + plt.tight_layout() + plt.show() + + +def _NNM_vectorized(f_, N, M, X, Y): + out = np.empty((N, N, M)) + it = np.nditer(out, flags=['multi_index']) + while not it.finished: + out[*it.multi_index[:-1], :] = f_(X[it.multi_index[:-1]], Y[it.multi_index[:-1]]) + it.iternext() + + return out + + +def plot_cost_likelihood_convolution_for_stochastics( + stochastic_name1: str, stochastic_name2: str, + p: stor.ModelType, f_: stor.ExpectigrandType, n=1000): + + samples = [] + for _ in range(n): # Generate 1000 samples + sample = p() + samples.append((sample[stochastic_name1].item(), sample[stochastic_name2].item())) + + samples = np.array(samples) + + kde: KernelDensity = KernelDensity(kernel='gaussian', bandwidth=0.02).fit(samples) + + def density(s1, s2) -> np.ndarray: + return np.array([np.exp(kde.score_samples([[s1, s2]]))]) + + def cost(s1, s2) -> np.ndarray: + vals = f_(OrderedDict({stochastic_name1: tt(s1), stochastic_name2: tt(s2)})).values() + return np.array(tuple(v.item() for v in vals)) + + def cust_colorbar(ax, arr): + ticks = np.linspace(0., 1.0, 10) + cbar = fig.colorbar(ax=ax, mappable=ax.collections[0], ticks=ticks) + arrmax = arr.max() + arrmin = arr.min() + tick_labels = ticks * (arrmax - arrmin) + arrmin + cbar.ax.set_yticklabels(['{:.1f}'.format(tick) for tick in tick_labels]) + + resolution = 15 + + s1ls = np.linspace(0.00, 1.0, resolution) + s2ls = np.linspace(0.00, 1.0, resolution) + + s1ls = s1ls * (samples[:, 0].max() - samples[:, 0].min()) + samples[:, 0].min() + s2ls = s2ls * (samples[:, 1].max() - samples[:, 1].min()) + samples[:, 1].min() + + X, Y = np.meshgrid(s1ls, s2ls) + + # Make a subplot for the density and one for each component of the cost. + cost_component_names = f_(OrderedDict({stochastic_name1: tt(s1ls[0]), stochastic_name2: tt(s2ls[0])})).keys() + num_cost_components = len(cost_component_names) + + fig, axs = plt.subplots(num_cost_components + 1, 3, figsize=(6*num_cost_components, 18)) + + # Plot the density in the top row. + density_array = _NNM_vectorized(f_=density, N=resolution, M=1, X=X, Y=Y) + for col in range(3): + axs[0][col].contourf(X, Y, density_array[..., 0], cmap='viridis') + cust_colorbar(axs[0][col], density_array[..., 0]) + + # Compute the cost array. We have to do this manually because there + cost_array = _NNM_vectorized(f_=cost, N=resolution, M=num_cost_components, X=X, Y=Y) + + assert cost_array.shape == (resolution, resolution, num_cost_components) + + # Plot the cost components in the left column. + for i, cost_component_name in enumerate(cost_component_names): + axs[i+1][0].contourf(X, Y, cost_array[:, :, i], cmap='viridis') + axs[i+1][0].set_title(cost_component_name.capitalize()) + cust_colorbar(axs[i+1][0], cost_array[:, :, i]) + + # This maybe doesn't give the right shape. + cost_density_convolution_array = cost_array * density_array + assert cost_density_convolution_array.shape == (resolution, resolution, num_cost_components) + + # Plot the positive part of the convolved cost components in the middle column. + + def plot_part(arr, col): + for i, cost_component_name in enumerate(cost_component_names): + + axs[i+1][col].contourf(X, Y, arr[:, :, i], cmap='viridis') + axs[i+1][col].set_title(cost_component_name + (' +' if col == 1 else ' -')) + cust_colorbar(axs[i+1][col], arr[:, :, i]) + + # And draw crosshairs at the original density mean, for comparison. + axs[i+1][col].axvline(x=samples[:, 0].mean(), color='white', linestyle='--') + axs[i+1][col].axhline(y=samples[:, 1].mean(), color='white', linestyle='--') + + plot_part(np.maximum(cost_density_convolution_array, 0.0), 1) + plot_part(-np.minimum(cost_density_convolution_array, 0.0), 2) + + plt.tight_layout() + + plt.show() + + +def pyro_prior_over_sirlo_params(): + beta = pyro.sample( + "beta", dist.Normal(2., 0.3)) + gamma = pyro.sample( + "gamma", dist.Normal(.4, .06)) + capacity = pyro.sample( + "capacity", dist.Normal(0.01, 0.003)) + hospitalization_rate = pyro.sample( + "hospitalization_rate", dist.Normal(0.05, 0.015)) + return OrderedDict( + beta=beta, + gamma=gamma, + capacity=capacity, + hospitalization_rate=hospitalization_rate + ) + + +def pyro_prior_over_sirlo_params_2d(): + beta = pyro.sample( + "beta", dist.Normal(2., 0.1)) + gamma = pyro.sample( + "gamma", dist.Normal(.4, .02)) + return OrderedDict( + beta=beta, + gamma=gamma + ) + + +def _grad_debugging(): + dparams = OrderedDict( + lockdown_trigger=torch.nn.Parameter(tt(DDP['lockdown_trigger'].item())), + lockdown_lift_trigger=torch.nn.Parameter(tt(DDP['lockdown_lift_trigger'].item())), + lockdown_strength=torch.nn.Parameter(tt(DDP['lockdown_strength'].item())) + ) + + overrun = f_o_only(dparams, OrderedDict()) + lockdown = f_l_only(dparams, OrderedDict()) + combined = f_combined(dparams, OrderedDict()) + + traj = f_traj(dparams, OrderedDict()) + + # Make sure the gradient of the lockdown trigger end state wrt the param is one. + dltdlt = torch.autograd.grad( + outputs=(traj.lockdown_trigger[-1],), + inputs=tuple(dparams.values()), + create_graph=True) + assert torch.isclose(dltdlt[0], tt(1.0)) + assert torch.isclose(dltdlt[1], tt(0.0)) + assert torch.isclose(dltdlt[2], tt(0.0)) + + # Make sure the gradient of the lockdown lift trigger end state wrt the param is zero. + dlftdlft = torch.autograd.grad( + outputs=(traj.lockdown_lift_trigger[-1],), + inputs=tuple(dparams.values()), + create_graph=True) + assert torch.isclose(dlftdlft[0], tt(0.0), atol=1e-4) + assert torch.isclose(dlftdlft[1], tt(1.0)) + assert torch.isclose(dlftdlft[2], tt(0.0)) + + if LOCKDOWN_TYPE != LockdownType.CONT_PLATEAU: + # Make sure the gradient of the lockdown strength wrt itself is one. + dlstdlst = torch.autograd.grad( + outputs=(traj.l[30],), # get strength when lockdown is active. + inputs=tuple(dparams.values()), + create_graph=True) + assert torch.isclose(dlstdlst[0], tt(0.0)) + assert torch.isclose(dlstdlst[1], tt(0.0)) + assert torch.isclose(dlstdlst[2], tt(1.0)) + + dCd = torch.autograd.grad( + outputs=(combined["cost"],), + inputs=tuple(dparams.values()), + # inputs=(dparams['lockdown_strength'],), + # inputs=(dparams['lockdown_trigger'],), + # inputs=(dparams['lockdown_lift_trigger'],), + create_graph=True) + + assert not torch.isclose(dCd[0], tt(0.0), atol=1e-4) + assert not torch.isclose(dCd[1], tt(0.0), atol=1e-4) + assert not torch.isclose(dCd[2], tt(0.0), atol=1e-4) + + +class ConstraintType(Enum): + JOINT = 1 + MEAN = 2 + + +def optimize_decision_2d_latent(constraint_type: ConstraintType): + + # Beta and gamma are the only latents. + p_ = pyro_prior_over_sirlo_params_2d + + dparams = OrderedDict( + lockdown_trigger=torch.nn.Parameter(tt(0.1)), + lockdown_lift_trigger=torch.nn.Parameter(tt(0.3)), + lockdown_strength=torch.nn.Parameter(tt(0.7)) + ) + + q_plus_guide = AutoMultivariateNormal( + model=p_, + ) + q_minus_guide = AutoMultivariateNormal( + model=p_ + ) + q_den_guide = AutoMultivariateNormal( + model=p_ + ) + + te = stor.TABIExpectation( + q_plus=q_plus_guide, + q_minus=q_minus_guide, + q_den=q_den_guide, + num_samples=1 + ) + + def abort_guide_grads_(*parameters: torch.nn.Parameter): + # These gradients also blow up, but clipping them causes weird non-convergence. Just aborting + # the gradient update seems to work. + if torch.any(torch.tensor([torch.any(torch.abs(param.grad) > 400.) for param in parameters])): + for param in parameters: + param.grad = torch.zeros_like(param.grad) + + dh = stor.DecisionOptimizerHandler( + dparams=dparams, + lr=1e-1, + proposal_update_lr=1e-4, + # TODO restore this once the constraints can get dedicated proposals. + proposal_update_steps=10 if constraint_type == ConstraintType.JOINT else 0, + proposal_adjust_grads_=abort_guide_grads_ + ) + + mc = stor.MeanConstraintHandler( + g=f_o_only, + tau=1., + threshold=0.001 + ) if constraint_type == ConstraintType.MEAN else pyro.poutine.messenger.Messenger() + + def terminal_condition(_, i: int) -> bool: + return i > 1000 # TODO more sophisticated convergence criterion. + + betagammas_q_plus_progression = [] + betagammas_q_minus_progression = [] + betagammas_q_den_progression = [] + lockdown_strength_progression = [] + lockdown_trigger_progression = [] + lockdown_lift_trigger_progression = [] + + def save_progressions(): + np.savez( + f"/Users/azane/Desktop/sirlo_logs/sirlo_opt.npz", + lockdown_strength_progression=np.array(lockdown_strength_progression), + lockdown_trigger_progression=np.array(lockdown_trigger_progression), + lockdown_lift_trigger_progression=np.array(lockdown_lift_trigger_progression), + beta_q_plus_progression=np.array([bg['beta'].detach().item() for bg in betagammas_q_plus_progression]), + gamma_q_plus_progression=np.array([bg['gamma'].detach().item() for bg in betagammas_q_plus_progression]), + beta_q_minus_progression=np.array([bg['beta'].detach().item() for bg in betagammas_q_minus_progression]), + gamma_q_minus_progression=np.array([bg['gamma'].detach().item() for bg in betagammas_q_minus_progression]), + beta_q_den_progression=np.array([bg['beta'].detach().item() for bg in betagammas_q_den_progression]), + gamma_q_den_progression=np.array([bg['gamma'].detach().item() for bg in betagammas_q_den_progression]), + ) + + class OptimizeProposalCallback: + + init_optimize_proposal_iterations = 0 + + def __call__(self): + self.init_optimize_proposal_iterations += 1 + print(f"Optimizing proposal {self.init_optimize_proposal_iterations}.") + + betagammas_q_plus_progression.append(q_plus_guide.forward()) + betagammas_q_minus_progression.append(q_minus_guide.forward()) + betagammas_q_den_progression.append(q_den_guide.forward()) + + save_progressions() + + class OptimizeDecisionCallback: + optimize_decision_iterations = 0 + + def __call__(self): + self.optimize_decision_iterations += 1 + print(f"Optimizing decision {self.optimize_decision_iterations}.") + + lockdown_strength_progression.append(dh.dparams['lockdown_strength'].detach().item()) + lockdown_trigger_progression.append(dh.dparams['lockdown_trigger'].detach().item()) + lockdown_lift_trigger_progression.append(dh.dparams['lockdown_lift_trigger'].detach().item()) + + betagammas_q_plus_progression.append(q_plus_guide.forward()) + betagammas_q_minus_progression.append(q_minus_guide.forward()) + betagammas_q_den_progression.append(q_den_guide.forward()) + + save_progressions() + + with te, dh, mc: + + # # This changes what optimize_decision sees as the stochastic program to be the product of a conditioned + # # inference procedure. + # with pyro.condition(sir_noisy_data): + + # TODO enable for explicit constraints can get dedicated proposals. + if constraint_type == ConstraintType.JOINT: + # Initial optimization of proposals. + stor.optimize_proposal( + p=p_, + f=f_combined, + n_steps=10, + lr=1e-4, + adjust_grads_=dh.proposal_adjust_grads_, + callback=OptimizeProposalCallback() + ) + + optimal_decision = stor.optimize_decision( + p=p_, + f=f_combined if constraint_type == ConstraintType.JOINT else f_l_only, + terminal_condition=terminal_condition, + adjust_grads=lambda g: OrderedDict([(k, torch.clip(g[k], -1./100., 1./100.)) for k in g.keys()]), + callback=OptimizeDecisionCallback(), + ) + + return # just here to put a breakpoint on + + +if __name__ == '__main__': + + # plot_basic() + + # plot_cost_vs_parameter('lockdown_strength') + # plot_cost_vs_parameter('lockdown_trigger') + # plot_cost_vs_parameter('lockdown_lift_trigger') + # plot_cost_vs_parameters('lockdown_strength', 'lockdown_trigger') + # plot_cost_vs_parameters('lockdown_strength', 'lockdown_lift_trigger') + # plot_cost_vs_parameters('lockdown_trigger', 'lockdown_lift_trigger') + + optimize_decision_2d_latent(constraint_type=ConstraintType.MEAN) + # plot_basic(OrderedDict(lockdown_strength=tt(.663))) + + # _grad_debugging() + + # plot_cost_likelihood_convolution_for_stochastics( + # 'beta', 'gamma', pyro_prior_over_sirlo_params_2d, f_=stor.build_expectigrand_gradient(DDP, f_combined)) diff --git a/docs/source/expectation_programming/sri_tabi_problem.py b/docs/source/expectation_programming/sri_tabi_problem.py new file mode 100644 index 000000000..9be5be17b --- /dev/null +++ b/docs/source/expectation_programming/sri_tabi_problem.py @@ -0,0 +1,197 @@ +import pyro +import torch +import pyro.distributions as dist + +from chirho.dynamical.handlers import ( + DynamicIntervention, + SimulatorEventLoop, + simulate, + ODEDynamics, +) + +from chirho.dynamical.ops import State +from torch import tensor as tt +from enum import Enum + + +# Largely for debugging... +class LockdownType(Enum): + CONT_PLATEAU = 1 + NONCONT_STATE = 2 + NONCONT_TIME = 3 + + +LOCKDOWN_TYPE = LockdownType.NONCONT_STATE + + +class SimpleSIRDynamics(ODEDynamics): + def __init__(self, beta, gamma): + super().__init__() + self.beta = beta + self.gamma = gamma + + def diff(self, dX: State[torch.Tensor], X: State[torch.Tensor]): + dX.S = -self.beta * X.S * X.I + dX.I = self.beta * X.S * X.I - self.gamma * X.I + dX.R = self.gamma * X.I + + def observation(self, X: State[torch.Tensor]): + + I_obs = pyro.sample(f"I_obs", dist.Poisson(X.I)) # noisy number of infected actually observed + R_obs = pyro.sample(f"R_obs", dist.Poisson(X.R)) # noisy number of recovered actually observed + + return { + f"I_obs": I_obs, + f"R_obs": R_obs, + } + + +class SimpleSIRDynamicsLockdown(SimpleSIRDynamics): + def __init__(self, beta0, gamma): + super().__init__(torch.zeros_like(gamma), gamma) + self.beta0 = beta0 + + def diff(self, dX: State[torch.Tensor], X: State[torch.Tensor]): + + # Lockdown strength is a piecewise constant function affected by dynamic interventions, so dX is 0. + dX.l = torch.tensor(0.0) + + if LOCKDOWN_TYPE == LOCKDOWN_TYPE.CONT_PLATEAU: + # Pretend the trigger is also in terms of the number recovered. + recovered_trigger = X.lockdown_trigger + # Pretend the lift trigger is in terms of how many additional recovered need to trigger. + recovered_lift_trigger = recovered_trigger + X.lockdown_lift_trigger + # Definep dX.l as a continuous plateau. + plateau_u = (recovered_lift_trigger + recovered_trigger) / tt(2.) + plateau_s = (recovered_lift_trigger - recovered_trigger) / tt(2.) + dXL_override = X.static_lockdown_strength * torch.exp(-((X.R - plateau_u) / plateau_s)**tt(10)) + + dX.L = dXL_override + self.beta = (1 - dXL_override) * self.beta0 + else: + # Time-varing beta parametrized by lockdown strength l_t + self.beta = (1 - X.l) * self.beta0 + + # Accrual of lockdown time. + dX.L = X.l + + # Constant event parameters have to be in the state in order for torchdiffeq to give derivs. + dX.lockdown_trigger = tt(0.0) + dX.lockdown_lift_trigger = tt(0.0) + dX.static_lockdown_strength = tt(0.0) + + # Call the base SIR class diff method + super().diff(dX, X) + + +class SimpleSIRDynamicsLockdownCapacityOverrun(SimpleSIRDynamicsLockdown): + def __init__(self, beta0, gamma, capacity, hospitalization_rate): + super().__init__(beta0, gamma) + + self.capacity = capacity + self.hospitalization_rate = hospitalization_rate + + def diff(self, dX: State[torch.Tensor], X: State[torch.Tensor]): + # If the number of infected individuals needing hospitalization exceeds the capacity, accrue that difference + # in the overrun factor. + dX.O = torch.relu(X.I * self.hospitalization_rate - self.capacity) + + super().diff(dX, X) + + +def initiate_lockdown(t: torch.tensor, state: State[torch.tensor]): + target_recovered = state.lockdown_lift_trigger + target_infected = state.lockdown_trigger + + if LOCKDOWN_TYPE == LockdownType.NONCONT_STATE: + # To enact the policy, require that the lift trigger hasn't effectively already fired. + # Do this by saying that the infected count is treated as zero if the lift trigger had fired. + # Let's say the lockdown is approaching its firing state (infected count is increasing), then this value + # will start going to zero, but if the lift trigger had fired, then it will jump up to the non-zero + # target state again and the lockdown will never fire. + return target_infected - state.I * torch.greater_equal(target_recovered, state.R).type(state.R.dtype) + elif LOCKDOWN_TYPE == LockdownType.NONCONT_TIME: + return target_infected - t + elif LOCKDOWN_TYPE == LockdownType.CONT_PLATEAU: + # Disabled entirely for continuous plateau. + return tt(1.) + + +def lift_lockdown(t: torch.tensor, state: State[torch.tensor]): + target_recovered = state.lockdown_lift_trigger + + if LOCKDOWN_TYPE == LockdownType.NONCONT_STATE: + # To lift the policy, require that the recovered count exceeds a certain level, and that a lockdown of + # non-zero strength is in place. + return target_recovered - state.R + torch.isclose(state.l, torch.tensor(0.0)).type(state.l.dtype) + elif LOCKDOWN_TYPE == LockdownType.NONCONT_TIME: + return target_recovered - t + elif LOCKDOWN_TYPE == LockdownType.CONT_PLATEAU: + # Disabled entirely for continuous plateau. + return tt(1.) + + +class ReuseableSimulation: + # TODO can something like this be accomplished with functools partial? + def __init__(self): + self.result = None + + @staticmethod + def _inner_call(lockdown_trigger, lockdown_lift_trigger, lockdown_strength, + init_state, beta, gamma, capacity, hospitalization_rate, times, **kwargs): + + # Make a new state object so we can add the lockdown trigger constants to the state without modifying the + # original. This is required because torchdiffeq requires that even constant event parameters must be in the + # state in order to take gradients with respect to them. + new_init_state = State() + for k in init_state.keys: + setattr(new_init_state, k, getattr(init_state, k)) + setattr(new_init_state, "lockdown_trigger", lockdown_trigger) + setattr(new_init_state, "lockdown_lift_trigger", lockdown_lift_trigger) + setattr(new_init_state, "static_lockdown_strength", lockdown_strength) + + if torch.isclose(lockdown_trigger, torch.tensor(0.0)) or torch.less(lockdown_trigger, torch.tensor(0.0)): + raise ValueError("Lockdown trigger must be greater than zero.") + + if torch.isclose(lockdown_lift_trigger, torch.tensor(0.0)) or torch.less(lockdown_lift_trigger, torch.tensor(0.0)): + raise ValueError("Lockdown lift trigger must be greater than zero.") + + sir = SimpleSIRDynamicsLockdownCapacityOverrun(beta, gamma, capacity, hospitalization_rate) + with SimulatorEventLoop(): + with DynamicIntervention(event_f=initiate_lockdown, + intervention=State(l=lockdown_strength), + var_order=new_init_state.var_order, max_applications=1, ): + with DynamicIntervention(event_f=lift_lockdown, + intervention=State(l=torch.tensor(0.0)), var_order=new_init_state.var_order, + max_applications=1): + tspan = times + if not torch.isclose(tspan[0], torch.tensor(0.0)): + tspan = torch.cat([torch.tensor([0.]), tspan]) + + soln = simulate(sir, new_init_state, tspan) + + return soln + + def __call__(self, *args, **kwargs): + if self.result is None: + self.result = self._inner_call(*args, **kwargs) + else: + # TODO assert that the arguments are the same? + pass + + return self.result + + +def cost(dparams, stochastics, end_time, reuseable_sim) -> torch.Tensor: + sim_results = reuseable_sim(dparams, stochastics, end_time) + + # The cost of the lockdown policy is the total lockdown time times the severity. See lockdown diff. + return sim_results[-1].L + + +def failure_magnitude(dparams, stochastics, end_time, reuseable_sim) -> torch.Tensor: + sim_results = reuseable_sim(dparams, stochastics, end_time) + + # The failure magnitude of lockdown policy is the total time spent by infected individuals who needed + # hospitalization but could not get it. + return sim_results[-1].O diff --git a/docs/source/expectation_programming/toy_tabi_problem.py b/docs/source/expectation_programming/toy_tabi_problem.py new file mode 100644 index 000000000..2f989efae --- /dev/null +++ b/docs/source/expectation_programming/toy_tabi_problem.py @@ -0,0 +1,129 @@ +import torch +import pyro +import pyro.distributions as dist +from collections import OrderedDict +from torch import tensor as tt + + +def h(d, c): + # Convert the decision parameter to a convex space with defined maximum (1+c) and minimum (c). + # This is the root form of a student-t distribution, giving us longer tails and a better ramp up for + # the decision parameter gradients. + return (1. + d**2.)**-1. + c + + +def cost_part(d, x, c, z: bool): + # This builds out a part of the cost function. It's a normal positioned in the tails to the degree defined by + # c. If z is true, this is the negative part of the cost function in the left tail, otherwise it's the positive + # part of the cost function in the right tail. Each part is normal-distribution shaped with a mean of h(d, c) + # and h(d, c) - 3 * c, respectively. The variance is fixed at .2**2. + # numerator = -(x - h(d, c) + 3 * float(z) * c)**2. + # denominator = torch.tensor(.2**2.) + # return .5 * torch.exp(numerator / denominator) + + mean = h(d, c) - 3 * float(z) * c + std = .2 + return torch.exp(dist.Normal(mean, std).log_prob(x)) + + +def cost(d: torch.Tensor, x: torch.Tensor, c: torch.Tensor, **kwargs) -> torch.Tensor: + # The additional kwargs just takes stochastics and decision parameters that don't actually play into this + # cost function. This is required because all decision parameters and all stochastics are unpacked here + # as keyword arguments to the cost function, and there are auxiliary stochastics we don't care about. + return cost_part(d, x, c, False) - cost_part(d, x, c, True) + + +def q_optimal_normal_guide_mean_var(d, c, z: bool): + # This is the product of the standard normal and the cost normal pdfs. + # See e.g. https://www.johndcook.com/blog/2012/10/29/product-of-normal-pdfs/ for more details. + m = h(d, c) - 3 * float(z) * c + vp = tt(.2**2.) + vn = tt(.2**-2) + + new_mean = (vn * m) / (vn + 1) + new_var = vp / (1 + vp) + + return new_mean, torch.sqrt(new_var) + + +MODEL_DIST = dist.Normal(0., 1.) + + +def model(): + # The model is simply a unit normal. + return OrderedDict(x=pyro.sample("x", MODEL_DIST)) + + +if __name__ == "__main__": + + import matplotlib.pyplot as plt + + xx = torch.linspace(-5., 5., 300) + + c_ = tt(1.5) + d1_ = tt(2.) + d2_ = tt(0.) + d3_ = tt(0.5) + + # Plot the decision parameter re-mapping. + plt.figure() + plt.suptitle('Decision Parameter Re-Mapping') + plt.plot(xx, h(xx, c_)) + + # Plot the cost function for the two decision parameters. + fig, (ax1, ax2, ax3) = plt.subplots(3, 1) + plt.suptitle('Cost by Model on Stochastics') + + # Plot the cost function for the two decision parameters. Put in the different subplots. + ax1.plot(xx, cost(d1_, xx, c_), color='red', label=f"d={d1_}") + ax2.plot(xx, cost(d2_, xx, c_), color='red', label=f"d={d2_}") + ax3.plot(xx, cost(d3_, xx, c_), color='red', label=f"d={d3_}") + + # Plot the model on top of the cost function. + ax1.plot(xx, torch.exp(MODEL_DIST.log_prob(xx)), color='blue') + ax2.plot(xx, torch.exp(MODEL_DIST.log_prob(xx)), color='blue') + ax3.plot(xx, torch.exp(MODEL_DIST.log_prob(xx)), color='blue') + + # Plot the optimal guides for the two parts of the cost function. These are normal distributions with means and + # variances calculated from the product of the model pdf and the cost function "pdf". + for d_, ax in zip([d1_, d2_, d3_], [ax1, ax2, ax3]): + gpm, gms = q_optimal_normal_guide_mean_var(d_, c_, False) + gnm, gns = q_optimal_normal_guide_mean_var(d_, c_, True) + ax.plot(xx, torch.exp(-((xx - gpm)/gms)**tt(2.))*.4, color='orange', label='Optimal Guide') + ax.plot(xx, torch.exp(-((xx - gnm)/gns)**tt(2.))*.4, color='orange') + + ax1.legend() + ax2.legend() + ax3.legend() + + # In a different plot, show the unnormalized positive component (using torch.relu) of the cost function + # multiplied by the model pdf. Also show the negative component (this can be achieved by using torch.relu(-x)). + # Just do this for d3. + fig1, ax = plt.subplots() + tax = ax.twinx() + tax.plot(xx, torch.exp(MODEL_DIST.log_prob(xx)), color='blue', label=f"d={d3_}", linestyle='--') + ax.plot(xx, torch.exp(MODEL_DIST.log_prob(xx)) * torch.relu(1e2*cost(d3_, xx, c_)), color='green', label=f"d={d3_}") + ax.plot(xx, torch.exp(MODEL_DIST.log_prob(xx)) * torch.relu(-1e2*cost(d3_, xx, c_)), color='red', label=f"d={d3_}") + + # Plot just the cost function with a thin line. + ax.plot(xx, cost(d3_, xx, c_), color='black', linestyle='--', linewidth=0.4) + + # Now plot the properly scaled normal distributions that map to the normalizations of the curves plotted above. + gpm, gms = q_optimal_normal_guide_mean_var(d3_, c_, False) + gnm, gns = q_optimal_normal_guide_mean_var(d3_, c_, True) + + ax.plot(xx, torch.exp(dist.Normal(gpm, gms).log_prob(xx)), color='orange', linestyle='--') + ax.plot(xx, torch.exp(dist.Normal(gnm, gns).log_prob(xx)), color='orange', linestyle='--') + + # In a different figure, show the ratio of the model-scaled positive and negative components of the cost function + # with respect to the properly normalized optimal guides. Show this across the same xx. + fig2, ax = plt.subplots() + num_pos = torch.exp(MODEL_DIST.log_prob(xx)) * torch.relu(1e2 * cost(d3_, xx, c_)) + den_pos = torch.exp(dist.Normal(gpm, gms).log_prob(xx)) + plt.plot(xx, num_pos / den_pos, color='green', label='Positive Component') + + num_neg = torch.exp(MODEL_DIST.log_prob(xx)) * torch.relu(-1e2 * cost(d3_, xx, c_)) + den_neg = torch.exp(dist.Normal(gnm, gns).log_prob(xx)) + plt.plot(xx, num_neg / den_neg, color='red', label='Negative Component') + + plt.show()