Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add modified Bayesian regression tutorial with more direct PyTorch usage #2996

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pyro.infer.inspect import render_model
from pyro.logger import log
from pyro.poutine import condition, do, markov
from pyro.poutine.runtime import enable_module_local_param
from pyro.primitives import (
barrier,
clear_param_store,
Expand All @@ -21,6 +22,7 @@
random_module,
sample,
subsample,
use_param_store,
validation_enabled,
)
from pyro.util import set_rng_seed
Expand All @@ -42,6 +44,7 @@
"deterministic",
"do",
"enable_validation",
"enable_module_local_param",
"factor",
"get_param_store",
"iarange",
Expand All @@ -59,5 +62,6 @@
"sample",
"set_rng_seed",
"subsample",
"use_param_store",
"validation_enabled",
]
16 changes: 16 additions & 0 deletions pyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,26 @@
import warnings
from abc import ABCMeta, abstractmethod

import torch

import pyro
import pyro.poutine as poutine
from pyro.infer.util import is_validation_enabled
from pyro.poutine.util import prune_subsample_sites
from pyro.util import check_site_shape


class _ELBOModule(torch.nn.Module):
def __init__(self, model, guide, elbo):
super().__init__()
self.model = model
self.guide = guide
self.elbo = elbo

def forward(self, *args, **kwargs):
return self.elbo.differentiable_loss(self.model, self.guide, *args, **kwargs)


class ELBO(object, metaclass=ABCMeta):
"""
:class:`ELBO` is the top-level interface for stochastic variational
Expand Down Expand Up @@ -86,6 +99,9 @@ def __init__(
self.jit_options = jit_options
self.tail_adaptive_beta = tail_adaptive_beta

def __call__(self, model, guide):
return _ELBOModule(model, guide, self)

def _guess_max_plate_nesting(self, model, guide, args, kwargs):
"""
Guesses max_plate_nesting by running the (model,guide) pair once
Expand Down
70 changes: 46 additions & 24 deletions pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from torch.distributions import constraints, transform_to

import pyro
from pyro.poutine.runtime import _PYRO_PARAM_STORE
from pyro.params.param_store import ParamStoreDict
from pyro.poutine.runtime import _PYRO_PARAM_STORE, _module_local_param_enabled


class PyroParam(namedtuple("PyroParam", ("init_value", "constraint", "event_dim"))):
Expand Down Expand Up @@ -380,6 +381,8 @@ def __init__(self, name=""):
self._pyro_context = _Context() # shared among sub-PyroModules
self._pyro_params = OrderedDict()
self._pyro_samples = OrderedDict()
if _module_local_param_enabled():
self._pyro_param_store = ParamStoreDict()
super().__init__()

def add_module(self, name, module):
Expand Down Expand Up @@ -407,6 +410,12 @@ def named_pyro_params(self, prefix="", recurse=True):
for elem in gen:
yield elem

def _pyro_param_local(self, *args, **kwargs):
with pyro.use_param_store(
getattr(self, "_pyro_param_store", _PYRO_PARAM_STORE)
):
return pyro.param(*args, **kwargs)

