Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expectation Programming Features #230

Draft
wants to merge 5 commits into
base: staging-dynamic
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added chirho/contrib/__init__.py
Empty file.
8 changes: 8 additions & 0 deletions chirho/contrib/compexp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .composeable_expectation.composed_expectation import ComposedExpectation
from .composeable_expectation.expectation_atom import ExpectationAtom
from .handlers.expectation_handler import ExpectationHandler
from .handlers.importance_sampling_expectation_handler import ImportanceSamplingExpectationHandler
from .handlers.montecarlo_expectation_handler import MonteCarloExpectationHandler
from .typedecs import StochasticFunction, ExpectationFunction
from .composeable_expectation import grad
E = ExpectationAtom
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Callable, List, TYPE_CHECKING
from ..typedecs import ModelType
import torch
from torch import Tensor as TT
if TYPE_CHECKING:
from .expectation_atom import ExpectationAtom


class ComposedExpectation:
def __init__(
self,
children: List["ComposedExpectation"],
op: Callable[[TT, ...], TT],
parts: List["ExpectationAtom"]
):
self.op = op
self.children = children
self.parents: List[ComposedExpectation] = []
self.parts = parts

self._normalization_constant_cancels = False

def recursively_refresh_parts(self):
self.parts = None

for child in self.children:
child.recursively_refresh_parts()

self.parts = []
for child in self.children:
self.parts.extend(child.parts)

def __op_other(self, other: "ComposedExpectation", op) -> "ComposedExpectation":
ce = ComposedExpectation(
children=[self, other],
op=op,
parts=self.parts + other.parts)
self.parents.append(ce)
other.parents.append(ce)
return ce

def __truediv__(self, other: "ComposedExpectation") -> "ComposedExpectation":
return self.__op_other(other, torch.divide)

def __add__(self, other: "ComposedExpectation") -> "ComposedExpectation":
return self.__op_other(other, torch.add)

def __mul__(self, other: "ComposedExpectation") -> "ComposedExpectation":
return self.__op_other(other, torch.multiply)

def __sub__(self, other: "ComposedExpectation") -> "ComposedExpectation":
return self.__op_other(other, torch.subtract)

def __call__(self, p: ModelType) -> TT:
return self.op(*[child(p) for child in self.children])
93 changes: 93 additions & 0 deletions chirho/contrib/compexp/composeable_expectation/expectation_atom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import pyro
from typing import Optional
import torch
from torch import tensor as tt
from torch import Tensor as TT
from ..typedecs import ModelType, KWType
from .composed_expectation import ComposedExpectation
from ..typedecs import StochasticFunction
from ..ops import _compute_expectation_atom


class ExpectationAtom(ComposedExpectation):

def __init__(
self,
f: StochasticFunction, # TODO say in docstring that this has to have scalar valued output.
name: str,
log_fac_eps: float = 1e-45,
guide: Optional[ModelType] = None):

self.f = f
self.name = name
self.log_fac_eps = log_fac_eps
self.guide = guide

self._is_positive_everywhere = False

super().__init__(children=[], op=lambda v: v, parts=[self])

def recursively_refresh_parts(self):
self.parts = [self]

def __call__(self, p: ModelType) -> TT:
"""
Overrides the non-atomic call to actually estimate the value of this expectation atom.
"""
ret = _compute_expectation_atom(self, p)
if ret.ndim != 0:
raise ValueError(f"Argument f to {ExpectationAtom.__name__} with name {self.name} must return a scalar,"
f" but got {ret} instead.")
return ret

def build_pseudo_density(self, p: ModelType) -> ModelType:

if not self._is_positive_everywhere:
raise NotImplementedError("Non positive pseudo-density construction is not supported. "
f"Convert atom named {self.name} by using output of "
"CompoundedExpectation.split_into_positive_components().")

# This defines a new density that is the product of the density defined by p and this all-positive function
# we want to take the expectation wrt.
def pseudo_density() -> KWType:
stochastics = p()
factor_name = self.name + "_factor"
pyro.factor(factor_name, torch.log(self.log_fac_eps + self.f(stochastics)))
return stochastics

return pseudo_density

# TODO maybe rename this to get_tabi_decomposition.
def split_into_positive_components(
self,
# TODO bdt18dosjk maybe don't allow for guide specification, but rather handler specification that
# may specify a particular guide arrangement?
pos_guide: Optional[ModelType] = None,
neg_guide: Optional[ModelType] = None,
den_guide: Optional[ModelType] = None) -> "ComposedExpectation":

pos_part = ExpectationAtom(
f=lambda s: torch.relu(self.f(s)), name=self.name + "_split_pos", guide=pos_guide)
pos_part._is_positive_everywhere = True
neg_part = ExpectationAtom(
f=lambda s: torch.relu(-self.f(s)), name=self.name + "_split_neg", guide=neg_guide)
neg_part._is_positive_everywhere = True
den_part = ExpectationAtom(
lambda s: tt(1.), name=self.name + "_split_den", guide=den_guide)
den_part._is_positive_everywhere = True

ret: ComposedExpectation = (pos_part - neg_part) / den_part
ret._normalization_constant_cancels = True

return ret

def swap_self_for_other_child(self, other):
for parent in self.parents:
positions_as_child = [i for i, child in enumerate(parent.children) if child is self]
assert len(positions_as_child) >= 1, "This shouldn't be possible." \
" There's a reference mismatch with parents."
# Now, swap out the old atom with the new composite.
for pac in positions_as_child:
parent.children[pac] = other
other.parents.append(parent)
self.parents = []
113 changes: 113 additions & 0 deletions chirho/contrib/compexp/composeable_expectation/grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import Callable
from ..typedecs import KWType
import torch
from torch import Tensor as TT
import warnings
from .expectation_atom import ExpectationAtom
from .composed_expectation import ComposedExpectation


def _build_df_dd(dparams: TT, di: int, part: ExpectationAtom) -> Callable[[KWType], TT]:
def df_dd(stochastics: KWType) -> TT:
y: TT = part.f(stochastics)

if y.ndim != 0:
raise ValueError(f"Argument f to {ExpectationAtom.__name__} with name {part.name} must return a scalar,"
f" but got {y} instead.")

assert dparams[di].ndim == 0, "This shouldn't be possible due to the outer check of 1 dimension."

try:
df_ddparam, = torch.autograd.grad(
outputs=(y,),
# FIXME HACK Have to grad wrt the whole tensor apparently, and then index after.
inputs=(dparams,),
create_graph=True
)
df_ddparam = df_ddparam[di]
except RuntimeError as e:
if "does not require grad and does not have a grad_fn" in str(e):
# FIXME FIXME kf2801dgi1 this is only correct when this particular atom is a mul or div of one
# that does require grad. It would be nice to not have to repro autodiff here but
# somehow this needs to how the parent incorporates this. Maybe we could autodiff
# parent's op and see how this atom relates, then use that to determine what
# should be returned here?
warnings.warn(f"The gradient of atom named {part.name} with respect to dparam {di}"
f" is 0.0, but returning the original atom's value for now because"
f" it's probably a scaling factor. This is a hack and should be fixed."
f" See FIXME tagged kf2801dgi1.")
return y
else:
raise

assert df_ddparam.ndim == 0, "This shouldn't be possible due to out and in being 0 dimensional."

return df_ddparam

return df_dd


# FIXME 7301ykd0sk See below. Want to conver to proper in place operation with no return value.
def gradify_in_place_but_need_return(
output: ComposedExpectation, dparams: TT, split_atoms=False) -> ComposedExpectation:

if dparams.ndim != 1:
raise ValueError(f"Argument dparams to {gradify_in_place_but_need_return.__name__} must be a 1d tensor, "
f"but got ndim {dparams.ndim} instead.")

assert len(output.parts) >= 1, "This shouldn't be possible due to composites always having at least one " \
"part (themselves)."

if not len(dparams) >= 1:
raise ValueError(f"Argument dparams to {gradify_in_place_but_need_return.__name__} must have at least one "
f"element, but got {len(dparams)} instead.")

# Only relevant if output is an atom. Just defining outside of loop so type checking is happy below.
sub_atom_composite = None

for part in output.parts:

sub_atoms = []

# Create a new atom for each of the old atoms.
for di, _ in enumerate(dparams):

# Create a new atom just for just this element of the gradient vector.
ea = ExpectationAtom(
f=_build_df_dd(dparams, di, part),
name=f"d{part.name}_dd{di}",
log_fac_eps=part.log_fac_eps
# TODO maybe seed a new guide with the original guide (if present)?
)

if split_atoms:
ea = ea.split_into_positive_components()

sub_atoms.append(ea)

# Create a composite that simply concatenates the new atoms into one tensor.
sub_atom_composite = ComposedExpectation(
children=sub_atoms,
op=lambda *v: torch.stack(v, dim=0),
# Note bm72gdi1: This will be updated before the return of this function.
parts=[]
)

for parent in part.parents:
positions_as_child = [i for i, child in enumerate(parent.children) if child is part]
assert len(positions_as_child) >= 1, "This shouldn't be possible." \
" There's a reference mismatch with parents."
# Now, swap out the old atom with the new composite.
for pac in positions_as_child:
parent.children[pac] = sub_atom_composite
sub_atom_composite.parents.append(parent)

if isinstance(output, ExpectationAtom):
assert output.parts[0] and len(output.parts) == 1, "This shouldn't be possible: atom code broken?"
# FIXME 7301ykd0sk this is why you have to take the return value here...
output = sub_atom_composite

# Note bm72gdi1 this ensures that the cached part list is up-to-date.
output.recursively_refresh_parts()

return output
Empty file.
12 changes: 12 additions & 0 deletions chirho/contrib/compexp/handlers/expectation_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pyro


class ExpectationHandler(pyro.poutine.messenger.Messenger):

# noinspection PyMethodMayBeStatic
def _pyro_compute_expectation_atom(self, msg) -> None:
if msg["done"]:
# TODO bdt18dosjk Do something similar to the demo1 setup where we define an OOP interface as well.
# Then handler can be passed to expectation atoms so that users can specify different handlers
# for different atoms.
raise RuntimeError("Only one default expectation handler can be in effect at a time.")
Loading
Loading