def _pyro_set_supermodule(self, name, context):
self._pyro_name = name
self._pyro_context = context
Expand Down Expand Up @@ -434,33 +443,37 @@ def __getattr__(self, name):
unconstrained_value = getattr(self, name + "_unconstrained")
if self._pyro_context.active:
fullname = self._pyro_get_fullname(name)
if fullname in _PYRO_PARAM_STORE:
if fullname in self._pyro_param_store:
if (
_PYRO_PARAM_STORE._params[fullname]
self._pyro_param_store._params[fullname]
is not unconstrained_value
):
# Update PyroModule <--- ParamStore.
unconstrained_value = _PYRO_PARAM_STORE._params[fullname]
unconstrained_value = self._pyro_param_store._params[
fullname
]
if not isinstance(unconstrained_value, torch.nn.Parameter):
# Update PyroModule ---> ParamStore (type only; data is preserved).
unconstrained_value = torch.nn.Parameter(
unconstrained_value
)
_PYRO_PARAM_STORE._params[
self._pyro_param_store._params[
fullname
] = unconstrained_value
_PYRO_PARAM_STORE._param_to_name[
self._pyro_param_store._param_to_name[
unconstrained_value
] = fullname
super().__setattr__(
name + "_unconstrained", unconstrained_value
)
else:
# Update PyroModule ---> ParamStore.
_PYRO_PARAM_STORE._constraints[fullname] = constraint
_PYRO_PARAM_STORE._params[fullname] = unconstrained_value
_PYRO_PARAM_STORE._param_to_name[unconstrained_value] = fullname
return pyro.param(fullname, event_dim=event_dim)
self._pyro_param_store._constraints[fullname] = constraint
self._pyro_param_store._params[fullname] = unconstrained_value
self._pyro_param_store._param_to_name[
unconstrained_value
] = fullname
return self._pyro_param_local(fullname, event_dim=event_dim)
else: # Cannot determine supermodule and hence cannot compute fullname.
return transform_to(constraint)(unconstrained_value)

Expand Down Expand Up @@ -491,7 +504,7 @@ def __getattr__(self, name):
"_unconstrained"
):
if self._pyro_context.active:
pyro.param(self._pyro_get_fullname(name), result)
self._pyro_param_local(self._pyro_get_fullname(name), result)

if isinstance(result, torch.nn.Module):
if isinstance(result, PyroModule):
Expand All @@ -508,6 +521,11 @@ def __getattr__(self, name):
return result

def __setattr__(self, name, value):

if isinstance(value, ParamStoreDict):
super().__setattr__(name, value)
return

if isinstance(value, PyroModule):
# Create a new sub PyroModule, overwriting any old value.
try:
Expand All @@ -527,19 +545,21 @@ def __setattr__(self, name, value):
self._pyro_params[name] = constraint, event_dim
if self._pyro_context.active:
fullname = self._pyro_get_fullname(name)
pyro.param(
self._pyro_param_local(
fullname,
constrained_value,
constraint=constraint,
event_dim=event_dim,
)
constrained_value = pyro.param(fullname)
constrained_value = self._pyro_param_local(fullname)
unconstrained_value = constrained_value.unconstrained()
if not isinstance(unconstrained_value, torch.nn.Parameter):
# Update PyroModule ---> ParamStore (type only; data is preserved).
unconstrained_value = torch.nn.Parameter(unconstrained_value)
_PYRO_PARAM_STORE._params[fullname] = unconstrained_value
_PYRO_PARAM_STORE._param_to_name[unconstrained_value] = fullname
self._pyro_param_store._params[fullname] = unconstrained_value
self._pyro_param_store._param_to_name[
unconstrained_value
] = fullname
else: # Cannot determine supermodule and hence cannot compute fullname.
unconstrained_value = _unconstrain(constrained_value, constraint)
super().__setattr__(name + "_unconstrained", unconstrained_value)
Expand All @@ -553,12 +573,12 @@ def __setattr__(self, name, value):
pass
if self._pyro_context.active:
fullname = self._pyro_get_fullname(name)
value = pyro.param(fullname, value)
value = self._pyro_param_local(fullname, value)
if not isinstance(value, torch.nn.Parameter):
# Update PyroModule ---> ParamStore (type only; data is preserved).
value = torch.nn.Parameter(value)
_PYRO_PARAM_STORE._params[fullname] = value
_PYRO_PARAM_STORE._param_to_name[value] = fullname
self._pyro_param_store._params[fullname] = value
self._pyro_param_store._param_to_name[value] = fullname
super().__setattr__(name, value)
return

Expand Down Expand Up @@ -590,19 +610,19 @@ def __delattr__(self, name):
del self._parameters[name]
if self._pyro_context.used:
fullname = self._pyro_get_fullname(name)
if fullname in _PYRO_PARAM_STORE:
if fullname in self._pyro_param_store:
# Update PyroModule ---> ParamStore.
del _PYRO_PARAM_STORE[fullname]
del self._pyro_param_store[fullname]
return

if name in self._pyro_params:
delattr(self, name + "_unconstrained")
del self._pyro_params[name]
if self._pyro_context.used:
fullname = self._pyro_get_fullname(name)
if fullname in _PYRO_PARAM_STORE:
if fullname in self._pyro_param_store:
# Update PyroModule ---> ParamStore.
del _PYRO_PARAM_STORE[fullname]
del self._pyro_param_store[fullname]
return

if name in self._pyro_samples:
Expand All @@ -613,9 +633,9 @@ def __delattr__(self, name):
del self._modules[name]
if self._pyro_context.used:
fullname = self._pyro_get_fullname(name)
for p in list(_PYRO_PARAM_STORE.keys()):
for p in list(self._pyro_param_store.keys()):
if p.startswith(fullname):
del _PYRO_PARAM_STORE[p]
del self._pyro_param_store[p]
return

super().__delattr__(name)
Expand Down Expand Up @@ -699,6 +719,8 @@ def to_pyro_module_(m, recurse=True):
m._pyro_context = _Context()
m._pyro_params = OrderedDict()
m._pyro_samples = OrderedDict()
if _module_local_param_enabled():
m._pyro_param_store = ParamStoreDict()

# Reregister parameters and submodules.
for name, value in list(m._parameters.items()):
Expand Down
12 changes: 12 additions & 0 deletions pyro/poutine/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@
# the global ParamStore
_PYRO_PARAM_STORE = ParamStoreDict()

# toggle usage of local param stores in PyroModules
_PYRO_MODULE_LOCAL_PARAM = False


def enable_module_local_param(flag: bool) -> None:
global _PYRO_MODULE_LOCAL_PARAM
_PYRO_MODULE_LOCAL_PARAM = flag


def _module_local_param_enabled():
return _PYRO_MODULE_LOCAL_PARAM


class _DimAllocator:
"""
Expand Down
15 changes: 14 additions & 1 deletion pyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pyro.poutine as poutine
from pyro.distributions import constraints
from pyro.params import param_with_module_name
from pyro.params.param_store import ParamStoreDict
from pyro.poutine.plate_messenger import PlateMessenger
from pyro.poutine.runtime import (
_MODULE_NAMESPACE_DIVIDER,
Expand Down Expand Up @@ -45,7 +46,19 @@ def clear_param_store():
return _PYRO_PARAM_STORE.clear()


_param = effectful(_PYRO_PARAM_STORE.get_param, type="param")
@contextmanager
def use_param_store(param_store: ParamStoreDict):
try:
global _PYRO_PARAM_STORE
_PYRO_PARAM_STORE, prev_store = param_store, _PYRO_PARAM_STORE
yield param_store
finally:
_PYRO_PARAM_STORE = prev_store


@effectful(type="param")
def _param(*args, **kwargs):
return _PYRO_PARAM_STORE.get_param(*args, **kwargs)


def param(name, init_tensor=None, constraint=constraints.real, event_dim=None):
Expand Down
Loading