From 0b2466ffb06c29af74e1f26e1dd701464e31dc8a Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 28 May 2026 19:39:36 +0200 Subject: [PATCH 1/3] Refactor marginalization with logprob rewrites and conditional() API Redesign marginalization around two concepts: 1. marginalize(model) creates MarginalSubgraph markers at build time, resolved to typed MarginalRVs (e.g. MarginalFiniteDiscreteRV) via an EquilibriumDB of logprob rewrites. 2. conditional(model) replaces marginalized variables with their conditional posterior distributions as free RVs, returning a standard PyMC model that works with compile_logp and sample_posterior_predictive. recover_marginals now uses conditional() + sample_posterior_predictive internally. Other changes: - Remove the deprecated MarginalModel class - Split marginal_model.py into model.py, distributions/, rewrites.py - Support sequential marginalize via unwrap_inner_marginal_rv rewrite --- conda-envs/environment-test.yml | 2 +- docs/api_reference.rst | 13 +- pymc_extras/__init__.py | 5 +- pymc_extras/inference/INLA/inla.py | 2 +- pymc_extras/marginal.py | 14 + pymc_extras/model/marginal/__init__.py | 8 + .../model/marginal/distributions/__init__.py | 6 + .../model/marginal/distributions/core.py | 93 +++ .../enumerable.py} | 336 ++------ .../model/marginal/distributions/laplace.py | 177 ++++ pymc_extras/model/marginal/graph_analysis.py | 25 +- pymc_extras/model/marginal/marginal_model.py | 647 --------------- pymc_extras/model/marginal/model.py | 774 ++++++++++++++++++ pymc_extras/model/marginal/rewrites.py | 334 ++++++++ pyproject.toml | 2 +- tests/model/marginal/test_distributions.py | 3 +- .../{test_marginal_model.py => test_model.py} | 108 +-- tests/model/marginal/test_rewrites.py | 260 ++++++ 18 files changed, 1819 insertions(+), 990 deletions(-) create mode 100644 pymc_extras/marginal.py create mode 100644 pymc_extras/model/marginal/distributions/__init__.py create mode 100644 pymc_extras/model/marginal/distributions/core.py rename pymc_extras/model/marginal/{distributions.py => distributions/enumerable.py} (54%) create mode 100644 pymc_extras/model/marginal/distributions/laplace.py delete mode 100644 pymc_extras/model/marginal/marginal_model.py create mode 100644 pymc_extras/model/marginal/model.py create mode 100644 pymc_extras/model/marginal/rewrites.py rename tests/model/marginal/{test_marginal_model.py => test_model.py} (91%) create mode 100644 tests/model/marginal/test_rewrites.py diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 021c14af0..081f3413e 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -13,7 +13,7 @@ dependencies: - pytest-cov - pydantic>=2.0.0 - h5netcdf - - pymc>=6.0,<7.0 + - pymc>=6.0.1,<7.0 - preliz>=0.26,<0.27 - pip - pip: diff --git a/docs/api_reference.rst b/docs/api_reference.rst index e21772b8b..17ce77e04 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -13,9 +13,20 @@ methods in the current release of PyMC experimental. as_model marginalize - recover_marginals model_builder.ModelBuilder +Marginalization +=============== + +.. currentmodule:: pymc_extras.marginal +.. autosummary:: + :toctree: generated/ + + marginalize + conditional + unmarginalize + recover + Inference ========= diff --git a/pymc_extras/__init__.py b/pymc_extras/__init__.py index cee0ffeb5..00691dd40 100644 --- a/pymc_extras/__init__.py +++ b/pymc_extras/__init__.py @@ -15,11 +15,10 @@ from importlib.metadata import version -from pymc_extras import gp, statespace, utils +from pymc_extras import gp, marginal, statespace, utils from pymc_extras.distributions import * from pymc_extras.inference import find_MAP, fit, fit_laplace, fit_pathfinder -from pymc_extras.model.marginal.marginal_model import ( - MarginalModel, +from pymc_extras.model.marginal.model import ( marginalize, recover_marginals, ) diff --git a/pymc_extras/inference/INLA/inla.py b/pymc_extras/inference/INLA/inla.py index b128e3bc9..33393df14 100644 --- a/pymc_extras/inference/INLA/inla.py +++ b/pymc_extras/inference/INLA/inla.py @@ -5,7 +5,7 @@ from pytensor.tensor import TensorLike, TensorVariable, as_tensor from xarray import DataTree -from pymc_extras.model.marginal.marginal_model import marginalize +from pymc_extras.model.marginal.model import marginalize def fit_INLA( diff --git a/pymc_extras/marginal.py b/pymc_extras/marginal.py new file mode 100644 index 000000000..d447a264b --- /dev/null +++ b/pymc_extras/marginal.py @@ -0,0 +1,14 @@ +"""Public namespace for marginalization utilities. + +The implementation lives in :mod:`pymc_extras.model.marginal`; this module +re-exports the public API under the shorter ``pymc_extras.marginal`` path. +""" + +from pymc_extras.model.marginal.model import ( + conditional, + marginalize, + recover, + unmarginalize, +) + +__all__ = ["conditional", "marginalize", "recover", "unmarginalize"] diff --git a/pymc_extras/model/marginal/__init__.py b/pymc_extras/model/marginal/__init__.py index e69de29bb..f43476c24 100644 --- a/pymc_extras/model/marginal/__init__.py +++ b/pymc_extras/model/marginal/__init__.py @@ -0,0 +1,8 @@ +import pymc_extras.model.marginal.rewrites # noqa: F401 + +from pymc_extras.model.marginal.model import ( # noqa: F401 + conditional, + marginalize, + recover, + unmarginalize, +) diff --git a/pymc_extras/model/marginal/distributions/__init__.py b/pymc_extras/model/marginal/distributions/__init__.py new file mode 100644 index 000000000..7e52e2ee9 --- /dev/null +++ b/pymc_extras/model/marginal/distributions/__init__.py @@ -0,0 +1,6 @@ +import pymc_extras.model.marginal.distributions.enumerable +import pymc_extras.model.marginal.distributions.laplace # noqa: F401 + +from pymc_extras.model.marginal.distributions.enumerable import ( + MarginalFiniteDiscreteRV, # noqa: F401 +) diff --git a/pymc_extras/model/marginal/distributions/core.py b/pymc_extras/model/marginal/distributions/core.py new file mode 100644 index 000000000..2507837a1 --- /dev/null +++ b/pymc_extras/model/marginal/distributions/core.py @@ -0,0 +1,93 @@ +from collections.abc import Sequence +from functools import singledispatch + +from pymc.distributions.distribution import _support_point, support_point +from pymc.logprob.abstract import MeasurableOp +from pytensor.compile.builders import OpFromGraph +from pytensor.graph import FunctionGraph +from pytensor.graph.basic import Variable +from pytensor.graph.replace import graph_replace +from pytensor.tensor.random.type import RandomType + + +def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]: + """Inline the inner graph (outputs) of an OpFromGraph Op. + + Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps" + the inner graph. + """ + return graph_replace( + op.inner_outputs, + replace=tuple(zip(op.inner_inputs, inputs)), + strict=False, + ) + + +class MarginalRV(OpFromGraph, MeasurableOp): + """Base class for supported MarginalRVs.""" + + +@_support_point.register(MarginalRV) +def _support_point_marginal_rv(op, rv, *inputs): + outputs = rv.owner.outputs + + fgraph = op.fgraph.clone() + inner_inputs = fgraph.inputs + inner_outputs = fgraph.outputs + del op + + inner_rv = inner_outputs[outputs.index(rv)] + marginalized_inner_rv, *other_dependent_inner_rvs = ( + out for out in inner_outputs if out is not inner_rv and not isinstance(out.type, RandomType) + ) + + marginalized_inner_rv_dummy = marginalized_inner_rv.clone() + inner_to_dummy_replacements = [] + dummy_to_outer_replacements = [] + for other_inner_rv in other_dependent_inner_rvs: + dummy = other_inner_rv.clone() + inner_to_dummy_replacements.append((other_inner_rv, dummy)) + dummy_to_outer_replacements.append((dummy, outputs[inner_outputs.index(other_inner_rv)])) + + fgraph.replace(marginalized_inner_rv, marginalized_inner_rv_dummy, import_missing=True) + fgraph.replace_all(tuple(inner_to_dummy_replacements), import_missing=True) + + inner_rv_support_point = support_point(inner_rv) + marginalized_inner_rv_support_point = support_point(marginalized_inner_rv) + + fgraph = FunctionGraph(outputs=[inner_rv_support_point], clone=False) + fgraph.replace( + marginalized_inner_rv_dummy, marginalized_inner_rv_support_point, import_missing=True + ) + fgraph.replace_all(tuple(zip(inner_inputs, inputs)), import_missing=True) + fgraph.replace_all(tuple(dummy_to_outer_replacements), import_missing=True) + + [rv_support_point] = fgraph.outputs + return rv_support_point + + +@singledispatch +def marginalized_conditional(op, node): + """Build the conditional distribution of a marginalized variable given its dependents. + + Dispatches on the MarginalRV op type. + + The inner graph of a MarginalRV is generative: it draws the marginalized + variable and then the dependents given it, factoring as + ``p(marginalized | inputs) * p(dependents | marginalized, inputs)``. + This function returns the reverse factor + ``p(marginalized | dependents, inputs)``, where the dependents are given + values rather than random draws: a Categorical over the enumerated domain + weighted by the joint logp for finite discrete marginals, the conjugate + posterior Normal for Normal-Normal. + + Returns ``(sample_graph, dep_dummies)`` where *sample_graph* is a random + variable distributed as ``p(marginalized | dependents, inputs)``, + expressed over the op's ``inner_inputs``, and *dep_dummies* are + placeholder tensors standing in for the dependent values. The caller + replaces the dummies with the actual model variables (or observed data) + and the inner inputs with the node's inputs. + """ + raise NotImplementedError( + f"Cannot recover marginalized variable with distribution {type(op).__name__}" + ) diff --git a/pymc_extras/model/marginal/distributions.py b/pymc_extras/model/marginal/distributions/enumerable.py similarity index 54% rename from pymc_extras/model/marginal/distributions.py rename to pymc_extras/model/marginal/distributions/enumerable.py index 2bbf57b17..6c1b241a2 100644 --- a/pymc_extras/model/marginal/distributions.py +++ b/pymc_extras/model/marginal/distributions/enumerable.py @@ -3,41 +3,41 @@ from collections.abc import Sequence import numpy as np -import pytensor import pytensor.tensor as pt from pymc.distributions import Bernoulli, Categorical, DiscreteUniform -from pymc.distributions.distribution import _support_point, support_point -from pymc.distributions.multivariate import _logdet_from_cholesky -from pymc.logprob.abstract import MeasurableOp, _logprob +from pymc.logprob.abstract import _logprob from pymc.logprob.basic import conditional_logp, logp from pymc.pytensorf import constant_fold -from pytensor.compile.builders import OpFromGraph from pytensor.compile.mode import Mode -from pytensor.graph import FunctionGraph, Op, vectorize_graph -from pytensor.graph.basic import Variable, equal_computations +from pytensor.graph import Op, vectorize_graph from pytensor.graph.replace import clone_replace, graph_replace from pytensor.scan import map as scan_map from pytensor.scan import scan -from pytensor.tensor import TensorLike, TensorVariable -from pytensor.tensor.optimize import minimize -from pytensor.tensor.random.type import RandomType +from pytensor.tensor import TensorVariable from pymc_extras.distributions import DiscreteMarkovChain +from pymc_extras.model.marginal.distributions.core import ( + MarginalRV, + inline_ofg_outputs, + marginalized_conditional, +) -class MarginalRV(OpFromGraph, MeasurableOp): - """Base class for Marginalized RVs""" +class EnumerableMarginalRV(MarginalRV): + """Base class for enumerable Marginalized RVs with closed-form logp.""" def __init__( self, *args, dims_connections: tuple[tuple[int | None], ...], - dims: tuple[Variable, ...], + marginalized_dims, + n_dependent_rvs: int, **kwargs, ) -> None: self.dims_connections = dims_connections - self.dims = dims + self.marginalized_dims = marginalized_dims + self.n_dependent_rvs = n_dependent_rvs super().__init__(*args, **kwargs) @property @@ -57,104 +57,38 @@ def support_axes(self) -> tuple[tuple[int]]: ) return tuple(support_axes_vars) - def __eq__(self, other): - # Just to allow easy testing of equivalent models, - # This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed - if type(self) is not type(other): - return False - - return equal_computations( - self.inner_outputs, - other.inner_outputs, - self.inner_inputs, - other.inner_inputs, - ) - - def __hash__(self): - # Just to allow easy testing of equivalent models, - # This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed - return hash((type(self), len(self.inner_inputs), len(self.inner_outputs))) +class NonSeparableLogpWarning(UserWarning): + pass -@_support_point.register -def support_point_marginal_rv(op: MarginalRV, rv, *inputs): - """Support point for a marginalized RV. - The support point of a marginalized RV is the support point of the inner RV, - conditioned on the marginalized RV taking its support point. - """ - outputs = rv.owner.outputs +def warn_non_separable_logp(values): + if len(values) > 1: + warnings.warn( + "There are multiple dependent variables in a FiniteDiscreteMarginalRV. " + f"Their joint logp terms will be assigned to the first value: {values[0]}.", + NonSeparableLogpWarning, + stacklevel=2, + ) - fgraph = op.fgraph.clone() - inner_inputs = fgraph.inputs - inner_outputs = fgraph.outputs - del op - inner_rv = inner_outputs[outputs.index(rv)] - marginalized_inner_rv, *other_dependent_inner_rvs = ( - out for out in inner_outputs if out is not inner_rv and not isinstance(out.type, RandomType) - ) +DUMMY_ZERO = pt.constant(0, name="dummy_zero") - # Replace references to inner rvs by the dummy variables (including the marginalized RV) - # This is necessary because the inner RVs may depend on each other - marginalized_inner_rv_dummy = marginalized_inner_rv.clone() - # Map inner rvs to dummies, saving what outer output each corresponds to. - # We need dummies because inner RVs may depend on each other. - inner_to_dummy_replacements = [] - dummy_to_outer_replacements = [] - for other_inner_rv in other_dependent_inner_rvs: - dummy = other_inner_rv.clone() - inner_to_dummy_replacements.append((other_inner_rv, dummy)) - dummy_to_outer_replacements.append((dummy, outputs[inner_outputs.index(other_inner_rv)])) - - fgraph.replace(marginalized_inner_rv, marginalized_inner_rv_dummy, import_missing=True) - fgraph.replace_all(tuple(inner_to_dummy_replacements), import_missing=True) - - # Get support point of inner RV and marginalized RV - inner_rv_support_point = support_point(inner_rv) - marginalized_inner_rv_support_point = support_point(marginalized_inner_rv) - - fgraph = FunctionGraph(outputs=[inner_rv_support_point], clone=False) - # Replace the marginalized RV dummy by its support point - fgraph.replace( - marginalized_inner_rv_dummy, marginalized_inner_rv_support_point, import_missing=True - ) - # Replace the inner inputs by the outer inputs - fgraph.replace_all(tuple(zip(inner_inputs, inputs)), import_missing=True) - # Replace other dependent RVs dummies by the respective outer outputs. - # PyMC will replace them by their support points later - fgraph.replace_all(tuple(dummy_to_outer_replacements), import_missing=True) - [rv_support_point] = fgraph.outputs - return rv_support_point +def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> TensorVariable: + """Align the logp with the order specified in dims.""" + dims_alignment = [dim for dim in dims if dim is not None] + return logp.transpose(*dims_alignment) -class MarginalFiniteDiscreteRV(MarginalRV): +class MarginalFiniteDiscreteRV(EnumerableMarginalRV): """Base class for Marginalized Finite Discrete RVs""" -class MarginalDiscreteMarkovChainRV(MarginalRV): +class MarginalDiscreteMarkovChainRV(EnumerableMarginalRV): """Base class for Marginalized Discrete Markov Chain RVs""" -class MarginalLaplaceRV(MarginalRV): - """Base class for Marginalized Laplace-Approximated RVs. - - Estimates log likelihood using Laplace approximations. - """ - - def __init__( - self, - *args, - minimizer_seed: int, - minimizer_kwargs: dict = {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}}, - **kwargs, - ) -> None: - self.minimizer_seed = minimizer_seed - self.minimizer_kwargs = minimizer_kwargs - super().__init__(*args, **kwargs) - - def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]: op = rv.owner.op dist_params = rv.owner.op.dist_params(rv.owner) @@ -223,46 +157,12 @@ def reduce_batch_dependent_logps( return reduced_logp -def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> TensorVariable: - """Align the logp with the order specified in dims.""" - dims_alignment = [dim for dim in dims if dim is not None] - return logp.transpose(*dims_alignment) - - -def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]: - """Inline the inner graph (outputs) of an OpFromGraph Op. - - Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps" - the inner graph. - """ - return graph_replace( - op.inner_outputs, - replace=tuple(zip(op.inner_inputs, inputs)), - strict=False, - ) - - -class NonSeparableLogpWarning(UserWarning): - pass - - -def warn_non_separable_logp(values): - if len(values) > 1: - warnings.warn( - "There are multiple dependent variables in a FiniteDiscreteMarginalRV. " - f"Their joint logp terms will be assigned to the first value: {values[0]}.", - NonSeparableLogpWarning, - stacklevel=2, - ) - - -DUMMY_ZERO = pt.constant(0, name="dummy_zero") - - @_logprob.register(MarginalFiniteDiscreteRV) def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inputs, **kwargs): # Clone the inner RV graph of the Marginalized RV - marginalized_rv, *inner_rvs = inline_ofg_outputs(op, inputs) + all_outputs = inline_ofg_outputs(op, inputs) + marginalized_rv = all_outputs[0] + inner_rvs = list(all_outputs[1 : 1 + op.n_dependent_rvs]) # Obtain the joint_logp graph of the inner RV graph inner_rv_values = dict(zip(inner_rvs, values)) @@ -326,7 +226,9 @@ def logp_fn(marginalized_rv_const, *non_sequences): @_logprob.register(MarginalDiscreteMarkovChainRV) def marginal_hmm_logp(op, values, *inputs, **kwargs): - chain_rv, *dependent_rvs = inline_ofg_outputs(op, inputs) + all_outputs = inline_ofg_outputs(op, inputs) + chain_rv = all_outputs[0] + dependent_rvs = list(all_outputs[1 : 1 + op.n_dependent_rvs]) P, n_steps_, init_dist_, rng = chain_rv.owner.inputs domain = pt.arange(P.shape[-1], dtype="int32") @@ -400,135 +302,51 @@ def step_alpha(logp_emission, log_alpha, log_P): return joint_logp, *dummy_logps -def _precision_mv_normal_logp(value: TensorLike, mean: TensorLike, tau: TensorLike): - """ - Compute the log likelihood of a multivariate normal distribution in precision form. May be phased out - see https://github.com/pymc-devs/pymc/pull/7895 - - Parameters - ---------- - value: TensorLike - Query point to compute the log prob at. - mean: TensorLike - Mean vector of the Gaussian, - tau: TensorLike - Precision matrix of the Gaussian (i.e. cov = inv(tau)) - - Returns - ------- - logp: TensorLike - Log likelihood at value. - posdef: TensorLike - Boolean indicating whether the precision matrix is positive definite. - """ - k = value.shape[-1].astype("floatX") - - delta = value - mean - quadratic_form = delta.T @ tau @ delta - logdet, posdef = _logdet_from_cholesky(pt.linalg.cholesky(tau, lower=True)) - logp = -0.5 * (k * pt.log(2 * np.pi) + quadratic_form) + logdet +@marginalized_conditional.register(MarginalFiniteDiscreteRV) +def build_finite_discrete_marginalized_conditional(op, node): + fgraph = op.fgraph.clone() + marginalized = fgraph.outputs[0] + dependents = list(fgraph.outputs[1 : 1 + op.n_dependent_rvs]) - return logp, posdef + marginalized_value = marginalized.clone() + dep_dummies = [dep.type() for dep in dependents] + rvs_to_values = {marginalized: marginalized_value} + for inner_dep, dummy in zip(dependents, dep_dummies): + rvs_to_values[inner_dep] = dummy + logps_dict = conditional_logp(rvs_to_values) + marginalized_logp = logps_dict[marginalized_value] + dependent_logps = [logps_dict[dummy] for dummy in dep_dummies] -def get_laplace_approx( - log_likelihood: TensorVariable, - logp_objective: TensorVariable, - x: TensorVariable, - x0_init: TensorLike, - Q: TensorLike, - minimizer_kwargs: dict = {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}}, -): - """ - Compute the laplace approximation logp_G(x | y, params) of some variable x. - - Parameters - ---------- - log_likelihood: TensorVariable - Model likelihood logp(y | x, params). - logp_objective: TensorVariable - Obective log likelihood to maximize, logp(x | y, params) (up to some constant in x). - x: TensorVariable - Variable to be laplace approximated. - x0_init: TensorLike - Initial guess for minimization. - Q: TensorLike - Precision matrix of x. - minimizer_kwargs: - Kwargs to pass to pytensor.optimize.minimize. - - Returns - ------- - x0: TensorVariable - x*, the maximizer of logp(x | y, params) in x. - log_laplace_approx: TensorVariable - Laplace approximation of logp(x | y, params) evaluated at x. - """ - # Maximize log(p(x | y, params)) wrt x to find mode x0 - # This step is currently bottlenecking the logp calculation. - x0, _ = minimize( - objective=-logp_objective, # logp(x | y, params) = logp(y | x, params) + logp(x | params) + const (const omitted during minimization) - x=x, - use_vectorized_jac=True, - **minimizer_kwargs, + joint_logp = marginalized_logp + reduce_batch_dependent_logps( + op.dims_connections, + [dep.owner.op for dep in dependents], + dependent_logps, ) - # Set minimizer initialisation to be random - x0 = pytensor.graph.replace.graph_replace(x0, {x: x0_init}) - - # This step is also expensive (but not as much as minimize). Could be made more efficient by recycling hessian from the minimizer step, however that requires a bespoke algorithm described in Rasmussen & Williams - # since the general optimisation scheme maximises logp(x | y, params) rather than logp(y | x, params), and thus the hessian that comes out of methods - # like L-BFGS-B is in fact not the hessian of logp(y | x, params) - # TODO: Use vectorized hessian? - hess = pytensor.gradient.hessian(log_likelihood, x) - - # Evaluate logp of Laplace approx of logp(x | y, params) at some point x - tau = Q - hess - mu = x0 - log_laplace_approx, _ = _precision_mv_normal_logp(x, mu, tau) - - return x0, log_laplace_approx - - -@_logprob.register(MarginalLaplaceRV) -def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs_and_Q, **kwargs): - # Get Q and remove it from the graph (stored as a dummy input) - *inputs, Q = inputs_and_Q - - # Clone the inner RV graph of the Marginalized RV - x, *inner_rvs = inline_ofg_outputs(op, inputs) - - # Obtain the joint_logp graph of the inner RV graph - inner_rv_values = dict(zip(inner_rvs, values)) - - marginalized_vv = x.clone() - rv_values = inner_rv_values | {x: marginalized_vv} - logps_dict = conditional_logp(rv_values=rv_values, **kwargs) - - # logp(x | params) - logp_x = logps_dict.pop(marginalized_vv).sum() - - # logp(y | x, params) - logp_y = pt.sum([logp_term.sum() for value, logp_term in logps_dict.items()]) - - # logp_total = logp(y | x, params) + logp(x | params) (i.e. logp(x | y, params) up to a constant in x) - logp_total = logp_x + logp_y - - # Set minimizer initialisation to be random (TODO: Let pymc accept this one, maybe when rng is constant) - # TODO: Use newer pytensor helper - d = pt.prod(constant_fold(tuple(x.shape), raise_not_constant=True)) - x0_init = pt.ones(d) + rv_shape = constant_fold(tuple(marginalized.shape), raise_not_constant=False) + rv_domain = get_domain_of_finite_discrete_rv(marginalized) + rv_domain_tensor = pt.moveaxis( + pt.full( + (*rv_shape, len(rv_domain)), + rv_domain, + dtype=marginalized.dtype, + ), + -1, + 0, + ) - # Obtain laplace approx for logp(x | y, params) - x0, log_laplace_approx = get_laplace_approx( - logp_y, - logp_total, - x=marginalized_vv, - x0_init=x0_init, - Q=Q, - minimizer_kwargs=op.minimizer_kwargs, + batched_joint_logp = vectorize_graph( + joint_logp, + replace={marginalized_value: rv_domain_tensor}, ) + batched_joint_logp = pt.moveaxis(batched_joint_logp, 0, -1) + + sample_graph = Categorical.dist(logit_p=batched_joint_logp) + if isinstance(marginalized.owner.op, DiscreteUniform): + # rv_domain[0] is folded to a float; adding it directly would insert a + # Cast{float64} that breaks logp derivation. Keep the offset integral and + # matching the marginalized dtype so the conditional stays loggable. + sample_graph += rv_domain[0].astype(marginalized.dtype) - # logp(y | params) = logp(y | x, params) + logp(x | params) - logp(x | y, params) - # TODO: Can we recover the elementwise logp? - marginal_likelihood = logp_total - log_laplace_approx - return graph_replace(marginal_likelihood, {marginalized_vv: x0}) + return sample_graph, dep_dummies diff --git a/pymc_extras/model/marginal/distributions/laplace.py b/pymc_extras/model/marginal/distributions/laplace.py new file mode 100644 index 000000000..3c3c1cb9e --- /dev/null +++ b/pymc_extras/model/marginal/distributions/laplace.py @@ -0,0 +1,177 @@ +import numpy as np +import pytensor +import pytensor.tensor as pt + +from pymc.distributions.multivariate import _logdet_from_cholesky +from pymc.logprob.abstract import _logprob +from pymc.logprob.basic import conditional_logp +from pymc.pytensorf import constant_fold +from pytensor.graph.replace import graph_replace +from pytensor.tensor import TensorLike, TensorVariable +from pytensor.tensor.optimize import minimize + +from pymc_extras.model.marginal.distributions.core import ( + MarginalRV, + inline_ofg_outputs, +) + + +class MarginalLaplaceRV(MarginalRV): + """Base class for Marginalized Laplace-Approximated RVs. + + Estimates log likelihood using Laplace approximations. + + The precision matrix Q of the marginalized variable is passed as the + last input of the node (a dummy input, unused by the inner graph). + """ + + def __init__( + self, + *args, + marginalized_dims, + n_dependent_rvs: int, + minimizer_seed: int, + minimizer_kwargs: dict = {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}}, + **kwargs, + ) -> None: + self.marginalized_dims = marginalized_dims + self.n_dependent_rvs = n_dependent_rvs + self.minimizer_seed = minimizer_seed + self.minimizer_kwargs = minimizer_kwargs + super().__init__(*args, **kwargs) + + +def _precision_mv_normal_logp(value: TensorLike, mean: TensorLike, tau: TensorLike): + """ + Compute the log likelihood of a multivariate normal distribution in precision form. May be phased out - see https://github.com/pymc-devs/pymc/pull/7895 + + Parameters + ---------- + value: TensorLike + Query point to compute the log prob at. + mean: TensorLike + Mean vector of the Gaussian, + tau: TensorLike + Precision matrix of the Gaussian (i.e. cov = inv(tau)) + + Returns + ------- + logp: TensorLike + Log likelihood at value. + posdef: TensorLike + Boolean indicating whether the precision matrix is positive definite. + """ + k = value.shape[-1].astype("floatX") + + delta = value - mean + quadratic_form = delta.T @ tau @ delta + logdet, posdef = _logdet_from_cholesky(pt.linalg.cholesky(tau, lower=True)) + logp = -0.5 * (k * pt.log(2 * np.pi) + quadratic_form) + logdet + + return logp, posdef + + +def get_laplace_approx( + log_likelihood: TensorVariable, + logp_objective: TensorVariable, + x: TensorVariable, + x0_init: TensorLike, + Q: TensorLike, + minimizer_kwargs: dict = {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}}, +): + """ + Compute the laplace approximation logp_G(x | y, params) of some variable x. + + Parameters + ---------- + log_likelihood: TensorVariable + Model likelihood logp(y | x, params). + logp_objective: TensorVariable + Obective log likelihood to maximize, logp(x | y, params) (up to some constant in x). + x: TensorVariable + Variable to be laplace approximated. + x0_init: TensorLike + Initial guess for minimization. + Q: TensorLike + Precision matrix of x. + minimizer_kwargs: + Kwargs to pass to pytensor.optimize.minimize. + + Returns + ------- + x0: TensorVariable + x*, the maximizer of logp(x | y, params) in x. + log_laplace_approx: TensorVariable + Laplace approximation of logp(x | y, params) evaluated at x. + """ + # Maximize log(p(x | y, params)) wrt x to find mode x0 + # This step is currently bottlenecking the logp calculation. + x0, _ = minimize( + objective=-logp_objective, # logp(x | y, params) = logp(y | x, params) + logp(x | params) + const (const omitted during minimization) + x=x, + use_vectorized_jac=True, + **minimizer_kwargs, + ) + + # Set minimizer initialisation to be random + x0 = pytensor.graph.replace.graph_replace(x0, {x: x0_init}) + + # This step is also expensive (but not as much as minimize). Could be made more efficient by recycling hessian from the minimizer step, however that requires a bespoke algorithm described in Rasmussen & Williams + # since the general optimisation scheme maximises logp(x | y, params) rather than logp(y | x, params), and thus the hessian that comes out of methods + # like L-BFGS-B is in fact not the hessian of logp(y | x, params) + # TODO: Use vectorized hessian? + hess = pytensor.gradient.hessian(log_likelihood, x) + + # Evaluate logp of Laplace approx of logp(x | y, params) at some point x + tau = Q - hess + mu = x0 + log_laplace_approx, _ = _precision_mv_normal_logp(x, mu, tau) + + return x0, log_laplace_approx + + +@_logprob.register(MarginalLaplaceRV) +def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs_and_Q, **kwargs): + # Get Q and remove it from the graph (stored as a dummy input) + *inputs, Q = inputs_and_Q + + # Clone the inner RV graph of the Marginalized RV + all_outputs = inline_ofg_outputs(op, inputs_and_Q) + x = all_outputs[0] + inner_rvs = list(all_outputs[1 : 1 + op.n_dependent_rvs]) + + # Obtain the joint_logp graph of the inner RV graph + inner_rv_values = dict(zip(inner_rvs, values)) + + marginalized_vv = x.clone() + rv_values = inner_rv_values | {x: marginalized_vv} + logps_dict = conditional_logp(rv_values=rv_values, **kwargs) + + # logp(x | params) + logp_x = logps_dict.pop(marginalized_vv).sum() + + # logp(y | x, params) + logp_y = pt.sum([logp_term.sum() for value, logp_term in logps_dict.items()]) + + # logp_total = logp(y | x, params) + logp(x | params) (i.e. logp(x | y, params) up to a constant in x) + logp_total = logp_x + logp_y + + # Set minimizer initialisation to be random (TODO: Let pymc accept this one, maybe when rng is constant) + # TODO: Use newer pytensor helper + d = pt.prod(constant_fold(tuple(x.shape), raise_not_constant=True)) + x0_init = pt.ones(d) + + # Obtain laplace approx for logp(x | y, params) + x0, log_laplace_approx = get_laplace_approx( + logp_y, + logp_total, + x=marginalized_vv, + x0_init=x0_init, + Q=Q, + minimizer_kwargs=op.minimizer_kwargs, + ) + + # logp(y | params) = logp(y | x, params) + logp(x | params) - logp(x | y, params) + # TODO: Can we recover the elementwise logp? + marginal_likelihood = logp_total - log_laplace_approx + return graph_replace(marginal_likelihood, {marginalized_vv: x0}) diff --git a/pymc_extras/model/marginal/graph_analysis.py b/pymc_extras/model/marginal/graph_analysis.py index f4d79f799..7c48fe0e1 100644 --- a/pymc_extras/model/marginal/graph_analysis.py +++ b/pymc_extras/model/marginal/graph_analysis.py @@ -16,7 +16,7 @@ from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list from pytensor.tensor.type_other import NoneTypeT -from pymc_extras.model.marginal.distributions import MarginalRV +from pymc_extras.model.marginal.distributions.core import MarginalRV def static_shape_ancestors(vars): @@ -34,7 +34,7 @@ def static_shape_ancestors(vars): def find_conditional_input_rvs(output_rvs, all_rvs): - """Find conditionally indepedent input RVs.""" + """Find conditionally independent input RVs.""" other_rvs = [other_rv for other_rv in all_rvs if other_rv not in output_rvs] blockers = other_rvs + static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs)) return [var for var in ancestors(output_rvs, blockers=blockers) if var in other_rvs] @@ -43,9 +43,17 @@ def find_conditional_input_rvs(output_rvs, all_rvs): def is_conditional_dependent( dependent_rv: TensorVariable, dependable_rv: TensorVariable, all_rvs ) -> bool: - """Check if dependent_rv is conditionall dependent on dependable_rv, + """Check if dependent_rv is conditionally dependent on dependable_rv, given all conditionally independent all_rvs""" + # Sibling outputs of the same node are conditionally dependent + if ( + dependent_rv is not dependable_rv + and dependent_rv.owner.inputs[0].owner is not None + and dependent_rv.owner.inputs[0].owner is dependable_rv.owner.inputs[0].owner + ): + return True + return dependable_rv in find_conditional_input_rvs((dependent_rv,), all_rvs) @@ -59,17 +67,12 @@ def find_conditional_dependent_rvs(dependable_rv, all_rvs): def get_support_axes(op) -> tuple[tuple[int, ...], ...]: - if isinstance(op, MarginalRV): + if hasattr(op, "support_axes"): return op.support_axes else: - # For vanilla RVs, the support axes are the last ndim_supp return (tuple(range(-op.ndim_supp, 0)),) -def _is_tensor_idx(idx) -> bool: - return isinstance(idx, Variable) and isinstance(idx.type, TensorType) - - def _advanced_indexing_axis_and_ndim(idxs) -> tuple[int, int]: """Find the output axis and dimensionality of the advanced indexing group (i.e., array indexing). @@ -78,6 +81,10 @@ def _advanced_indexing_axis_and_ndim(idxs) -> tuple[int, int]: See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing """ + + def _is_tensor_idx(idx) -> bool: + return isinstance(idx, Variable) and isinstance(idx.type, TensorType) + adv_group_axis = None simple_group_after_adv = False for axis, idx in enumerate(idxs): diff --git a/pymc_extras/model/marginal/marginal_model.py b/pymc_extras/model/marginal/marginal_model.py deleted file mode 100644 index ab51c01ba..000000000 --- a/pymc_extras/model/marginal/marginal_model.py +++ /dev/null @@ -1,647 +0,0 @@ -import warnings - -from collections.abc import Sequence - -import numpy as np -import pymc -import pytensor.tensor as pt - -from arviz_base import dict_to_dataset -from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_point_list -from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform -from pymc.distributions.transforms import Chain -from pymc.logprob.transforms import IntervalTransform -from pymc.model import Model, modelcontext -from pymc.model.fgraph import ( - ModelFreeRV, - ModelValuedVar, - fgraph_from_model, - model_free_rv, - model_from_fgraph, -) -from pymc.pytensorf import collect_default_updates, constant_fold, toposort_replace -from pymc.pytensorf import compile as compile_pymc -from pymc.util import RandomState, _get_seeds_per_chain -from pytensor.compile import SharedVariable -from pytensor.compile.io import In, Out -from pytensor.graph import ( - FunctionGraph, - Variable, - clone_replace, - graph_inputs, - graph_replace, - node_rewriter, - vectorize_graph, -) -from pytensor.graph.rewriting.basic import in2out -from pytensor.tensor import TensorVariable -from xarray import DataTree - -__all__ = ["MarginalModel", "marginalize"] - -from pytensor.tensor.random.type import RandomType -from pytensor.tensor.special import log_softmax - -from pymc_extras.distributions import DiscreteMarkovChain -from pymc_extras.model.marginal.distributions import ( - MarginalDiscreteMarkovChainRV, - MarginalFiniteDiscreteRV, - MarginalLaplaceRV, - MarginalRV, - NonSeparableLogpWarning, - get_domain_of_finite_discrete_rv, - inline_ofg_outputs, - reduce_batch_dependent_logps, -) -from pymc_extras.model.marginal.graph_analysis import ( - find_conditional_dependent_rvs, - find_conditional_input_rvs, - is_conditional_dependent, - subgraph_batch_dim_connection, -) - -ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str] - - -class MarginalModel(Model): - """Subclass of PyMC Model that implements functionality for automatic - marginalization of variables in the logp transformation - - After defining the full Model, the `marginalize` method can be used to indicate a - subset of variables that should be marginalized - - Notes - ----- - Marginalization functionality is still very restricted. Only finite discrete - variables can be marginalized. Deterministics and Potentials cannot be conditionally - dependent on the marginalized variables. - - Furthermore, not all instances of such variables can be marginalized. If a variable - has batched dimensions, it is required that any conditionally dependent variables - use information from an individual batched dimension. In other words, the graph - connecting the marginalized variable(s) to the dependent variable(s) must be - composed strictly of Elemwise Operations. This is necessary to ensure an efficient - logprob graph can be generated. If you want to bypass this restriction you can - separate each dimension of the marginalized variable into the scalar components - and then stack them together. Note that such graphs will grow exponentially in the - number of marginalized variables. - - For the same reason, it's not possible to marginalize RVs with multivariate - dependent RVs. - - Examples - -------- - Marginalize over a single variable - - .. code-block:: python - - import pymc as pm - from pymc_extras import MarginalModel - - with MarginalModel() as m: - p = pm.Beta("p", 1, 1) - x = pm.Bernoulli("x", p=p, shape=(3,)) - y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10]) - - m.marginalize([x]) - - idata = pm.sample() - - """ - - def __init__(self, *args, **kwargs): - raise TypeError( - "MarginalModel was deprecated in favor of `marginalize` which now returns a PyMC model" - ) - - -def _warn_interval_transform(rv_to_marginalize, replaced_vars: Sequence[ModelValuedVar]) -> None: - for replaced_var in replaced_vars: - if not isinstance(replaced_var.owner.op, ModelValuedVar): - raise TypeError(f"{replaced_var} is not a ModelValuedVar") - - if not isinstance(replaced_var.owner.op, ModelFreeRV): - continue - - if replaced_var is rv_to_marginalize: - continue - - transform = replaced_var.owner.op.transform - - if isinstance(transform, IntervalTransform) or ( - isinstance(transform, Chain) - and any(isinstance(tr, IntervalTransform) for tr in transform.transform_list) - ): - warnings.warn( - f"The transform {transform} for the variable {replaced_var}, which depends on the " - f"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.", - UserWarning, - ) - - -def _unique(seq: Sequence) -> list: - """Copied from https://stackoverflow.com/a/480227""" - seen = set() - seen_add = seen.add - return [x for x in seq if not (x in seen or seen_add(x))] - - -def marginalize( - model: Model, rvs_to_marginalize: ModelRVs, use_laplace: bool = False, **marginalize_kwargs -) -> MarginalModel: - """Marginalize a subset of variables in a PyMC model. - - This creates a class of `MarginalModel` from an existing `Model`, with the specified - variables marginalized. - - See documentation for `MarginalModel` for more information. - - Parameters - ---------- - model : Model - PyMC model to marginalize. Original variables well be cloned. - rvs_to_marginalize : Sequence[TensorVariable] - Variables to marginalize in the returned model. - use_laplace : bool - Whether to use Laplace appoximations to marginalize out rvs_to_marginalize. - - Returns - ------- - marginal_model: MarginalModel - Marginal model with the specified variables marginalized. - """ - if isinstance(rvs_to_marginalize, str | Variable): - rvs_to_marginalize = (rvs_to_marginalize,) - - rvs_to_marginalize = [model[rv] if isinstance(rv, str) else rv for rv in rvs_to_marginalize] - - if not rvs_to_marginalize: - return model - - for rv_to_marginalize in rvs_to_marginalize: - if rv_to_marginalize not in model.free_RVs: - raise ValueError(f"Marginalized RV {rv_to_marginalize} is not a free RV in the model") - - rv_op = rv_to_marginalize.owner.op - if isinstance(rv_op, DiscreteMarkovChain): - if rv_op.n_lags > 1: - raise NotImplementedError( - "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported" - ) - if rv_to_marginalize.owner.inputs[0].type.ndim > 2: - raise NotImplementedError( - "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported" - ) - - elif not use_laplace and not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform): - raise NotImplementedError( - f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported" - ) - - fg, memo = fgraph_from_model(model) - rvs_to_marginalize = [memo[rv] for rv in rvs_to_marginalize] - toposort = fg.toposort() - - for rv_to_marginalize in sorted( - rvs_to_marginalize, - key=lambda rv: toposort.index(rv.owner), - reverse=True, - ): - all_rvs = [node.out for node in fg.toposort() if isinstance(node.op, ModelValuedVar)] - - dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs) - if not dependent_rvs: - # TODO: This should at most be a warning, not an error - raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}") - - # Issue warning for IntervalTransform on dependent RVs - for dependent_rv in dependent_rvs: - transform = dependent_rv.owner.op.transform - - if isinstance(transform, IntervalTransform) or ( - isinstance(transform, Chain) - and any(isinstance(tr, IntervalTransform) for tr in transform.transform_list) - ): - warnings.warn( - f"The transform {transform} for the variable {dependent_rv}, which depends on the " - f"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.", - UserWarning, - ) - - # Check that no deterministics or potentials depend on the rv to marginalize - for det in model.deterministics: - if is_conditional_dependent(memo[det], rv_to_marginalize, all_rvs): - raise NotImplementedError( - f"Cannot marginalize {rv_to_marginalize} due to dependent Deterministic {det}" - ) - for pot in model.potentials: - if is_conditional_dependent(memo[pot], rv_to_marginalize, all_rvs): - raise NotImplementedError( - f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}" - ) - - marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs) - other_direct_rv_ancestors = [ - rv - for rv in find_conditional_input_rvs(dependent_rvs, all_rvs) - if rv is not rv_to_marginalize - ] - input_rvs = _unique((*marginalized_rv_input_rvs, *other_direct_rv_ancestors)) - - if use_laplace: - Q = marginalize_kwargs["Q"] - marginalize_kwargs["Q"] = memo.get(Q, pt.as_tensor_variable(Q)).copy() - - replace_marginal_subgraph( - fg, rv_to_marginalize, dependent_rvs, input_rvs, use_laplace, **marginalize_kwargs - ) - - return model_from_fgraph(fg, mutate_fgraph=True) - - -@node_rewriter(tracks=[MarginalRV]) -def local_unmarginalize(fgraph, node): - unmarginalized_rv, *dependent_rvs_and_rngs = inline_ofg_outputs(node.op, node.inputs) - rngs = [rng for rng in dependent_rvs_and_rngs if isinstance(rng.type, RandomType)] - dependent_rvs = [rv for rv in dependent_rvs_and_rngs if rv not in rngs] - - # Wrap the marginalized RV in a FreeRV - # TODO: Preserve dims and transform in MarginalRV - value = unmarginalized_rv.clone() - fgraph.add_input(value) - transform = None - unmarginalized_free_rv = model_free_rv(unmarginalized_rv, value, transform, *node.op.dims) - - # Replace references to the marginalized RV with the FreeRV in the dependent RVs - dependent_rvs = graph_replace(dependent_rvs, {unmarginalized_rv: unmarginalized_free_rv}) - - return [unmarginalized_free_rv, *dependent_rvs, *rngs] - - -unmarginalize_rewrite = in2out(local_unmarginalize, ignore_newtrees=False) - - -def unmarginalize(model: Model, rvs_to_unmarginalize: str | Sequence[str] | None = None) -> Model: - """Unmarginalize a subset of variables in a PyMC model. - - - Parameters - ---------- - model : Model - PyMC model to unmarginalize. Original variables well be cloned. - rvs_to_unmarginalize : str or sequence of str, optional - Variables to unmarginalize in the returned model. If None, all variables are - unmarginalized. - - Returns - ------- - unmarginal_model: Model - Model with the specified variables unmarginalized. - """ - - # Unmarginalize all the MarginalRVs - fg, memo = fgraph_from_model(model) - unmarginalize_rewrite(fg) - unmarginalized_model = model_from_fgraph(fg, mutate_fgraph=True) - if rvs_to_unmarginalize is None: - return unmarginalized_model - - # Re-marginalize the variables we want to keep marginalized - if not isinstance(rvs_to_unmarginalize, list | tuple): - rvs_to_unmarginalize = (rvs_to_unmarginalize,) - rvs_to_unmarginalize = set(rvs_to_unmarginalize) - - old_free_rv_names = set(rv.name for rv in model.free_RVs) - new_free_rv_names = set( - rv.name for rv in unmarginalized_model.free_RVs if rv.name not in old_free_rv_names - ) - if rvs_to_unmarginalize - new_free_rv_names: - raise ValueError( - f"Unrecognized rvs_to_unmarginalize: {rvs_to_unmarginalize - new_free_rv_names}" - ) - rvs_to_keep_marginalized = tuple(new_free_rv_names - rvs_to_unmarginalize) - return marginalize(unmarginalized_model, rvs_to_keep_marginalized) - - -def transform_posterior_pts(model, posterior_pts): - """Create a function from the untransformed space to the transformed space""" - # TODO: This should be a utility in PyMC - transformed_rvs = [] - transformed_names = [] - - for rv in model.free_RVs: - transform = model.rvs_to_transforms.get(rv) - if transform is None: - transformed_rvs.append(rv) - transformed_names.append(rv.name) - else: - transformed_rv = transform.forward(rv, *rv.owner.inputs) - transformed_rvs.append(transformed_rv) - transformed_names.append(model.rvs_to_values[rv].name) - - fn = compile_pymc( - inputs=[In(inp, borrow=True) for inp in model.free_RVs], - outputs=[Out(out, borrow=True) for out in transformed_rvs], - ) - fn.trust_input = True - - # TODO: This should work with vectorized inputs - return [dict(zip(transformed_names, fn(**point))) for point in posterior_pts] - - -def recover_marginals( - idata: DataTree, - *, - model: Model | None = None, - var_names: Sequence[str] | None = None, - return_samples: bool = True, - extend_inferencedata: bool = True, - random_seed: RandomState = None, -): - """Computes posterior log-probabilities and samples of marginalized variables - conditioned on parameters of the model given DataTree with posterior group - - When there are multiple marginalized variables, each marginalized variable is - conditioned on both the parameters and the other variables still marginalized - - All log-probabilities are within the transformed space - - Parameters - ---------- - model: Model - PyMC model with marginalized variables to recover - idata : DataTree - DataTree with posterior group - var_names : sequence of str, optional - List of variable names for which to compute posterior log-probabilities and samples. Defaults to all marginalized variables - return_samples : bool, default True - If True, also return samples of the marginalized variables - extend_inferencedata : bool, default True - Whether to extend the original DataTree or return a new one - random_seed: int, array-like of int or SeedSequence, optional - Seed used to generating samples - - Returns - ------- - idata : DataTree - DataTree with where a lp_{varname} and {varname} for each marginalized variable in var_names added to the posterior group - - .. code-block:: python - - import pymc as pm - from pymc_extras import MarginalModel - - with MarginalModel() as m: - p = pm.Beta("p", 1, 1) - x = pm.Bernoulli("x", p=p, shape=(3,)) - y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10]) - - m.marginalize([x]) - - idata = pm.sample() - m.recover_marginals(idata, var_names=["x"]) - - - """ - # Temporary error message for helping with migration - # Will be removed in a future release - if isinstance(idata, Model): - raise TypeError( - "The order of arguments of `recover_marginals` changed. The first input must be an idata" - ) - - model = modelcontext(model) - - unmarginal_model = unmarginalize(model) - - # Find the names of the marginalized variables - model_var_names = set(rv.name for rv in model.free_RVs) - marginalized_rv_names = [ - rv.name for rv in unmarginal_model.free_RVs if rv.name not in model_var_names - ] - - if var_names is None: - var_names = marginalized_rv_names - - var_names = [var if isinstance(var, str) else var.name for var in var_names] - var_names_to_recover = [name for name in marginalized_rv_names if name in var_names] - missing_names = [name for name in var_names_to_recover if name not in marginalized_rv_names] - if missing_names: - raise ValueError(f"Unrecognized var_names: {missing_names}") - - if return_samples and random_seed is not None: - seeds = _get_seeds_per_chain(random_seed, len(var_names_to_recover)) - else: - seeds = [None] * len(var_names_to_recover) - - posterior_pts, stacked_dims = dataset_to_point_list( - # Remove Deterministics - idata["posterior"].dataset[[rv.name for rv in model.free_RVs]], - sample_dims=("chain", "draw"), - ) - transformed_posterior_pts = transform_posterior_pts(model, posterior_pts) - - rv_dict = {} - rv_dims = {} - for seed, var_name_to_recover in zip(seeds, var_names_to_recover): - var_to_recover = unmarginal_model[var_name_to_recover] - supported_dists = (Bernoulli, Categorical, DiscreteUniform) - if not isinstance(var_to_recover.owner.op, supported_dists): - raise NotImplementedError( - f"RV with distribution {var_to_recover.owner.op} cannot be recovered. " - f"Supported distribution include {supported_dists}" - ) - - other_marginalized_rvs_names = marginalized_rv_names.copy() - other_marginalized_rvs_names.remove(var_name_to_recover) - dependent_rvs = [ - rv - for rv in find_conditional_dependent_rvs(var_to_recover, unmarginal_model.basic_RVs) - if rv.name not in other_marginalized_rvs_names - ] - # Handle batch dims for marginalized value and its dependent RVs - dependent_rvs_dim_connections = subgraph_batch_dim_connection(var_to_recover, dependent_rvs) - - marginalized_model = marginalize(unmarginal_model, other_marginalized_rvs_names) - - marginalized_var_to_recover = marginalized_model[var_name_to_recover] - dependent_rvs = [marginalized_model[rv.name] for rv in dependent_rvs] - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=NonSeparableLogpWarning) - logps = marginalized_model.logp( - vars=[marginalized_var_to_recover, *dependent_rvs], sum=False - ) - - marginalized_logp, *dependent_logps = logps - joint_logp = marginalized_logp + reduce_batch_dependent_logps( - dependent_rvs_dim_connections, - [dependent_var.owner.op for dependent_var in dependent_rvs], - dependent_logps, - ) - - marginalized_value = marginalized_model.rvs_to_values[marginalized_var_to_recover] - other_values = [v for v in marginalized_model.value_vars if v is not marginalized_value] - - rv_shape = constant_fold(tuple(var_to_recover.shape), raise_not_constant=False) - rv_domain = get_domain_of_finite_discrete_rv(var_to_recover) - rv_domain_tensor = pt.moveaxis( - pt.full( - (*rv_shape, len(rv_domain)), - rv_domain, - dtype=var_to_recover.dtype, - ), - -1, - 0, - ) - - batched_joint_logp = vectorize_graph( - joint_logp, - replace={marginalized_value: rv_domain_tensor}, - ) - batched_joint_logp = pt.moveaxis(batched_joint_logp, 0, -1) - - joint_logp_norm = log_softmax(batched_joint_logp, axis=-1) - if return_samples: - rv_draws = Categorical.dist(logit_p=batched_joint_logp) - if isinstance(var_to_recover.owner.op, DiscreteUniform): - rv_draws += rv_domain[0] - outputs = [joint_logp_norm, rv_draws] - else: - outputs = joint_logp_norm - - rv_loglike_fn = compile_pymc( - inputs=other_values, - outputs=outputs, - on_unused_input="ignore", - random_seed=seed, - ) - - logvs = [rv_loglike_fn(**vs) for vs in transformed_posterior_pts] - - if return_samples: - logps, samples = zip(*logvs) - logps = np.asarray(logps) - samples = np.asarray(samples) - rv_dict[var_name_to_recover] = samples.reshape( - tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:], - ) - else: - logps = np.asarray(logvs) - - rv_dict["lp_" + var_name_to_recover] = logps.reshape( - tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:], - ) - if var_name_to_recover in unmarginal_model.named_vars_to_dims: - rv_dims[var_name_to_recover] = list( - unmarginal_model.named_vars_to_dims[var_name_to_recover] - ) - rv_dims["lp_" + var_name_to_recover] = rv_dims[var_name_to_recover] + [ - "lp_" + var_name_to_recover + "_dim" - ] - - coords, dims = coords_and_dims_for_inferencedata(unmarginal_model) - dims.update(rv_dims) - rv_dataset = dict_to_dataset( - rv_dict, - inference_library=pymc, - dims=dims, - coords=coords, - skip_event_dims=True, - ) - - if extend_inferencedata: - idata["posterior"] = idata["posterior"].assign(rv_dataset) - return idata - else: - return rv_dataset - - -def collect_shared_vars(outputs, blockers): - return [ - inp - for inp in graph_inputs(outputs, blockers=blockers) - if (isinstance(inp, SharedVariable) and inp not in blockers) - ] - - -def remove_model_vars(vars): - """Remove ModelVars from the graph of vars.""" - model_vars = [var for var in vars if isinstance(var.owner.op, ModelValuedVar)] - replacements = [(model_var, model_var.owner.inputs[0]) for model_var in model_vars] - fgraph = FunctionGraph(outputs=vars, clone=False) - toposort_replace(fgraph, replacements) - return fgraph.outputs - - -def replace_marginal_subgraph( - fgraph, - rv_to_marginalize, - dependent_rvs, - input_rvs, - use_laplace=False, - **marginalize_kwargs, -) -> None: - # If the marginalized RV has multiple dimensions, check that graph between - # marginalized RV and dependent RVs does not mix information from batch dimensions - # (otherwise logp would require enumerating over all combinations of batch dimension values) - if not use_laplace: - try: - dependent_rvs_dim_connections = subgraph_batch_dim_connection( - rv_to_marginalize, dependent_rvs - ) - except (ValueError, NotImplementedError) as e: - # For the perspective of the user this is a NotImplementedError - raise NotImplementedError( - "The graph between the marginalized and dependent RVs cannot be marginalized efficiently. " - "You can try splitting the marginalized RV into separate components and marginalizing them separately." - ) from e - else: - dependent_rvs_dim_connections = None - - output_rvs = [rv_to_marginalize, *dependent_rvs] - rng_updates = collect_default_updates(output_rvs, inputs=input_rvs, must_be_shared=False) - outputs = output_rvs + list(rng_updates.values()) - inputs = input_rvs + list(rng_updates.keys()) - # Add any other shared variable inputs - inputs += collect_shared_vars(output_rvs, blockers=inputs) - - if use_laplace: - Q = marginalize_kwargs.pop("Q") - inputs.append(Q) - - inner_inputs = [inp.clone() for inp in inputs] - inner_outputs = clone_replace(outputs, replace=dict(zip(inputs, inner_inputs))) - inner_outputs = remove_model_vars(inner_outputs) - - if isinstance(inner_outputs[0].owner.op, DiscreteMarkovChain): - marginalize_constructor = MarginalDiscreteMarkovChainRV - elif use_laplace: - marginalize_constructor = MarginalLaplaceRV - else: - marginalize_constructor = MarginalFiniteDiscreteRV - - _, _, *dims = rv_to_marginalize.owner.inputs - marginalization_op = marginalize_constructor( - inputs=inner_inputs, - outputs=inner_outputs, - dims_connections=dependent_rvs_dim_connections, - dims=dims, - **marginalize_kwargs, - ) - - new_outputs = marginalization_op(*inputs) - for old_output, new_output in zip(outputs, new_outputs): - new_output.name = old_output.name - - model_replacements = [] - for old_output, new_output in zip(outputs, new_outputs): - if old_output is rv_to_marginalize or not isinstance(old_output.owner.op, ModelValuedVar): - # Replace the marginalized ModelFreeRV (or non model-variables) themselves - var_to_replace = old_output - else: - # Replace the underlying RV, keeping the same value, transform and dims - var_to_replace = old_output.owner.inputs[0] - model_replacements.append((var_to_replace, new_output)) - - fgraph.replace_all(model_replacements) diff --git a/pymc_extras/model/marginal/model.py b/pymc_extras/model/marginal/model.py new file mode 100644 index 000000000..3412d86a4 --- /dev/null +++ b/pymc_extras/model/marginal/model.py @@ -0,0 +1,774 @@ +import warnings + +from collections.abc import Sequence + +import pytensor.tensor as pt + +from pymc.distributions.transforms import Chain +from pymc.logprob.transforms import IntervalTransform +from pymc.model import Model, modelcontext +from pymc.model.fgraph import ( + ModelValuedVar, + extract_dims, + fgraph_from_model, + model_from_fgraph, +) +from pymc.util import RandomState, _get_seeds_per_chain +from pytensor.compile import SharedVariable +from pytensor.graph import ( + Variable, + graph_inputs, +) +from pytensor.graph.replace import graph_replace +from pytensor.graph.rewriting.basic import in2out +from pytensor.graph.rewriting.db import RewriteDatabaseQuery +from pytensor.graph.traversal import io_toposort +from pytensor.tensor import TensorVariable +from xarray import DataTree, merge + +from pymc_extras.model.marginal.distributions.core import ( + MarginalRV, + marginalized_conditional, +) +from pymc_extras.model.marginal.distributions.laplace import MarginalLaplaceRV +from pymc_extras.model.marginal.graph_analysis import ( + find_conditional_dependent_rvs, + find_conditional_input_rvs, + is_conditional_dependent, +) +from pymc_extras.model.marginal.rewrites import ( + DeferredMarginalSubgraph, + LaplaceMarginalSubgraph, + MarginalSubgraph, + MarginalSubgraphBase, + local_unmarginalize, + marginal_rewrites_db, +) + +ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str] + + +def _unique(seq: Sequence) -> list: + """Copied from https://stackoverflow.com/a/480227""" + seen = set() + seen_add = seen.add + return [x for x in seq if not (x in seen or seen_add(x))] + + +def _get_marginalized_rv_names(model, unmarginal_model): + model_var_names = set(rv.name for rv in model.free_RVs) + return [rv.name for rv in unmarginal_model.free_RVs if rv.name not in model_var_names] + + +def replace_marginal_subgraph( + fgraph, rv_to_marginalize, dependent_rvs, input_rvs, use_laplace=False, **marginalize_kwargs +) -> None: + """Replace a marginalized subgraph with a flat MarginalSubgraph marker Op. + + The subgraph stays alive in the fgraph — the MS node references both + the subgraph outputs and boundary vars as its inputs. No cloning here; + rewrites clone at resolution time when building the OpFromGraph. + + If `use_laplace` is True, a LaplaceMarginalSubgraph marker is created + instead, with the precision matrix Q appended as the last boundary input + and the minimizer options stored on the marker. + """ + raw_marg = rv_to_marginalize.owner.inputs[0] + raw_deps = [ + dep.owner.inputs[0] if isinstance(dep.owner.op, ModelValuedVar) else dep + for dep in dependent_rvs + ] + + subgraph_outputs = [raw_marg, *raw_deps] + boundary = list(input_rvs) + boundary += [ + inp + for inp in graph_inputs(subgraph_outputs, blockers=boundary) + if (isinstance(inp, SharedVariable) and inp not in boundary) + ] + + # Unwrap ModelValuedVar inside the subgraph so the interior only + # references raw RVs. This prevents cycles when rv_to_marginalize + # is replaced by the MS output below. + subgraph_nodes = set(io_toposort(boundary, subgraph_outputs)) + for node in list(subgraph_nodes): + if not isinstance(node.op, ModelValuedVar): + continue + model_var = node.outputs[0] + raw_rv = node.inputs[0] + for client_node, client_idx in list(fgraph.clients.get(model_var, [])): + if client_node in subgraph_nodes: + fgraph.change_node_input(client_node, client_idx, raw_rv, import_missing=True) + + marginalized_dims = extract_dims(rv_to_marginalize) + n_dep = len(dependent_rvs) + + output_types = [out.type for out in subgraph_outputs] + if use_laplace: + # Q goes last so the logp implementation can pop it back + boundary.append(marginalize_kwargs.pop("Q")) + op = LaplaceMarginalSubgraph( + n_dependent_rvs=n_dep, + marginalized_dims=marginalized_dims, + output_types=output_types, + **marginalize_kwargs, + ) + else: + has_nested = any( + rd.owner is not None and isinstance(rd.owner.op, MarginalSubgraphBase) + for rd in raw_deps + ) + cls = DeferredMarginalSubgraph if has_nested else MarginalSubgraph + + op = cls( + n_dependent_rvs=n_dep, + marginalized_dims=marginalized_dims, + output_types=output_types, + ) + + new_outputs = op(*(subgraph_outputs + boundary)) + if not isinstance(new_outputs, list): + new_outputs = list(new_outputs) + + for old, new in zip(subgraph_outputs, new_outputs): + new.name = old.name + + fgraph.replace(rv_to_marginalize, new_outputs[0], import_missing=True) + + for i, dep in enumerate(dependent_rvs): + ms_dep = new_outputs[1 + i] + if isinstance(dep.owner.op, ModelValuedVar): + fgraph.change_node_input(dep.owner, 0, ms_dep, import_missing=True) + + +def marginalize( + model: Model, + rvs_to_marginalize: ModelRVs, + rewrite_query=RewriteDatabaseQuery(include=["basic"]), + use_laplace: bool = False, + **marginalize_kwargs, +) -> Model: + """Marginalize a subset of variables in a PyMC model. + + This creates a new `Model`, with the specified variables marginalized. + + Notes + ----- + Deterministics and Potentials cannot be conditionally dependent on the + marginalized variables. + + Marginalization is resolved via logprob rewrites. The supported cases + include finite discrete variables (Bernoulli, Categorical, + DiscreteUniform, DiscreteMarkovChain) and closed-form conjugate pairs + such as Normal-Normal. + + For finite discrete marginalization with batched dimensions, any + conditionally dependent variables must use information from an individual + batched dimension (i.e., the connecting graph must be strictly Elemwise). + If you want to bypass this restriction you can separate each dimension + of the marginalized variable into scalar components and stack them + together. Note that such graphs will grow exponentially in the number of + marginalized variables. + + Parameters + ---------- + model : Model + PyMC model to marginalize. Original variables will be cloned. + rvs_to_marginalize : Sequence[TensorVariable] + Variables to marginalize in the returned model. + use_laplace : bool + Whether to use Laplace approximations to marginalize out + rvs_to_marginalize. Requires passing the precision matrix ``Q`` of the + marginalized variable via ``marginalize_kwargs``, alongside optional + ``minimizer_seed`` and ``minimizer_kwargs``. + + Returns + ------- + marginal_model: Model + Marginal model with the specified variables marginalized. + + Examples + -------- + .. code-block:: python + + import pymc as pm + from pymc_extras.marginal import marginalize + + with pm.Model() as m: + p = pm.Beta("p", 1, 1) + x = pm.Bernoulli("x", p=p, shape=(3,)) + y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10]) + + marginal_m = marginalize(m, [x]) + idata = pm.sample(model=marginal_m) + """ + if isinstance(rvs_to_marginalize, str | Variable): + rvs_to_marginalize = (rvs_to_marginalize,) + + rvs_to_marginalize = [model[rv] if isinstance(rv, str) else rv for rv in rvs_to_marginalize] + + if not rvs_to_marginalize: + return model + + for rv_to_marginalize in rvs_to_marginalize: + if rv_to_marginalize not in model.free_RVs: + raise ValueError(f"Marginalized RV {rv_to_marginalize} is not a free RV in the model") + + fg, memo = fgraph_from_model(model) + rvs_to_marginalize_fg = [memo[rv] for rv in rvs_to_marginalize] + + rvs_to_marginalize = rvs_to_marginalize_fg + toposort = fg.toposort() + + for rv_to_marginalize in sorted( + rvs_to_marginalize, + key=lambda rv: toposort.index(rv.owner), + reverse=True, + ): + all_rvs = [node.out for node in fg.toposort() if isinstance(node.op, ModelValuedVar)] + + dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs) + if not dependent_rvs: + continue + + # Issue warning for IntervalTransform on dependent RVs + for dependent_rv in dependent_rvs: + transform = dependent_rv.owner.op.transform + + if isinstance(transform, IntervalTransform) or ( + isinstance(transform, Chain) + and any(isinstance(tr, IntervalTransform) for tr in transform.transform_list) + ): + warnings.warn( + f"The transform {transform} for the variable {dependent_rv}, which depends on the " + f"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.", + UserWarning, + ) + + # Check that no deterministics or potentials depend on the rv to marginalize + for det in model.deterministics: + if is_conditional_dependent(memo[det], rv_to_marginalize, all_rvs): + raise NotImplementedError( + f"Cannot marginalize {rv_to_marginalize} due to dependent Deterministic {det}" + ) + for pot in model.potentials: + if is_conditional_dependent(memo[pot], rv_to_marginalize, all_rvs): + raise NotImplementedError( + f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}" + ) + + marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs) + other_direct_rv_ancestors = [ + rv + for rv in find_conditional_input_rvs(dependent_rvs, all_rvs) + if rv is not rv_to_marginalize + ] + input_rvs = _unique((*marginalized_rv_input_rvs, *other_direct_rv_ancestors)) + + if use_laplace: + # Q may reference variables of the original model; remap it to the fgraph clones + Q = marginalize_kwargs["Q"] + if not isinstance(Q, Variable): + Q = pt.as_tensor_variable(Q) + marginalize_kwargs["Q"] = memo.get(Q, Q).copy() + + replace_marginal_subgraph( + fg, rv_to_marginalize, dependent_rvs, input_rvs, use_laplace, **marginalize_kwargs + ) + + rewriter = marginal_rewrites_db.query(rewrite_query) + rewriter.rewrite(fg) + + remaining = [node for node in fg.toposort() if isinstance(node.op, MarginalSubgraphBase)] + for node in remaining: + marginalized_rv = node.inputs[0] + n_dep = node.op.n_dependent_rvs + dependent_rvs = node.inputs[1 : 1 + n_dep] + raise NotImplementedError( + f"Cannot marginalize {node.outputs[0]} with distribution " + f"{marginalized_rv.owner.op} and dependent variables " + f"{[rv.owner.op for rv in dependent_rvs]}. " + ) + + return model_from_fgraph(fg, mutate_fgraph=True) + + +def _validate_recover_var_names(var_names, marginalized_rv_names): + if var_names is None: + return list(marginalized_rv_names) + var_names = [var if isinstance(var, str) else var.name for var in var_names] + var_names_to_recover = [name for name in marginalized_rv_names if name in var_names] + missing_names = [name for name in var_names if name not in marginalized_rv_names] + if missing_names: + raise ValueError(f"Unrecognized var_names: {missing_names}") + return var_names_to_recover + + +def _find_laplace_marginalized_names(apply_nodes) -> list[str]: + """Names of Laplace-marginalized variables, including ones absorbed into other MarginalRVs.""" + names = [] + for node in apply_nodes: + if isinstance(node.op, MarginalLaplaceRV): + names.append(node.op.inner_outputs[0].name) + if isinstance(node.op, MarginalRV): + names.extend(_find_laplace_marginalized_names(node.op.fgraph.apply_nodes)) + return names + + +def unmarginalize(model: Model, rvs_to_unmarginalize: str | Sequence[str] | None = None) -> Model: + """Unmarginalize a subset of variables in a PyMC model. + + + Parameters + ---------- + model : Model + PyMC model to unmarginalize. Original variables well be cloned. + rvs_to_unmarginalize : str or sequence of str, optional + Variables to unmarginalize in the returned model. If None, all variables are + unmarginalized. + + Returns + ------- + unmarginal_model: Model + Model with the specified variables unmarginalized. + """ + + fg, _memo = fgraph_from_model(model) + + if rvs_to_unmarginalize is not None: + if not isinstance(rvs_to_unmarginalize, list | tuple): + rvs_to_unmarginalize = (rvs_to_unmarginalize,) + rvs_to_unmarginalize = set(rvs_to_unmarginalize) + + # Laplace-marginalized RVs that are kept marginalized would be re-marginalized + # without the Q / minimizer options they were created with, which cannot be + # recovered from the unmarginalized graph. + kept_laplace_rvs = [ + name + for name in _find_laplace_marginalized_names(fg.apply_nodes) + if name not in rvs_to_unmarginalize + ] + if kept_laplace_rvs: + raise NotImplementedError( + f"Laplace-marginalized variables {kept_laplace_rvs} cannot be kept marginalized " + "through a partial unmarginalize, because their precision matrix Q and minimizer " + "options are not currently preserved when re-marginalizing. Either include them " + "in rvs_to_unmarginalize, or rebuild the model with " + "marginalize(..., use_laplace=True) from scratch." + ) + + # Unmarginalize all the MarginalRVs + in2out(local_unmarginalize, ignore_newtrees=False).apply(fg) + unmarginalized_model = model_from_fgraph(fg, mutate_fgraph=True) + if rvs_to_unmarginalize is None: + return unmarginalized_model + + # Re-marginalize the variables we want to keep marginalized + old_free_rv_names = set(rv.name for rv in model.free_RVs) + new_free_rv_names = set( + rv.name for rv in unmarginalized_model.free_RVs if rv.name not in old_free_rv_names + ) + if rvs_to_unmarginalize - new_free_rv_names: + raise ValueError( + f"Unrecognized rvs_to_unmarginalize: {rvs_to_unmarginalize - new_free_rv_names}" + ) + rvs_to_keep_marginalized = tuple(new_free_rv_names - rvs_to_unmarginalize) + return marginalize(unmarginalized_model, rvs_to_keep_marginalized) + + +def conditional( + model: Model, + rvs_to_recover: ModelRVs | None = None, +) -> Model: + """Replace marginalized variables with their conditional distributions. + + Returns a new model where the specified marginalized variables become + free RVs whose distributions are their conditionals given the dependents. + Unspecified marginalized variables stay marginalized (integrated out). + + The returned model can be used with ``pm.sample_posterior_predictive`` + to draw conditional posterior samples, or with ``model.compile_logp`` + to evaluate conditional log-probabilities. + + The input is a marginalized model. Starting from an original model + factored as ``p(mu) * p(x|mu) * p(y|x)``, marginalizing ``x`` yields + ``p(mu) * p(y|mu)``. ``conditional`` adds ``x`` back as its conditional + distribution, giving ``p(mu) * p(y|mu) * p(x|y, mu)`` -- a re-factorization + of the same full joint ``p(mu, x, y)``: the recovered variable follows the + conditional ``p(x|y, mu)``, while each dependent stays marginalized over it. + + Selecting variables matters when evaluating logp: + ``model.compile_logp(vars=[model["x"]])`` gives the conditional + ``p(x|y, mu)``, while the unqualified ``model.compile_logp()`` is the full + joint ``p(mu, x, y)``. + + Parameters + ---------- + model : Model + PyMC model with marginalized variables. + rvs_to_recover : str, sequence of str, or None + Marginalized variables to recover. Defaults to all. + + Returns + ------- + Model + Model with the specified variables as free RVs with conditional + distributions. + + Examples + -------- + **Basic usage** — recover a marginalized variable: + + .. code-block:: python + + import pymc as pm + from pymc_extras.marginal import marginalize, conditional + + with pm.Model() as m: + p = pm.Beta("p", 1, 1) + idx = pm.Bernoulli("idx", p=p, shape=(3,)) + y = pm.Normal("y", pm.math.switch(idx, -10, 10), observed=[10, 10, -10]) + + marginal_m = marginalize(m, [idx]) + idata = pm.sample(model=marginal_m) + + # Get model with idx's conditional posterior as its distribution + cond_m = conditional(marginal_m) + logp_fn = cond_m.compile_logp(vars=[cond_m["idx"]]) + pm.sample_posterior_predictive(idata, model=cond_m, sample_vars=["idx"]) + + **Nested marginalization** — recover a subset (marginal posterior): + + When multiple variables are marginalized, specifying a subset recovers + those variables with the others integrated out (marginal posterior). + + .. code-block:: python + + with pm.Model() as m: + idx = pm.Bernoulli("idx", p=0.5) + sub_idx = pm.Bernoulli("sub_idx", p=f(idx)) + y = pm.Normal("y", mu=idx + sub_idx, sigma=1) + + marginal_m = marginalize(m, ["idx", "sub_idx"]) + + # Marginal posterior of idx (sub_idx integrated out): + # P(idx | y, σ) = Σ_sub_idx P(idx, sub_idx | y, σ) + cond_idx = conditional(marginal_m, "idx") + + # Marginal posterior of sub_idx (idx integrated out): + # P(sub_idx | y, σ) = Σ_idx P(idx, sub_idx | y, σ) + cond_sub = conditional(marginal_m, "sub_idx") + + **Recovering all nested variables** — joint posterior factorization: + + When recovering all marginalized variables at once, the joint + posterior is factored via the chain rule in recovery order. Each + variable integrates out the not-yet-recovered ones and conditions + on the already-recovered ones: + + .. code-block:: python + + # P(idx, sub_idx | y) = P(idx | y) · P(sub_idx | idx, y) + cond_all = conditional(marginal_m) + + # idx's logp does NOT depend on sub_idx (sub_idx is integrated out): + logp_idx = cond_all.compile_logp(vars=[cond_all["idx"]]) + + # sub_idx's logp depends on idx: + logp_sub = cond_all.compile_logp(vars=[cond_all["sub_idx"]]) + + The result is a valid generative DAG — draw exact joint posterior + samples by forward-sampling through it. + + **Full conditional via unmarginalize:** + + To get the full conditional ``P(idx | sub_idx, y)`` (conditioning + on ``sub_idx`` rather than integrating it out), first unmarginalize + ``sub_idx`` so it becomes a free RV with its original prior, then + conditionalize ``idx``: + + .. code-block:: python + + from pymc_extras.marginal import unmarginalize + + partial_m = unmarginalize(marginal_m, "sub_idx") + cond_full = conditional(partial_m, "idx") + # User must provide sub_idx values when evaluating + """ + unmarginal_model = unmarginalize(model) + marginalized_rv_names = _get_marginalized_rv_names(model, unmarginal_model) + + if rvs_to_recover is None: + var_names_to_recover = list(marginalized_rv_names) + else: + if isinstance(rvs_to_recover, str | Variable): + rvs_to_recover = (rvs_to_recover,) + var_names_to_recover = _validate_recover_var_names(rvs_to_recover, marginalized_rv_names) + + [n for n in marginalized_rv_names if n not in var_names_to_recover] + + if not var_names_to_recover: + return model + + # Chain-rule factorization of the joint posterior. The base is the + # marginal model — dependents keep their marginal distribution via the + # MarginalRV, so the conditional can reference them without cycles. + fg, _memo = fgraph_from_model(model) + + # Check if all requested vars can be found directly in fg. + # If any are nested, recover ALL vars via the chain-rule, then + # re-marginalize the unwanted ones (the chain-rule model IS the joint + # posterior, so Σ_unwanted p(all|y) = p(kept|y)). + all_direct = all(_find_marg_rv(fg, name)[0] is not None for name in var_names_to_recover) + vars_to_add = var_names_to_recover if all_direct else list(marginalized_rv_names) + + for var_name in vars_to_add: + marg_node, source_fg = _find_marg_rv(fg, var_name) + if marg_node is not None: + _add_conditional(fg, marg_node, source_fg, var_name) + else: + source_fg, _ = fgraph_from_model(marginalize(unmarginal_model, [var_name])) + marg_node, _ = _find_marg_rv(source_fg, var_name) + if marg_node is not None: + _add_conditional(fg, marg_node, source_fg, var_name) + else: + raise NotImplementedError( + f"Cannot build conditional for nested variable '{var_name}'. " + f"Use conditional(model) to recover all marginalized variables " + f"together, or unmarginalize the parent variables first." + ) + + result = model_from_fgraph(fg, mutate_fgraph=True) + + # Re-marginalize vars that were recovered only for the chain-rule + # but weren't requested by the user. + vars_to_remarginalize = [n for n in vars_to_add if n not in var_names_to_recover] + if vars_to_remarginalize: + result = marginalize(result, vars_to_remarginalize) + + return result + + +def _find_marg_rv(fg, var_name): + """Find the MarginalRV in ``fg`` whose marginalized variable is ``var_name``.""" + for node in fg.toposort(): + if isinstance(node.op, MarginalRV) and node.op.inner_outputs[0].name == var_name: + return node, fg + return None, fg + + +def _remap_to_fg(sample_graph, source_fg, fg): + """Remap ``sample_graph`` references from ``source_fg`` to ``fg``. + + When ``source_fg is fg`` this is a no-op. Otherwise, maps model RVs + and fgraph inputs by name so nothing from source_fg leaks into fg. + """ + if source_fg is fg: + return sample_graph + + remap = {} + + # Model RVs → fg model RVs (by name) + for src_node in source_fg.toposort(): + if not isinstance(src_node.op, ModelValuedVar): + continue + name = src_node.outputs[0].name + fg_node = next( + ( + n + for n in fg.toposort() + if isinstance(n.op, ModelValuedVar) and n.outputs[0].name == name + ), + None, + ) + if fg_node is not None: + remap[src_node.outputs[0]] = fg_node.outputs[0] + + # Fgraph inputs (value variables) → fg inputs (by name) + fg_inputs_by_name = { + getattr(inp, "name", None): inp for inp in fg.inputs if getattr(inp, "name", None) + } + for src_inp in source_fg.inputs: + src_name = getattr(src_inp, "name", None) + if src_name and src_name in fg_inputs_by_name: + remap[src_inp] = fg_inputs_by_name[src_name] + + if remap: + [sample_graph] = graph_replace([sample_graph], replace=remap, strict=False) + return sample_graph + + +def _add_conditional(fg, marg_node, source_fg, var_name): + """Dispatch on ``marg_node`` (from ``source_fg``), wire result into ``fg``, add as free RV.""" + from pymc.model.fgraph import ModelObservedRV, model_free_rv + + op = marg_node.op + n_dep = op.n_dependent_rvs + + # Dispatch → sample_graph with dep_dummies + sample_graph, dep_dummies = marginalized_conditional(op, marg_node) + replacements = dict(zip(op.inner_inputs, marg_node.inputs)) + [sample_graph] = graph_replace([sample_graph], replace=replacements, strict=False) + + # Map dep_dummies → fg's dependent model RVs (or observed data constants) + dep_remap = {} + for k, dep_output in enumerate(marg_node.outputs[1 : 1 + n_dep]): + clients = source_fg.clients.get(dep_output, []) + mv_client = next((c for c, _ in clients if isinstance(c.op, ModelValuedVar)), None) + dep_name = mv_client.outputs[0].name + is_observed = isinstance(mv_client.op, ModelObservedRV) + + fg_mv = next( + n + for n in fg.toposort() + if isinstance(n.op, ModelValuedVar) and n.outputs[0].name == dep_name + ) + dep_remap[dep_dummies[k]] = fg_mv.inputs[1] if is_observed else fg_mv.outputs[0] + + [sample_graph] = graph_replace([sample_graph], replace=dep_remap, strict=False) + + # Remap remaining source_fg references to fg + sample_graph = _remap_to_fg(sample_graph, source_fg, fg) + + # Import new shared variables (e.g. RNGs from Categorical.dist) + for inp in graph_inputs([sample_graph]): + if isinstance(inp, SharedVariable) and inp not in fg.inputs: + fg.add_input(inp) + + # Add the conditional as a new free RV + sample_graph.name = var_name + value = sample_graph.type() + value.name = var_name + fg.add_input(value) + conditional_free_rv = model_free_rv(sample_graph, value, None, *op.marginalized_dims) + fg.add_output(conditional_free_rv, reason="conditionalize") + + +def recover( + idata: DataTree, + *, + model: Model | None = None, + var_names: Sequence[str] | None = None, + extend_inferencedata: bool = True, + random_seed: RandomState = None, +): + """Sample marginalized variables from their conditional posterior. + + Builds the chain-rule factorization of the joint posterior via + :func:`conditional` and forward-samples all recovered variables + together. For more control, use :func:`conditional` directly. + + Parameters + ---------- + idata : DataTree + DataTree with posterior group. + model : Model, optional + PyMC model with marginalized variables. + var_names : sequence of str, optional + Variables to recover. Defaults to all marginalized variables. + extend_inferencedata : bool, default True + Whether to extend the original DataTree or return a new Dataset. + random_seed : int, array-like of int or SeedSequence, optional + Seed for generating samples. + + Returns + ------- + idata : DataTree or Dataset + DataTree with recovered samples added to posterior, or a new Dataset. + + Examples + -------- + .. code-block:: python + + import pymc as pm + from pymc_extras.marginal import marginalize, recover + + with pm.Model() as m: + p = pm.Beta("p", 1, 1) + x = pm.Bernoulli("x", p=p, shape=(3,)) + y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10]) + + marginal_m = marginalize(m, [x]) + idata = pm.sample(model=marginal_m) + recover(idata, model=marginal_m) + """ + import pymc as pm + + if isinstance(idata, Model): + raise TypeError( + "The order of arguments of `recover` changed. " "The first input must be an idata" + ) + + model = modelcontext(model) + unmarginal_model = unmarginalize(model) + marginalized_rv_names = _get_marginalized_rv_names(model, unmarginal_model) + var_names_to_recover = _validate_recover_var_names(var_names, marginalized_rv_names) + + if random_seed is not None: + _get_seeds_per_chain(random_seed, len(var_names_to_recover)) + else: + [None] * len(var_names_to_recover) + + # Build a single conditional model recovering all requested variables + # via the chain-rule factorization. This handles nested variables + # correctly (each conditions on the already-recovered ones and + # integrates out the not-yet-recovered ones). Sample all recovered + # variables together so the chain-rule dependencies are satisfied + # (e.g. sub_idx's conditional uses idx's sampled value). + cond_model = conditional(model, var_names_to_recover) + freeze = [rv.name for rv in cond_model.free_RVs if rv.name not in var_names_to_recover] + + sample_result = pm.sample_posterior_predictive( + idata, + model=cond_model, + sample_vars=var_names_to_recover, + freeze_vars=freeze, + random_seed=random_seed, + progressbar=False, + ) + pp = sample_result.posterior_predictive + pp_ds = pp.dataset if isinstance(pp, DataTree) else pp + + all_datasets = [pp_ds[var_names_to_recover]] + + if not all_datasets: + return idata + + rv_dataset = all_datasets[0] + for ds in all_datasets[1:]: + rv_dataset = merge([rv_dataset, ds], compat="override") + + if extend_inferencedata: + idata["posterior"] = idata["posterior"].assign(rv_dataset) + return idata + else: + return rv_dataset + + +def recover_marginals(*args, return_samples: bool = True, **kwargs): + """Deprecated alias for :func:`recover`. + + .. deprecated:: + ``recover_marginals`` has been renamed to :func:`recover` (available as + ``pymc_extras.marginal.recover``). Unlike the old implementation, it no + longer returns the posterior log-probabilities of the marginalized + variables (the ``lp_*`` arrays / ``return_samples=False`` mode); use + :func:`conditional` together with ``Model.compile_logp`` to evaluate + those instead. + """ + warnings.warn( + "`recover_marginals` has been renamed to `recover` and moved to the " + "`pymc_extras.marginal` namespace (`pymc_extras.marginal.recover`).", + FutureWarning, + stacklevel=2, + ) + if not return_samples: + raise NotImplementedError( + "`recover` no longer returns posterior log-probabilities of the " + "marginalized variables. Use `conditional(...)` with " + "`Model.compile_logp` to evaluate them instead." + ) + return recover(*args, **kwargs) + + +__all__ = ["conditional", "marginalize", "recover", "recover_marginals", "unmarginalize"] diff --git a/pymc_extras/model/marginal/rewrites.py b/pymc_extras/model/marginal/rewrites.py new file mode 100644 index 000000000..22088349d --- /dev/null +++ b/pymc_extras/model/marginal/rewrites.py @@ -0,0 +1,334 @@ +from pymc.distributions import Bernoulli, Categorical, DiscreteUniform +from pymc.model.fgraph import model_free_rv +from pymc.pytensorf import collect_default_updates +from pytensor.compile import SharedVariable +from pytensor.graph import Apply, Op, node_rewriter +from pytensor.graph.replace import graph_replace +from pytensor.graph.rewriting.db import EquilibriumDB +from pytensor.graph.traversal import graph_inputs + +from pymc_extras.distributions.timeseries import DiscreteMarkovChain +from pymc_extras.model.marginal.distributions.core import MarginalRV, inline_ofg_outputs +from pymc_extras.model.marginal.distributions.enumerable import ( + MarginalDiscreteMarkovChainRV, + MarginalFiniteDiscreteRV, +) +from pymc_extras.model.marginal.distributions.laplace import MarginalLaplaceRV +from pymc_extras.model.marginal.graph_analysis import subgraph_batch_dim_connection + + +class MarginalSubgraphBase(Op): + """Base for flat IR markers representing marginalized subgraphs. + + Inputs: [*subgraph_outputs, *boundary_vars] + Outputs: [marginalized_rv, *dependent_rvs] + + The marker delimits the Markov blanket of the marginalized RV: the + dependent RVs are its children, and the boundary contains its parents + and the children's other parents. Given the boundary, the marginalized + RV is conditionally independent of the rest of the model, so rewrites + can resolve the marker locally. + + The actual subgraph lives in the fgraph between the boundary vars + and the subgraph outputs. At rewrite time, the subgraph is cloned + out of the fgraph to build the OpFromGraph (MarginalRV subclass). + RNG updates are discovered at clone time, not stored on the marker. + """ + + def __init__(self, n_dependent_rvs, marginalized_dims, output_types): + self.n_dependent_rvs = n_dependent_rvs + self.marginalized_dims = marginalized_dims + self.output_types = output_types + super().__init__() + + def __eq__(self, other): + return self is other + + def __hash__(self): + return id(self) + + def make_node(self, *inputs): + outputs = [t() for t in self.output_types] + return Apply(self, list(inputs), outputs) + + @property + def n_subgraph_outputs(self): + return 1 + self.n_dependent_rvs + + def split_node_inputs(self, node): + """Split node.inputs into (subgraph_outputs, boundary).""" + n = self.n_subgraph_outputs + return list(node.inputs[:n]), list(node.inputs[n:]) + + def perform(self, node, inputs, outputs): + raise NotImplementedError("MarginalSubgraph should be resolved by rewrites") + + +class MarginalSubgraph(MarginalSubgraphBase): + """Ready-to-resolve marginalized subgraph marker.""" + + +class DeferredMarginalSubgraph(MarginalSubgraphBase): + """Marginalized subgraph whose inner deps are not yet resolved. + + Some dependent RVs come from unresolved MarginalSubgraph nodes. + Type-specific rewrites (finite_discrete_marginal, etc.) track + MarginalSubgraph and won't match this class. Once the inner + MarginalSubgraph nodes are resolved by the EquilibriumDB, + resolve_deferred_marginal_subgraph converts this to a plain + MarginalSubgraph so those rewrites can fire. + """ + + +class LaplaceMarginalSubgraph(MarginalSubgraphBase): + """Marginalized subgraph to be resolved via Laplace approximation. + + Created when the user calls ``marginalize(..., use_laplace=True)``. + The precision matrix Q of the marginalized variable is appended as the + last boundary input; the minimizer options are stored on the marker and + forwarded to the MarginalLaplaceRV. + """ + + def __init__( + self, + *args, + minimizer_seed: int, + minimizer_kwargs: dict = {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}}, + **kwargs, + ): + self.minimizer_seed = minimizer_seed + self.minimizer_kwargs = minimizer_kwargs + super().__init__(*args, **kwargs) + + +def extract_marginal_subgraph(node): + """Extract inputs/outputs from a MarginalSubgraph node for building an OpFromGraph. + + ModelValuedVar nodes inside the subgraph were already unwrapped by + _unwrap_subgraph_model_vars during replace_marginal_subgraph. RNG + updates are discovered here. The OpFromGraph constructor handles cloning. + + Returns (inputs, outputs) where outputs = [marginalized_rv, *deps, *rng_updates]. + """ + subgraph_outputs, boundary = node.op.split_node_inputs(node) + + n_rvs = 1 + node.op.n_dependent_rvs + rng_updates = collect_default_updates( + subgraph_outputs[:n_rvs], inputs=boundary, must_be_shared=False + ) + + outputs = subgraph_outputs + list(rng_updates.values()) + return boundary, outputs + + +@node_rewriter(tracks=[MarginalRV]) +def local_unmarginalize(fgraph, node): + all_outputs = inline_ofg_outputs(node.op, node.inputs) + n_dep = node.op.n_dependent_rvs + unmarginalized_rv = all_outputs[0] + dependent_rvs = list(all_outputs[1 : 1 + n_dep]) + rngs = list(all_outputs[1 + n_dep :]) + + value = unmarginalized_rv.clone() + fgraph.add_input(value) + transform = None + unmarginalized_free_rv = model_free_rv( + unmarginalized_rv, value, transform, *node.op.marginalized_dims + ) + + dependent_rvs = graph_replace(dependent_rvs, {unmarginalized_rv: unmarginalized_free_rv}) + + return [unmarginalized_free_rv, *dependent_rvs, *rngs] + + +marginal_rewrites_db = EquilibriumDB() +marginal_rewrites_db.name = "marginal_rewrites_db" + + +@node_rewriter(tracks=[MarginalSubgraph]) +def finite_discrete_marginal(fgraph, node): + op = node.op + n_dep = op.n_dependent_rvs + + inputs, outputs = extract_marginal_subgraph(node) + marginalized_rv = outputs[0] + + marginalized_rv_op = marginalized_rv.owner.op + if not isinstance( + marginalized_rv_op, Bernoulli | Categorical | DiscreteUniform | DiscreteMarkovChain + ): + return None + + if isinstance(marginalized_rv_op, DiscreteMarkovChain): + if marginalized_rv_op.n_lags > 1: + raise NotImplementedError( + "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported" + ) + if marginalized_rv.owner.inputs[0].type.ndim > 2: + raise NotImplementedError( + "Marginalization for DiscreteMarkovChain with non-matrix transition probability " + "is not supported" + ) + + try: + dependent_rvs_dim_connections = subgraph_batch_dim_connection( + marginalized_rv, outputs[1 : 1 + n_dep] + ) + except (ValueError, NotImplementedError) as e: + raise type(e)( + "The graph between the marginalized and dependent RVs cannot be marginalized efficiently. " + "You can try splitting the marginalized RV into separate components and marginalizing " + f"them separately. {e}" + ) from e + + if isinstance(marginalized_rv_op, DiscreteMarkovChain): + constructor = MarginalDiscreteMarkovChainRV + else: + constructor = MarginalFiniteDiscreteRV + + typed_op = constructor( + inputs=inputs, + outputs=outputs, + dims_connections=dependent_rvs_dim_connections, + marginalized_dims=op.marginalized_dims, + n_dependent_rvs=n_dep, + ) + + new_outputs = typed_op(*inputs) + if not isinstance(new_outputs, list): + new_outputs = list(new_outputs) + return new_outputs[: len(node.outputs)] + + +marginal_rewrites_db.register("finite_discrete_marginal", finite_discrete_marginal, "basic") + + +@node_rewriter(tracks=[LaplaceMarginalSubgraph]) +def laplace_marginal(fgraph, node): + op = node.op + + # Q was appended as the last boundary input and is kept as a dummy input + # of the OpFromGraph (popped again by the logp implementation) + inputs, outputs = extract_marginal_subgraph(node) + + typed_op = MarginalLaplaceRV( + inputs=inputs, + outputs=outputs, + marginalized_dims=op.marginalized_dims, + n_dependent_rvs=op.n_dependent_rvs, + minimizer_seed=op.minimizer_seed, + minimizer_kwargs=op.minimizer_kwargs, + ) + + new_outputs = typed_op(*inputs) + if not isinstance(new_outputs, list): + new_outputs = list(new_outputs) + return new_outputs[: len(node.outputs)] + + +marginal_rewrites_db.register("laplace_marginal", laplace_marginal, "basic") + + +@node_rewriter(tracks=[MarginalSubgraph]) +def unwrap_inner_marginal_rv(fgraph, node): + """Unwrap a MarginalRV inside a MarginalSubgraph's subgraph. + + When a variable absorbed by a prior marginalize() call is re-marginalized, + its raw RV comes from a MarginalRV (OpFromGraph). This rewrite inlines that + MarginalRV and rebuilds as nested MarginalSubgraph markers that the + type-specific rewrites can handle. + """ + subgraph_outputs, boundary = node.op.split_node_inputs(node) + marginalized_rv = subgraph_outputs[0] + outer_dep_outputs = subgraph_outputs[1:] + + if not (marginalized_rv.owner and isinstance(marginalized_rv.owner.op, MarginalRV)): + return None + + marg_rv_node = marginalized_rv.owner + marg_rv_op = marg_rv_node.op + + # Inline the MarginalRV to get raw variables + inlined = inline_ofg_outputs(marg_rv_op, marg_rv_node.inputs) + inlined_marginalized = inlined[0] + inlined_deps = list(inlined[1 : 1 + marg_rv_op.n_dependent_rvs]) + + # Map MFD dep outputs → inlined raw variables + target_idx = list(marg_rv_node.outputs).index(marginalized_rv) - 1 + target_inlined = inlined_deps[target_idx] + deps_inlined = [ + inlined_deps[list(marg_rv_node.outputs).index(d) - 1] for d in outer_dep_outputs + ] + + def _shared_boundary(outputs, base_boundary): + return base_boundary + [ + inp + for inp in graph_inputs(outputs, blockers=base_boundary) + if isinstance(inp, SharedVariable) and inp not in base_boundary + ] + + # Inner MS: marginalize the target variable (e.g. sub_idx), deps are outer deps + # Compute boundary from scratch — only shared vars actually used by this subgraph. + # Block inlined_marginalized so idx's RNG doesn't leak into the inner boundary. + inner_subgraph = [target_inlined, *deps_inlined] + inner_boundary = _shared_boundary(inner_subgraph, [inlined_marginalized]) + inner_ms = MarginalSubgraph( + n_dependent_rvs=len(deps_inlined), + marginalized_dims=node.op.marginalized_dims, + output_types=[o.type for o in inner_subgraph], + ) + inner_outs = inner_ms(*(inner_subgraph + inner_boundary)) + if not isinstance(inner_outs, list): + inner_outs = list(inner_outs) + + # Outer DeferredMS: marginalize the previously-marginalized variable (e.g. idx) + # Use original boundary (not inner_boundary) so inlined_marginalized stays internal + outer_subgraph = [inlined_marginalized, *inner_outs[1:]] + outer_boundary = _shared_boundary(outer_subgraph, list(boundary)) + outer_ms = DeferredMarginalSubgraph( + n_dependent_rvs=len(deps_inlined), + marginalized_dims=marg_rv_op.marginalized_dims, + output_types=[o.type for o in outer_subgraph], + ) + outer_outs = outer_ms(*(outer_subgraph + outer_boundary)) + if not isinstance(outer_outs, list): + outer_outs = list(outer_outs) + + return outer_outs[: len(node.outputs)] + + +marginal_rewrites_db.register( + "unwrap_inner_marginal_rv", unwrap_inner_marginal_rv, "basic", "unwrap" +) + + +@node_rewriter(tracks=[DeferredMarginalSubgraph]) +def resolve_deferred_marginal_subgraph(fgraph, node): + """Convert DeferredMarginalSubgraph to MarginalSubgraph once inner deps are resolved. + + The EquilibriumDB resolves inner MarginalSubgraph nodes first (they live in + the same fgraph). Once none of this node's inputs come from a + MarginalSubgraph, this rewrite promotes it to a plain MarginalSubgraph + so the type-specific rewrites can fire. + """ + for inp in node.inputs: + if inp.owner is not None and isinstance(inp.owner.op, MarginalSubgraphBase): + return None + + op = node.op + resolved_op = MarginalSubgraph( + n_dependent_rvs=op.n_dependent_rvs, + marginalized_dims=op.marginalized_dims, + output_types=op.output_types, + ) + new_outputs = resolved_op(*node.inputs) + if not isinstance(new_outputs, list): + new_outputs = list(new_outputs) + return new_outputs + + +marginal_rewrites_db.register( + "resolve_deferred_marginal_subgraph", + resolve_deferred_marginal_subgraph, + "basic", +) diff --git a/pyproject.toml b/pyproject.toml index 544d754fe..5cae036ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ keywords = [ license = {file = "LICENSE"} dynamic = ["version"] # specify the version in the __init__.py file dependencies = [ - "pymc>=6.0,<7.0", + "pymc>=6.0.1,<7.0", "pytensor>=3.0.4", "arviz>=1.1", "better-optimize>=0.4.2,<1.0", diff --git a/tests/model/marginal/test_distributions.py b/tests/model/marginal/test_distributions.py index 30806e67e..85882e9bc 100644 --- a/tests/model/marginal/test_distributions.py +++ b/tests/model/marginal/test_distributions.py @@ -20,8 +20,9 @@ def test_marginalized_bernoulli_logp(): marginal_rv_node = MarginalFiniteDiscreteRV( [mu], [idx, y], + n_dependent_rvs=1, dims_connections=(((),),), - dims=(), + marginalized_dims=(), )(mu)[0].owner y_vv = y.clone() diff --git a/tests/model/marginal/test_marginal_model.py b/tests/model/marginal/test_model.py similarity index 91% rename from tests/model/marginal/test_marginal_model.py rename to tests/model/marginal/test_model.py index e677f3aa5..bd0d96f73 100644 --- a/tests/model/marginal/test_marginal_model.py +++ b/tests/model/marginal/test_model.py @@ -15,13 +15,12 @@ from pymc.initial_point import make_initial_point_expression from pymc.pytensorf import constant_fold, inputvars from pymc.util import UNSET -from scipy.special import log_softmax, logsumexp -from scipy.stats import halfnorm, norm +from scipy.special import logsumexp -from pymc_extras.model.marginal.distributions import MarginalRV -from pymc_extras.model.marginal.marginal_model import ( +from pymc_extras.model.marginal.distributions.core import MarginalRV +from pymc_extras.model.marginal.model import ( marginalize, - recover_marginals, + recover, unmarginalize, ) from pymc_extras.utils.model_equivalence import equal_computations_up_to_root, equivalent_models @@ -275,6 +274,27 @@ def build_model(build_batched: bool) -> Model: np.testing.assert_almost_equal(logp, ref_logp) +def test_sequential_marginalization(): + """Test that sequential marginalization is equivalent to joint marginalization.""" + + def build_model(): + with Model() as m: + idx = pm.Bernoulli("idx", p=0.5) + sub_idx = pm.Bernoulli("sub_idx", p=pt.as_tensor([0.3, 0.7])[idx]) + x = pm.Normal("x", mu=(idx + sub_idx) - 1) + return m + + joint_m = marginalize(build_model(), ["idx", "sub_idx"]) + + # idx first: sub_idx becomes a dependent of idx's marginalization + seq_idx_first = marginalize(marginalize(build_model(), "idx"), "sub_idx") + assert equivalent_models(seq_idx_first, joint_m) + + # sub_idx first: idx remains a plain free RV (sub_idx depends on idx, not vice versa) + seq_sub_first = marginalize(marginalize(build_model(), "sub_idx"), "idx") + assert equivalent_models(seq_sub_first, joint_m) + + def test_interdependent_rvs(): """Test Marginalization when dependent RVs are interdependent.""" with Model() as m: @@ -406,7 +426,7 @@ def test_mixed_dims_via_transposed_dot(self): idx = pm.Bernoulli("idx", p=0.7, shape=2) y = pm.Normal("y", mu=idx @ idx.T) - with pytest.raises(NotImplementedError): + with pytest.raises((ValueError, NotImplementedError)): marginalize(m, idx) def test_mixed_dims_via_indexing(self): @@ -415,13 +435,13 @@ def test_mixed_dims_via_indexing(self): with Model() as m: idx = pm.Bernoulli("idx", p=0.7, shape=2) y = pm.Normal("y", mu=mean[idx, :] + mean[:, idx]) - with pytest.raises(NotImplementedError): + with pytest.raises((ValueError, NotImplementedError)): marginalize(m, idx) with Model() as m: idx = pm.Bernoulli("idx", p=0.7, shape=2) y = pm.Normal("y", mu=mean[idx, None] + mean[None, idx]) - with pytest.raises(NotImplementedError): + with pytest.raises((ValueError, NotImplementedError)): marginalize(m, idx) with Model() as m: @@ -430,33 +450,33 @@ def test_mixed_dims_via_indexing(self): mean[None, :][:, idx], 0 ) y = pm.Normal("y", mu=mu) - with pytest.raises(NotImplementedError): + with pytest.raises((ValueError, NotImplementedError)): marginalize(m, idx) with Model() as m: idx = pm.Bernoulli("idx", p=0.7, shape=2) y = pm.Normal("y", mu=idx[0] + idx[1]) - with pytest.raises(NotImplementedError): + with pytest.raises((ValueError, NotImplementedError)): marginalize(m, idx) def test_mixed_dims_via_vector_indexing(self): with Model() as m: idx = pm.Bernoulli("idx", p=0.7, shape=2) y = pm.Normal("y", mu=idx[[0, 1, 0, 0]]) - with pytest.raises(NotImplementedError): + with pytest.raises((ValueError, NotImplementedError)): marginalize(m, idx) with Model() as m: idx = pm.Categorical("key", p=[0.1, 0.3, 0.6], shape=(2, 2)) y = pm.Normal("y", pt.as_tensor([[0, 1], [2, 3]])[idx.astype(bool)]) - with pytest.raises(NotImplementedError): + with pytest.raises((ValueError, NotImplementedError)): marginalize(m, idx) def test_mixed_dims_via_support_dimension(self): with Model() as m: x = pm.Bernoulli("x", p=0.7, shape=3) y = pm.Dirichlet("y", a=x * 10 + 1) - with pytest.raises(NotImplementedError): + with pytest.raises((ValueError, NotImplementedError)): marginalize(m, x) def test_mixed_dims_via_nested_marginalization(self): @@ -465,7 +485,7 @@ def test_mixed_dims_via_nested_marginalization(self): y = pm.Bernoulli("y", p=0.7, shape=(2,)) z = pm.Normal("z", mu=pt.add.outer(x, y), shape=(3, 2)) - with pytest.raises(NotImplementedError): + with pytest.raises((ValueError, NotImplementedError)): marginalize(m, [x, y]) @@ -847,33 +867,14 @@ def test_basic(self, explicit_model): ) if explicit_model: - idata = recover_marginals(idata, model=marginal_m, return_samples=True) + idata = recover(idata, model=marginal_m) else: with marginal_m: - idata = recover_marginals(idata, return_samples=True) + idata = recover(idata) post = idata.posterior assert "k" in post - assert "lp_k" in post assert post.k.shape == post.y.shape - assert post.lp_k.shape == (*post.k.shape, len(p)) - - def true_logp(y, sigma): - y = y.repeat(len(p)).reshape(len(y), -1) - sigma = sigma.repeat(len(p)).reshape(len(sigma), -1) - return log_softmax( - np.log(p) - + norm.logpdf(y, loc=mu, scale=sigma) - + halfnorm.logpdf(sigma) - + np.log(sigma), - axis=1, - ) - - np.testing.assert_almost_equal( - true_logp(post.y.values.flatten(), post.sigma.values.flatten()), - post.lp_k[0].values, - ) - np.testing.assert_almost_equal(logsumexp(post.lp_k, axis=-1), 0) def test_coords(self): """Test if coords can be recovered with marginalized value had it originally""" @@ -894,10 +895,10 @@ def test_coords(self): idata = from_dict({"posterior": {k: np.expand_dims(prior[k], axis=0) for k in prior}}) with marginal_m: - idata = recover_marginals(idata, return_samples=True) + idata = recover(idata) post = idata.posterior + assert "idx" in post assert post.idx.dims == ("chain", "draw", "year") - assert post.lp_idx.dims == ("chain", "draw", "year", "lp_idx_dim") def test_batched(self): """Test that marginalization works for batched random variables""" @@ -918,11 +919,10 @@ def test_batched(self): ) idata = from_dict({"posterior": {k: np.expand_dims(prior[k], axis=0) for k in prior}}) - idata = recover_marginals(idata, return_samples=True) + idata = recover(idata) post = idata.posterior assert post["y"].shape == (1, 20, 2, 3) assert post["idx"].shape == (1, 20, 3, 2) - assert post["lp_idx"].shape == (1, 20, 3, 2, 2) def test_nested(self): """Test that marginalization works when there are nested marginalized RVs""" @@ -946,38 +946,12 @@ def test_nested(self): {"posterior": {k: np.expand_dims(v, axis=0) for k, v in prior.items()}} ) - idata = recover_marginals(idata, return_samples=True) + idata = recover(idata) post = idata.posterior assert "idx" in post - assert "lp_idx" in post assert post.idx.shape == post.y.shape - assert post.lp_idx.shape == (*post.idx.shape, 2) assert "sub_idx" in post - assert "lp_sub_idx" in post assert post.sub_idx.shape == post.y.shape - assert post.lp_sub_idx.shape == (*post.sub_idx.shape, 2) - - def true_idx_logp(y): - idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.15 * 0.25 * norm.pdf(y, loc=1)) - idx_1 = np.log(0.05 * 0.75 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2)) - return log_softmax(np.stack([idx_0, idx_1]).T, axis=1) - - np.testing.assert_almost_equal( - true_idx_logp(post.y.values.flatten()), - post.lp_idx[0].values, - ) - - def true_sub_idx_logp(y): - sub_idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.05 * 0.75 * norm.pdf(y, loc=1)) - sub_idx_1 = np.log(0.15 * 0.25 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2)) - return log_softmax(np.stack([sub_idx_0, sub_idx_1]).T, axis=1) - - np.testing.assert_almost_equal( - true_sub_idx_logp(post.y.values.flatten()), - post.lp_sub_idx[0].values, - ) - np.testing.assert_almost_equal(logsumexp(post.lp_idx, axis=-1), 0) - np.testing.assert_almost_equal(logsumexp(post.lp_sub_idx, axis=-1), 0) def test_forward_after_sampling(): @@ -992,7 +966,7 @@ def test_forward_after_sampling(): marginalized_mod = marginalize(m, [is_outlier]) - # Check that model.initial_point() does not modify the inner graph of the MarginalRV + # Check that model.initial_point() does not modify the inner graph of the marginalization Op marginal_rv = marginalized_mod["y_hat"] inner_outputs_before = marginal_rv.owner.op.fgraph.clone().outputs marginalized_mod.initial_point() diff --git a/tests/model/marginal/test_rewrites.py b/tests/model/marginal/test_rewrites.py new file mode 100644 index 000000000..c36ae313c --- /dev/null +++ b/tests/model/marginal/test_rewrites.py @@ -0,0 +1,260 @@ +import numpy as np +import pymc as pm +import pytensor.tensor as pt +import pytest +import scipy + +from arviz import from_dict +from pymc.model.transform.conditioning import remove_value_transforms + +from pymc_extras import marginalize +from pymc_extras.model.marginal.model import conditional, unmarginalize + + +def compute_conditional_logprob(cond_model, var_name, domain, point): + """Compute log P(var=k | data, params) for each k in domain. + + This is the pattern users would follow to evaluate conditional + log-probabilities from a model returned by ``conditional()``. + + Parameters + ---------- + cond_model : pm.Model + Model returned by ``conditional()``. + var_name : str + Name of the recovered variable. + domain : array-like + Domain values to evaluate. + point : dict + Values for all other variables in the model. + + Returns + ------- + logps : array + Log-probabilities for each domain value. + """ + logp_fn = cond_model.compile_logp(vars=[cond_model[var_name]]) + return np.array([logp_fn({**point, var_name: k}) for k in domain]) + + +class TestConditional: + def test_finite_discrete_logp(self): + """Test that conditional gives correct conditional logp for Bernoulli.""" + with pm.Model() as m: + sigma = pm.HalfNormal("sigma") + idx = pm.Bernoulli("idx", p=0.5) + y = pm.Normal("y", mu=idx, sigma=sigma) + + marginal_m = marginalize(m, "idx") + cond_m = conditional(marginal_m) + + assert "idx" in [rv.name for rv in cond_m.free_RVs] + + logps = compute_conditional_logprob( + cond_m, "idx", domain=[0, 1], point={"sigma_log__": 0.0, "y": 2.0} + ) + + # Manual: log P(idx=k | y=2, sigma=1) ∝ log P(y=2|idx=k, sigma=1) + log P(idx=k) + expected = scipy.special.log_softmax( + [scipy.stats.norm.logpdf(2.0, k, 1) + np.log(0.5) for k in (0, 1)] + ) + np.testing.assert_allclose(logps, expected) + + def test_with_remove_value_transforms(self): + """Test that remove_value_transforms + conditional gives natural-scale inputs.""" + with pm.Model() as m: + sigma = pm.HalfNormal("sigma") + idx = pm.Bernoulli("idx", p=0.5) + y = pm.Normal("y", mu=idx, sigma=sigma) + + marginal_m = marginalize(m, "idx") + + # Default: should have transformed value variable + cond_m = conditional(marginal_m) + assert any( + rv.name == "sigma" and cond_m.rvs_to_transforms[rv] is not None + for rv in cond_m.free_RVs + ) + + # With remove_value_transforms: natural scale + natural_m = remove_value_transforms(marginal_m) + cond_nat = conditional(natural_m) + assert all(cond_nat.rvs_to_transforms[rv] is None for rv in cond_nat.free_RVs) + + # Both should give the same logp result + logp_fn = cond_m.compile_logp(vars=[cond_m["idx"]]) + logp_fn_nat = cond_nat.compile_logp(vars=[cond_nat["idx"]]) + + lp = logp_fn({"sigma_log__": 0.0, "y": 2.0, "idx": 1}) + lp_nat = logp_fn_nat({"sigma": 1.0, "y": 2.0, "idx": 1}) + np.testing.assert_allclose(lp, lp_nat) + + def test_categorical_conditional(self): + """Test conditional with Categorical marginalized variable.""" + p = np.array([0.1, 0.3, 0.6]) + mu = np.array([-3.0, 0.0, 3.0]) + + with pm.Model() as m: + k = pm.Categorical("k", p=p) + y = pm.Normal("y", mu=pt.as_tensor(mu)[k], sigma=1.0) + + marginal_m = marginalize(m, "k") + cond_m = conditional(marginal_m) + + y_val = 2.5 + logps = compute_conditional_logprob(cond_m, "k", domain=range(3), point={"y": y_val}) + expected = scipy.special.log_softmax(np.log(p) + scipy.stats.norm.logpdf(y_val, mu, 1.0)) + np.testing.assert_allclose(logps, expected) + + def test_marginal_vs_full_conditional(self): + """Test marginal posterior vs full conditional via unmarginalize.""" + with pm.Model() as m: + idx = pm.Bernoulli("idx", p=0.5) + sub_idx = pm.Bernoulli("sub_idx", p=pt.as_tensor([0.3, 0.7])[idx]) + y = pm.Normal("y", mu=(idx + sub_idx) - 1, sigma=0.5) + + marginal_m = marginalize(m, ["idx", "sub_idx"]) + + # Marginal posterior: P(idx | y) with sub_idx integrated out + cond_marginal = conditional(marginal_m, "idx") + assert "sub_idx" not in [rv.name for rv in cond_marginal.free_RVs] + logp_marginal = cond_marginal.compile_logp(vars=[cond_marginal["idx"]]) + lp_marginal = [logp_marginal({"y": 0.5, "idx": k}) for k in (0, 1)] + np.testing.assert_allclose(scipy.special.logsumexp(lp_marginal), 0.0, atol=1e-14) + + # Full conditional: P(idx | sub_idx, y) via unmarginalize + partial_m = unmarginalize(marginal_m, "sub_idx") + cond_full = conditional(partial_m, "idx") + assert "sub_idx" in [rv.name for rv in cond_full.free_RVs] + with pytest.warns(match="multiple dependent variables"): + logp_full = cond_full.compile_logp(vars=[cond_full["idx"]]) + + # Full conditional depends on sub_idx — different answers for different sub_idx values + lp_given_sub0 = [logp_full({"y": 0.5, "sub_idx": 0, "idx": k}) for k in (0, 1)] + lp_given_sub1 = [logp_full({"y": 0.5, "sub_idx": 1, "idx": k}) for k in (0, 1)] + np.testing.assert_allclose(scipy.special.logsumexp(lp_given_sub0), 0.0, atol=1e-14) + np.testing.assert_allclose(scipy.special.logsumexp(lp_given_sub1), 0.0, atol=1e-14) + + # Full conditionals should differ from each other and from the marginal + assert not np.allclose(lp_given_sub0, lp_given_sub1) + assert not np.allclose(lp_given_sub0, lp_marginal) + + def test_recover_all_nested(self): + """Test recovering all nested variables gives chain-rule factorization.""" + with pm.Model() as m: + idx = pm.Bernoulli("idx", p=0.5) + sub_idx = pm.Bernoulli("sub_idx", p=pt.as_tensor([0.3, 0.7])[idx]) + y = pm.Normal("y", mu=(idx + sub_idx) - 1, sigma=0.5) + + marginal_m = marginalize(m, ["idx", "sub_idx"]) + cond_all = conditional(marginal_m) + + assert set(rv.name for rv in cond_all.free_RVs) == {"idx", "sub_idx", "y"} + + # Each logp is a valid conditional (sums to 1 over domain) + logp_idx = cond_all.compile_logp(vars=[cond_all["idx"]]) + logp_sub = cond_all.compile_logp(vars=[cond_all["sub_idx"]]) + + tp = {"y": 0.5, "idx": 1, "sub_idx": 0} + lp_idx = [logp_idx({**tp, "idx": k}) for k in (0, 1)] + lp_sub = [logp_sub({**tp, "sub_idx": k}) for k in (0, 1)] + np.testing.assert_allclose(scipy.special.logsumexp(lp_idx), 0.0, atol=1e-14) + np.testing.assert_allclose(scipy.special.logsumexp(lp_sub), 0.0, atol=1e-14) + + # Chain-rule factorization: idx has P(idx | y) (sub_idx integrated out), + # so idx's logp does NOT depend on sub_idx + lp_a = logp_idx({"y": 0.5, "idx": 0, "sub_idx": 0}) + lp_b = logp_idx({"y": 0.5, "idx": 0, "sub_idx": 1}) + assert np.isclose(lp_a, lp_b) + + # sub_idx has P(sub_idx | idx, y), so it DOES depend on idx + lp_c = logp_sub({"y": 0.5, "idx": 0, "sub_idx": 0}) + lp_d = logp_sub({"y": 0.5, "idx": 1, "sub_idx": 0}) + assert not np.isclose(lp_c, lp_d) + + def test_recover_nested_subset(self): + """Test recovering a nested variable with its parent integrated out. + + Uses a 3-category idx so marginal posteriors of idx and sub_idx + have different shapes and are numerically distinguishable. + """ + p_idx = np.array([0.1, 0.3, 0.6]) + p_sub_given_idx = np.array([0.2, 0.8, 0.5]) + mu = np.array([[-1, 1], [0, 3], [2, 5]], dtype="float64") # [idx, sub_idx] + + with pm.Model() as m: + idx = pm.Categorical("idx", p=p_idx) + sub_idx = pm.Bernoulli("sub_idx", p=pt.as_tensor(p_sub_given_idx)[idx]) + y = pm.Normal("y", mu=pt.as_tensor(mu)[idx, sub_idx], sigma=1.0, observed=2.5) + + marginal_m = marginalize(m, ["idx", "sub_idx"]) + + # Manual joint log-probabilities: log p(idx=k, sub_idx=j, y) + y_val = 2.5 + log_joints = np.zeros((3, 2)) + for k in range(3): + for j in range(2): + p_s = p_sub_given_idx[k] if j == 1 else 1 - p_sub_given_idx[k] + log_joints[k, j] = ( + np.log(p_idx[k]) + np.log(p_s) + scipy.stats.norm.logpdf(y_val, mu[k, j], 1.0) + ) + + # Recover sub_idx only — idx integrated out + cond_sub = conditional(marginal_m, "sub_idx") + assert "idx" not in [rv.name for rv in cond_sub.free_RVs] + assert "sub_idx" in [rv.name for rv in cond_sub.free_RVs] + + logp_sub_fn = cond_sub.compile_logp(vars=[cond_sub["sub_idx"]]) + actual_sub = np.array([logp_sub_fn({"sub_idx": j}) for j in range(2)]) + expected_sub = scipy.special.log_softmax(scipy.special.logsumexp(log_joints, axis=0)) + np.testing.assert_allclose(actual_sub, expected_sub) + np.testing.assert_allclose(scipy.special.logsumexp(actual_sub), 0.0, atol=1e-14) + + # Recover idx only — sub_idx integrated out + cond_idx = conditional(marginal_m, "idx") + assert "sub_idx" not in [rv.name for rv in cond_idx.free_RVs] + + logp_idx_fn = cond_idx.compile_logp(vars=[cond_idx["idx"]]) + actual_idx = np.array([logp_idx_fn({"idx": k}) for k in range(3)]) + expected_idx = scipy.special.log_softmax(scipy.special.logsumexp(log_joints, axis=1)) + np.testing.assert_allclose(actual_idx, expected_idx) + np.testing.assert_allclose(scipy.special.logsumexp(actual_idx), 0.0, atol=1e-14) + + def test_recover_independent_variables(self): + """Test recovering multiple independent marginalized variables.""" + with pm.Model() as m: + idx1 = pm.Bernoulli("idx1", p=0.75) + x = pm.Normal("x", mu=idx1) + idx2 = pm.Bernoulli("idx2", p=0.75, shape=(5,)) + y = pm.Normal("y", mu=(idx2 * 2 - 1), shape=(5,)) + + marginal_m = marginalize(m, [idx1, idx2]) + cond_m = conditional(marginal_m) + + assert set(rv.name for rv in cond_m.free_RVs) == {"idx1", "idx2", "x", "y"} + + logp_idx1 = cond_m.compile_logp(vars=[cond_m["idx1"]]) + + tp = {"x": 0.5, "y": np.zeros(5), "idx1": 0, "idx2": np.zeros(5, dtype=int)} + lp1 = [logp_idx1({**tp, "idx1": k}) for k in (0, 1)] + np.testing.assert_allclose(scipy.special.logsumexp(lp1), 0.0, atol=1e-14) + + def test_sample_posterior_predictive_single(self): + """Test sample_posterior_predictive with a single recovered variable.""" + with pm.Model() as m: + sigma = pm.HalfNormal("sigma") + idx = pm.Bernoulli("idx", p=0.5) + y = pm.Normal("y", mu=idx, sigma=sigma, observed=1.0) + + marginal_m = marginalize(m, "idx") + cond_m = conditional(marginal_m) + + # sigma=0.01 → y=1 strongly favors idx=1 (mu=1) + idata = from_dict({"posterior": {"sigma": np.full((1, 50), 0.01)}}) + result = pm.sample_posterior_predictive( + idata, + model=cond_m, + sample_vars=["idx"], + random_seed=42, + ) + assert result.posterior_predictive.idx.values.mean() > 0.99 From 28d51c424c223eacbf09450ea3a64e38291c5dac Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 29 May 2026 18:27:30 +0200 Subject: [PATCH 2/3] Add normal-normal closed-form marginalization --- .../model/marginal/distributions/__init__.py | 3 +- .../model/marginal/distributions/normal.py | 64 ++++++++++ pymc_extras/model/marginal/rewrites.py | 55 ++++++++- tests/model/marginal/test_rewrites.py | 109 +++++++++++++++++- 4 files changed, 227 insertions(+), 4 deletions(-) create mode 100644 pymc_extras/model/marginal/distributions/normal.py diff --git a/pymc_extras/model/marginal/distributions/__init__.py b/pymc_extras/model/marginal/distributions/__init__.py index 7e52e2ee9..8b3b158d4 100644 --- a/pymc_extras/model/marginal/distributions/__init__.py +++ b/pymc_extras/model/marginal/distributions/__init__.py @@ -1,5 +1,6 @@ import pymc_extras.model.marginal.distributions.enumerable -import pymc_extras.model.marginal.distributions.laplace # noqa: F401 +import pymc_extras.model.marginal.distributions.laplace +import pymc_extras.model.marginal.distributions.normal # noqa: F401 from pymc_extras.model.marginal.distributions.enumerable import ( MarginalFiniteDiscreteRV, # noqa: F401 diff --git a/pymc_extras/model/marginal/distributions/normal.py b/pymc_extras/model/marginal/distributions/normal.py new file mode 100644 index 000000000..db719ef66 --- /dev/null +++ b/pymc_extras/model/marginal/distributions/normal.py @@ -0,0 +1,64 @@ +from pymc import Normal +from pymc.logprob.abstract import _logprob +from pymc.logprob.basic import logp +from pymc.pytensorf import get_symbolic_rv_shapes +from pytensor.graph.replace import graph_replace +from pytensor.tensor import broadcast_to, constant, sqrt + +from pymc_extras.model.marginal.distributions.core import ( + MarginalRV, + inline_ofg_outputs, + marginalized_conditional, +) + + +class NormalNormalMarginalRV(MarginalRV): + """Marginalized Normal-Normal conjugate pair. + + Inner graph: [marginalized_normal, dependent_normal, *rng_updates] + """ + + def __init__(self, *args, marginalized_dims, **kwargs): + self.marginalized_dims = marginalized_dims + self.n_dependent_rvs = 1 + super().__init__(*args, **kwargs) + + +@_logprob.register(NormalNormalMarginalRV) +def normal_normal_marginal_rv_logp(op: NormalNormalMarginalRV, values, *inputs, **kwargs): + [value] = values + + all_outputs = inline_ofg_outputs(op, inputs) + marginalized_rv = all_outputs[0] + dependent_rv = all_outputs[1] + + mu_m, sigma_m = marginalized_rv.owner.op.dist_params(marginalized_rv.owner) + mu_d, sigma_d = dependent_rv.owner.op.dist_params(dependent_rv.owner) + + if marginalized_rv.type.broadcastable != mu_m.type.broadcastable: + mu_m = broadcast_to(mu_m, get_symbolic_rv_shapes([marginalized_rv.shape])[0]) + + new_mu = graph_replace(mu_d, {marginalized_rv: mu_m}) + new_sigma = sqrt(sigma_d**2 + sigma_m**2) + return logp(Normal.dist(mu=new_mu, sigma=new_sigma), value) + + +@marginalized_conditional.register(NormalNormalMarginalRV) +def normal_normal_conditional(op, node): + fgraph = op.fgraph.clone() + marginalized, inner_dependent = fgraph.outputs[:2] + + mu_m, sigma_m = marginalized.owner.op.dist_params(marginalized.owner) + mu_d, sigma_d = inner_dependent.owner.op.dist_params(inner_dependent.owner) + + dep_dummy = inner_dependent.type() + + offset = graph_replace(mu_d, {marginalized: constant(0, dtype=marginalized.type.dtype)}) + + precision_m = 1 / sigma_m**2 + precision_d = 1 / sigma_d**2 + posterior_precision = precision_m + precision_d + posterior_sigma = sqrt(1 / posterior_precision) + posterior_mu = (mu_m * precision_m + (dep_dummy - offset) * precision_d) / posterior_precision + + return Normal.dist(mu=posterior_mu, sigma=posterior_sigma), [dep_dummy] diff --git a/pymc_extras/model/marginal/rewrites.py b/pymc_extras/model/marginal/rewrites.py index 22088349d..e444aae78 100644 --- a/pymc_extras/model/marginal/rewrites.py +++ b/pymc_extras/model/marginal/rewrites.py @@ -1,11 +1,11 @@ -from pymc.distributions import Bernoulli, Categorical, DiscreteUniform +from pymc.distributions import Bernoulli, Categorical, DiscreteUniform, Normal from pymc.model.fgraph import model_free_rv from pymc.pytensorf import collect_default_updates from pytensor.compile import SharedVariable from pytensor.graph import Apply, Op, node_rewriter from pytensor.graph.replace import graph_replace from pytensor.graph.rewriting.db import EquilibriumDB -from pytensor.graph.traversal import graph_inputs +from pytensor.graph.traversal import ancestors, graph_inputs from pymc_extras.distributions.timeseries import DiscreteMarkovChain from pymc_extras.model.marginal.distributions.core import MarginalRV, inline_ofg_outputs @@ -14,6 +14,7 @@ MarginalFiniteDiscreteRV, ) from pymc_extras.model.marginal.distributions.laplace import MarginalLaplaceRV +from pymc_extras.model.marginal.distributions.normal import NormalNormalMarginalRV from pymc_extras.model.marginal.graph_analysis import subgraph_batch_dim_connection @@ -229,6 +230,56 @@ def laplace_marginal(fgraph, node): marginal_rewrites_db.register("laplace_marginal", laplace_marginal, "basic") +@node_rewriter(tracks=[MarginalSubgraph]) +def normal_normal_marginal_rewrite(fgraph, node): + op = node.op + + if op.n_dependent_rvs != 1: + return None + + inputs, outputs = extract_marginal_subgraph(node) + marginalized_rv = outputs[0] + dependent_rv = outputs[1] + + if not ( + isinstance(marginalized_rv.owner.op, Normal) and isinstance(dependent_rv.owner.op, Normal) + ): + return None + + mu_dep, sigma_dep = dependent_rv.owner.op.dist_params(dependent_rv.owner) + + if marginalized_rv in ancestors([sigma_dep]): + return None + + if mu_dep is not marginalized_rv: + match mu_dep.owner_op_and_inputs: + case (_, a, b): + if a is marginalized_rv: + if marginalized_rv in ancestors([b]): + return None + elif b is marginalized_rv: + if marginalized_rv in ancestors([a]): + return None + else: + return None + case _: + return None + + typed_op = NormalNormalMarginalRV( + inputs=inputs, + outputs=outputs, + marginalized_dims=op.marginalized_dims, + ) + + new_outputs = typed_op(*inputs) + if not isinstance(new_outputs, list): + new_outputs = list(new_outputs) + return new_outputs[: len(node.outputs)] + + +marginal_rewrites_db.register("normal_normal_marginal", normal_normal_marginal_rewrite, "basic") + + @node_rewriter(tracks=[MarginalSubgraph]) def unwrap_inner_marginal_rv(fgraph, node): """Unwrap a MarginalRV inside a MarginalSubgraph's subgraph. diff --git a/tests/model/marginal/test_rewrites.py b/tests/model/marginal/test_rewrites.py index c36ae313c..e2961698d 100644 --- a/tests/model/marginal/test_rewrites.py +++ b/tests/model/marginal/test_rewrites.py @@ -8,7 +8,7 @@ from pymc.model.transform.conditioning import remove_value_transforms from pymc_extras import marginalize -from pymc_extras.model.marginal.model import conditional, unmarginalize +from pymc_extras.model.marginal.model import conditional, recover, unmarginalize def compute_conditional_logprob(cond_model, var_name, domain, point): @@ -258,3 +258,110 @@ def test_sample_posterior_predictive_single(self): random_seed=42, ) assert result.posterior_predictive.idx.values.mean() > 0.99 + + def test_normal_normal_logp(self): + """Test that conditional gives correct conjugate posterior logp for Normal-Normal.""" + sigma_prior = 3.0 + offset = 1.5 + sigma_lik = 4.0 + y_obs = 10.0 + + with pm.Model() as m: + mu = pm.Normal("mu", 0, 10) + x = pm.Normal("x", mu=mu, sigma=sigma_prior) + y = pm.Normal("y", mu=x + offset, sigma=sigma_lik, observed=y_obs) + + marginal_m = marginalize(m, "x") + cond_m = conditional(marginal_m) + + assert "x" in [rv.name for rv in cond_m.free_RVs] + + logp_fn = cond_m.compile_logp(vars=[cond_m["x"]]) + mu_val = 1.0 + x_test = 3.0 + + prec_p = 1 / sigma_prior**2 + prec_l = 1 / sigma_lik**2 + post_prec = prec_p + prec_l + post_sigma = np.sqrt(1 / post_prec) + post_mu = (mu_val * prec_p + (y_obs - offset) * prec_l) / post_prec + + expected = scipy.stats.norm.logpdf(x_test, post_mu, post_sigma) + actual = logp_fn({"mu": mu_val, "x": x_test}) + np.testing.assert_allclose(actual, expected) + + +def test_normal_normal(): + with pm.Model() as m: + x = pm.Normal("x", mu=0, sigma=1) + y = pm.Normal("y", mu=x + np.pi - 1, sigma=1.0) + z = pm.Normal("z", mu=y + 2 * np.pi, sigma=np.sqrt(np.e)) + + marginal_m = marginalize(m, m["y"]) + + test_point = {"x": 1, "z": -1} + + np.testing.assert_allclose( + marginal_m.compile_logp([marginal_m["z"]])(test_point), + scipy.stats.norm.logpdf(test_point["z"], np.pi * 3, np.sqrt(1 + np.e)), + ) + + +@pytest.mark.parametrize("mu_expr", ["x + x", "2 * x"], ids=["x+x", "2*x"]) +@pytest.mark.xfail(reason="Affine f(x)=a*x+b not yet supported") +def test_normal_normal_affine(mu_expr): + with pm.Model() as m: + x = pm.Normal("x", mu=1, sigma=2) + y = pm.Normal("y", mu=eval(mu_expr, {"x": m["x"]}), sigma=3) + + marginal_m = marginalize(m, m["x"]) + + # 2x: mu=2, sigma=sqrt(4*4 + 9)=5 + np.testing.assert_allclose( + marginal_m.compile_logp()({"y": 5.0}), + scipy.stats.norm.logpdf(5.0, 2, 5), + ) + + +def test_normal_normal_nonlinear_in_sigma(): + """Marginalized rv in sigma — not valid for closed-form Normal-Normal.""" + with pm.Model() as m: + x = pm.Normal("x", mu=0, sigma=1) + y = pm.Normal("y", mu=0, sigma=x**2 + 1) + + with pytest.raises(NotImplementedError): + marginalize(m, m["x"]) + + +def test_recover_normal_normal_marginal(): + """Test that recover produces correct conjugate posterior samples.""" + sigma_prior = 3.0 + offset = 1.5 + sigma_lik = 4.0 + y_obs = 10.0 + + with pm.Model() as m: + mu = pm.Normal("mu", 0, 10) + x = pm.Normal("x", mu=mu, sigma=sigma_prior) + y = pm.Normal("y", mu=x + offset, sigma=sigma_lik, observed=y_obs) + + marginal_m = marginalize(m, "x") + + prec_prior = 1 / sigma_prior**2 + prec_lik = 1 / sigma_lik**2 + post_prec = prec_prior + prec_lik + expected_sigma = np.sqrt(1 / post_prec) + + # Use constant mu across many draws for statistical precision + mu_val = 1.0 + expected_mu = (mu_val * prec_prior + (y_obs - offset) * prec_lik) / post_prec + + n_draws = 500 + idata = from_dict({"posterior": {"mu": np.full((4, n_draws), mu_val)}}) + + post = recover(idata, model=marginal_m, random_seed=42) + assert "x" in post.posterior + + x_samples = post.posterior.x.values.flatten() + np.testing.assert_allclose(x_samples.mean(), expected_mu, atol=0.1) + np.testing.assert_allclose(x_samples.std(), expected_sigma, atol=0.1) From fbd8539d8c1e9c37e679194702163e92a78b000a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 12 Jun 2026 11:40:07 +0200 Subject: [PATCH 3/3] Revamp docs and add dev section --- README.md | 53 +++-- docs/api/distributions.rst | 30 +++ docs/api/inference.rst | 18 ++ docs/api/marginalization.rst | 26 +++ docs/api/model.rst | 14 ++ docs/api/prior.rst | 46 ++++ docs/api/reparametrization.rst | 14 ++ docs/api/statespace.rst | 15 ++ docs/api/utils.rst | 21 ++ docs/api_reference.rst | 145 +----------- docs/conf.py | 17 ++ docs/developer/extending_marginalization.rst | 231 +++++++++++++++++++ docs/developer/index.rst | 12 + docs/index.rst | 29 ++- 14 files changed, 510 insertions(+), 161 deletions(-) create mode 100644 docs/api/distributions.rst create mode 100644 docs/api/inference.rst create mode 100644 docs/api/marginalization.rst create mode 100644 docs/api/model.rst create mode 100644 docs/api/prior.rst create mode 100644 docs/api/reparametrization.rst create mode 100644 docs/api/statespace.rst create mode 100644 docs/api/utils.rst create mode 100644 docs/developer/extending_marginalization.rst create mode 100644 docs/developer/index.rst diff --git a/README.md b/README.md index a0db62566..446309d89 100644 --- a/README.md +++ b/README.md @@ -10,52 +10,57 @@ alt="Codecov Badge" /> -As PyMC continues to mature and expand its functionality to accommodate more domains of application, we increasingly see cutting-edge methodologies, highly specialized statistical distributions, and complex models appear. -While this adds to the functionality and relevance of the project, it can also introduce instability and impose a burden on testing and quality control. -To reduce the burden on the main `pymc` repository, this `pymc-extras` repository can become the aggregator and testing ground for new additions to PyMC. -This may include unusual probability distributions, advanced model fitting algorithms, innovative yet not fully tested methods, or niche functionality that might not fit in the main PyMC repository, but still may be of interest to users. +PyMC Extras extends [PyMC](https://www.pymc.io) with additional distributions, inference methods, and model transformations. +It is maintained by the PyMC team and hosts functionality that is too specialized for the core library, but useful enough that you shouldn't have to write it yourself. -The `pymc-extras` repository can be understood as the first step in the PyMC development pipeline, where all novel code is introduced until it is obvious that it belongs in the main repository. -We hope that this organization improves the stability and streamlines the testing overhead of the `pymc` repository, while allowing users and developers to test and evaluate cutting-edge methods and not yet fully mature features. +Highlights include: -`pymc-extras` would be designed to mirror the namespaces in `pymc` to make usage and migration as easy as possible. -For example, a `ParabolicFractal` distribution could be used analogously to those in `pymc`: +- Automatic marginalization: exact for finite discrete and conjugate variables, approximate via the Laplace approximation +- Alternative inference methods: Pathfinder, DADVI, INLA, Laplace approximation, and better MAP estimation +- Statespace models: SARIMAX, VARMAX, ETS, and structural time series with Kalman filtering +- Additional distributions such as `DiscreteMarkovChain`, `GeneralizedPoisson`, and `GenExtreme` + +`pymc-extras` mirrors the namespaces in `pymc` to make usage and migration as easy as possible. +For example, distributions are used exactly like those in `pymc`: ```python import pymc as pm import pymc_extras as pmx with pm.Model(): - alpha = pmx.ParabolicFractal('alpha', b=1, c=1) + xi = pm.HalfNormal("xi", 0.2) + pmx.GenExtreme("llik", mu=1, sigma=0.5, xi=xi, observed=data) +``` + +See the [documentation](https://pymc-extras.readthedocs.io/) for the full API reference. + +## Installation + +```bash +pip install pymc-extras +``` - ... +or for the development version: +```bash +pip install git+https://github.com/pymc-devs/pymc-extras.git ``` ## Questions ### What belongs in `pymc-extras`? -- newly-implemented statistical methods, for example step methods or model construction helpers +- statistical methods, for example step methods or model construction helpers - distributions that are tricky to sample from or test -- infrequently-used fitting methods or distributions +- specialized fitting methods or distributions - any code that requires additional optimization before it can be used in practice +Functionality that proves widely useful may graduate to the main `pymc` repository. ### What does not belong in `pymc-extras`? - Case studies - Implementations that cannot be applied generically, for example because they are tied to variables from a toy example +## Contributing -### Should there be more than one add-on repository? - -Since there is a lot of code that we may not want in the main repository, does it make sense to have more than one additional repository? -For example, `pymc-extras` may just include methods that are not fully developed, tested and trusted, while code that is known to work well and has adequate test coverage, but is still too specialized to become part of `pymc` could reside in a `pymc-extras` (or similar) repository. - - -### Unanswered questions & ToDos -This project is still young and many things have not been answered or implemented. -Please get involved! - -* What are guidelines for organizing submodules? - * Proposal: No default imports of WIP/unstable submodules. By importing manually we can avoid breaking the package if a submodule breaks, for example because of an updated dependency. +We welcome contributions! Check out the [contributing guidelines](https://github.com/pymc-devs/pymc-extras/blob/main/CONTRIBUTING.md) to get started. diff --git a/docs/api/distributions.rst b/docs/api/distributions.rst new file mode 100644 index 000000000..a5bda0fb6 --- /dev/null +++ b/docs/api/distributions.rst @@ -0,0 +1,30 @@ +Distributions +============= + +Distributions that are not (or not yet) part of PyMC itself. They behave +like regular PyMC distributions and can be used directly inside a model. + +.. currentmodule:: pymc_extras.distributions +.. autosummary:: + :toctree: ../generated/ + + Chi + Maxwell + DiscreteMarkovChain + GeneralizedPoisson + BetaNegativeBinomial + GenExtreme + R2D2M2CP + Skellam + histogram_approximation + +Transforms +---------- + +Value transforms for constrained sampling. + +.. currentmodule:: pymc_extras.distributions.transforms +.. autosummary:: + :toctree: ../generated/ + + PartialOrder diff --git a/docs/api/inference.rst b/docs/api/inference.rst new file mode 100644 index 000000000..fa233d734 --- /dev/null +++ b/docs/api/inference.rst @@ -0,0 +1,18 @@ +Inference +========= + +Fitting methods beyond ``pm.sample``: optimization-based point estimates +(``find_MAP``), Gaussian approximations (Laplace, INLA), and fast variational +methods (Pathfinder, DADVI). ``fit`` is a single entry point that dispatches +to these by name. + +.. currentmodule:: pymc_extras.inference +.. autosummary:: + :toctree: ../generated/ + + fit + find_MAP + fit_laplace + fit_pathfinder + fit_dadvi + fit_INLA diff --git a/docs/api/marginalization.rst b/docs/api/marginalization.rst new file mode 100644 index 000000000..b69aa6cd4 --- /dev/null +++ b/docs/api/marginalization.rst @@ -0,0 +1,26 @@ +Marginalization +=============== + +Model transformations that integrate variables out of a model, and recover +them afterwards. Marginalizing discrete variables allows sampling with +gradient-based samplers like NUTS; marginalizing conjugate pairs or using the +Laplace approximation reduces the dimensionality of the posterior. + +``marginalize`` returns a model where the requested variables no longer +appear, but the remaining variables keep their original joint distribution +(exactly, or approximately when using the Laplace approximation). +``unmarginalize`` undoes the transformation, and ``conditional`` / +``recover`` reintroduce the marginalized variables conditioned on the +posterior of the remaining ones. + +.. currentmodule:: pymc_extras.marginal +.. autosummary:: + :toctree: ../generated/ + + marginalize + unmarginalize + conditional + recover + +The set of supported marginalizations is extensible; see +:doc:`../developer/extending_marginalization`. diff --git a/docs/api/model.rst b/docs/api/model.rst new file mode 100644 index 000000000..76ba01ae9 --- /dev/null +++ b/docs/api/model.rst @@ -0,0 +1,14 @@ +Model building +============== + +Tools for defining models. ``as_model`` turns a function with PyMC +statements into a reusable model factory, and ``ModelBuilder`` is a base +class for packaging a model behind a scikit-learn-like ``fit``/``predict`` +interface, with saving and loading included. + +.. currentmodule:: pymc_extras +.. autosummary:: + :toctree: ../generated/ + + as_model + model_builder.ModelBuilder diff --git a/docs/api/prior.rst b/docs/api/prior.rst new file mode 100644 index 000000000..ad07efe8a --- /dev/null +++ b/docs/api/prior.rst @@ -0,0 +1,46 @@ +Prior specification +=================== + +A declarative way to define (hierarchical) prior distributions that can be +serialized to and from JSON. Useful when priors are part of a configuration +file rather than hardcoded in a model, as in +`pymc-marketing `_. + +.. currentmodule:: pymc_extras.prior +.. autosummary:: + :toctree: ../generated/ + + Prior + Censored + Scaled + sample_prior + create_dim_handler + handle_dims + register_tensor_transform + VariableFactory + +From a previous model +--------------------- + +Build a prior from the posterior of a previously fitted model, enabling +simple Bayesian updating workflows. + +.. currentmodule:: pymc_extras.utils +.. autosummary:: + :toctree: ../generated/ + + prior.prior_from_idata + +Deserialization +--------------- + +Registry that maps JSON data back to Python objects, used to round-trip +``Prior`` definitions and extensible to arbitrary custom types. + +.. currentmodule:: pymc_extras.deserialize +.. autosummary:: + :toctree: ../generated/ + + deserialize + register_deserialization + Deserializer diff --git a/docs/api/reparametrization.rst b/docs/api/reparametrization.rst new file mode 100644 index 000000000..d27123c0e --- /dev/null +++ b/docs/api/reparametrization.rst @@ -0,0 +1,14 @@ +Reparametrization +================= + +Automatic reparametrization of hierarchical models. VIP (variationally +inferred parametrization) makes the choice between centered and non-centered +parametrizations continuous and learns the best setting per variable, instead +of leaving it as a manual, all-or-nothing decision. + +.. currentmodule:: pymc_extras.model.transforms +.. autosummary:: + :toctree: ../generated/ + + autoreparam.vip_reparametrize + autoreparam.VIP diff --git a/docs/api/statespace.rst b/docs/api/statespace.rst new file mode 100644 index 000000000..923905f65 --- /dev/null +++ b/docs/api/statespace.rst @@ -0,0 +1,15 @@ +Statespace models +================= + +Linear Gaussian statespace models with Kalman filtering and smoothing: +classical time series models (SARIMAX, VARMAX, ETS) and structural models +built from interpretable components (trend, seasonality, cycles, +autoregressive errors). + +.. automodule:: pymc_extras.statespace +.. toctree:: + :maxdepth: 1 + + ../statespace/core + ../statespace/filters + ../statespace/models diff --git a/docs/api/utils.rst b/docs/api/utils.rst new file mode 100644 index 000000000..ec6175f49 --- /dev/null +++ b/docs/api/utils.rst @@ -0,0 +1,21 @@ +Utilities +========= + +Printing +-------- + +.. currentmodule:: pymc_extras.printing +.. autosummary:: + :toctree: ../generated/ + + model_table + +Miscellaneous +------------- + +.. currentmodule:: pymc_extras.utils +.. autosummary:: + :toctree: ../generated/ + + spline.bspline_interpolation + model_equivalence.equivalent_models diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 17ce77e04..dc07cdb7e 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -1,137 +1,14 @@ API Reference -*************** +************* -Model -===== - -This reference provides detailed documentation for all modules, classes, and -methods in the current release of PyMC experimental. - -.. currentmodule:: pymc_extras -.. autosummary:: - :toctree: generated/ - - as_model - marginalize - model_builder.ModelBuilder - -Marginalization -=============== - -.. currentmodule:: pymc_extras.marginal -.. autosummary:: - :toctree: generated/ - - marginalize - conditional - unmarginalize - recover - -Inference -========= - -.. currentmodule:: pymc_extras.inference -.. autosummary:: - :toctree: generated/ - - find_MAP - fit - fit_laplace - fit_pathfinder - - -Distributions -============= - -.. currentmodule:: pymc_extras.distributions -.. autosummary:: - :toctree: generated/ - - Chi - Maxwell - DiscreteMarkovChain - GeneralizedPoisson - BetaNegativeBinomial - GenExtreme - R2D2M2CP - Skellam - histogram_approximation - -Prior -===== - -.. currentmodule:: pymc_extras.prior -.. autosummary:: - :toctree: generated/ - - create_dim_handler - handle_dims - Prior - register_tensor_transform - VariableFactory - sample_prior - Censored - Scaled - -Deserialize -=========== - -.. currentmodule:: pymc_extras.deserialize -.. autosummary:: - :toctree: generated/ - - deserialize - register_deserialization - Deserializer - - -Transforms -========== - -.. currentmodule:: pymc_extras.distributions.transforms -.. autosummary:: - :toctree: generated/ - - PartialOrder - - -Utils -===== - -.. currentmodule:: pymc_extras.utils -.. autosummary:: - :toctree: generated/ - - spline.bspline_interpolation - prior.prior_from_idata - model_equivalence.equivalent_models - - -Statespace Models -================= -.. automodule:: pymc_extras.statespace .. toctree:: - :maxdepth: 1 - - statespace/core - statespace/filters - statespace/models - - -Model Transforms -================ -.. automodule:: pymc_extras.model.transforms -.. autosummary:: - :toctree: generated/ - - autoreparam.vip_reparametrize - autoreparam.VIP - - -Printing -======== -.. currentmodule:: pymc_extras.printing -.. autosummary:: - :toctree: generated/ - - model_table + :maxdepth: 2 + + api/model + api/marginalization + api/reparametrization + api/inference + api/distributions + api/statespace + api/prior + api/utils diff --git a/docs/conf.py b/docs/conf.py index ec966c103..9c21eaf94 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -56,6 +56,7 @@ "sphinx.ext.viewcode", "sphinx.ext.napoleon", "sphinx.ext.mathjax", + "sphinx.ext.intersphinx", "nbsphinx", "matplotlib.sphinxext.plot_directive", ] @@ -214,3 +215,19 @@ # -- Extension configuration ------------------------------------------------- # https://svn.python.org/projects/external/Jinja-1.1/docs/build/designerdoc.html + +# Docstring plots use the "arviz-darkgrid" style, registered on arviz import. +# arviz >= 1.0 renamed its styles, so alias it when missing. +plot_pre_code = """ +import arviz +import matplotlib.style +if "arviz-darkgrid" not in matplotlib.style.available: + matplotlib.style.library["arviz-darkgrid"] = matplotlib.style.library["arviz-variat"] + matplotlib.style.available.append("arviz-darkgrid") +""" + +intersphinx_mapping = { + "pymc": ("https://www.pymc.io/projects/docs/en/stable/", None), + "pytensor": ("https://pytensor.readthedocs.io/en/latest/", None), + "arviz": ("https://python.arviz.org/en/stable/", None), +} diff --git a/docs/developer/extending_marginalization.rst b/docs/developer/extending_marginalization.rst new file mode 100644 index 000000000..e41e4d89d --- /dev/null +++ b/docs/developer/extending_marginalization.rst @@ -0,0 +1,231 @@ +Extending marginalization +========================= + +:func:`~pymc_extras.marginal.marginalize` is built on a small set of +composable pieces, so new kinds of marginalization (a new conjugate pair, a +new approximation) can be added without touching the core machinery. This +page explains how the pipeline works and what you need to implement. + +All code lives under ``pymc_extras/model/marginal/``. + +How marginalization works +------------------------- + +``marginalize`` operates on the model's +:class:`~pytensor.graph.fg.FunctionGraph` representation (obtained via :func:`pymc.model.fgraph.fgraph_from_model`) in +two stages: + +1. **Marking.** For each variable to marginalize, the subgraph connecting it + to its dependent RVs is wrapped behind a ``MarginalSubgraph`` marker + ``Op`` (in ``rewrites.py``). The marker delimits the variable's Markov + blanket: its children (the dependent RVs) are the subgraph outputs, and + its parents together with the children's other parents form the boundary + inputs. Given the boundary, the marginalized variable is conditionally + independent of the rest of the model, so rewrites can reason about the + marker locally. The marker itself is type-agnostic — it knows nothing + about distributions. + +2. **Resolution.** The ``marginal_rewrites_db`` (a pytensor + ``EquilibriumDB``) is run on the graph. + Each registered rewrite inspects ``MarginalSubgraph`` nodes and, when it + recognizes a pattern it can handle (e.g. a finite discrete variable, or a + Normal whose dependent is also Normal), replaces the marker with a typed + ``MarginalRV``. If no rewrite claims a marker, ``marginalize`` raises + ``NotImplementedError``. + +A ``MarginalRV`` (in ``distributions/core.py``) is an +:class:`~pytensor.compile.builders.OpFromGraph` that is also a PyMC +``MeasurableOp``. Its inner graph is the original *generative* subgraph — +it still draws the marginalized variable and the dependents given it, so +forward sampling (``pm.sample_prior_predictive``) works unchanged. What makes +it "marginal" is its logp implementation: + +* Each ``MarginalRV`` subclass registers a ``_logprob`` dispatch that returns + the **marginal** logp of the dependent values, with the marginalized + variable integrated/summed out. +* Optionally, it also registers ``marginalized_conditional``, which builds + ``p(marginalized | dependents)``. This is what powers + :func:`~pymc_extras.marginal.conditional` and + :func:`~pymc_extras.marginal.recover`. + +:func:`~pymc_extras.marginal.unmarginalize` is fully generic: it just inlines +the ``OpFromGraph`` and restores the marginalized variable as a free RV, so +new marginalizations get it for free. + +Adding a new marginalization +---------------------------- + +The example below adds Gamma-Poisson marginalization for the simplest +possible case — a *scalar* Gamma that is *directly* the rate of a single +Poisson: + +.. math:: + + z \sim \text{Gamma}(\alpha, \beta), \quad y \sim \text{Poisson}(z) + +The marginal of :math:`y` is +:math:`\text{NegativeBinomial}(\alpha, \beta / (\beta + 1))` and the +conditional is the conjugate posterior +:math:`z \mid y \sim \text{Gamma}(\alpha + y, \beta + 1)`. + +1. **Subclass** ``MarginalRV``. The inner graph outputs are laid out as + ``[marginalized_rv, *dependent_rvs, *rng_updates]``: + + .. code-block:: python + + from pymc_extras.model.marginal.distributions.core import MarginalRV + + + class GammaPoissonMarginalRV(MarginalRV): + """Marginalized Gamma-Poisson pair.""" + + def __init__(self, *args, marginalized_dims, **kwargs): + self.marginalized_dims = marginalized_dims + self.n_dependent_rvs = 1 + super().__init__(*args, **kwargs) + +2. **Write the rewrite** that recognizes the pattern. It tracks + ``MarginalSubgraph`` and uses ``extract_marginal_subgraph`` to get the + subgraph's inputs/outputs (RNG updates included). Make the pattern match + as restrictive as needed to keep the implementation simple — here we + require a scalar Gamma used directly as the Poisson rate, which sidesteps + transformed parameters and batch-dimension bookkeeping entirely. Return + ``None`` whenever the pattern does not apply, so other rewrites get a + chance: + + .. code-block:: python + + from pymc.distributions import Gamma, Poisson + from pytensor.graph import node_rewriter + + from pymc_extras.model.marginal.rewrites import ( + MarginalSubgraph, + extract_marginal_subgraph, + marginal_rewrites_db, + ) + + + @node_rewriter(tracks=[MarginalSubgraph]) + def gamma_poisson_marginal(fgraph, node): + if node.op.n_dependent_rvs != 1: + return None + + inputs, outputs = extract_marginal_subgraph(node) + marginalized_rv, dependent_rv = outputs[:2] + + if not ( + isinstance(marginalized_rv.owner.op, Gamma) + and isinstance(dependent_rv.owner.op, Poisson) + and marginalized_rv.type.ndim == 0 + ): + return None + + [poisson_mu] = dependent_rv.owner.op.dist_params(dependent_rv.owner) + if poisson_mu is not marginalized_rv: + return None + + typed_op = GammaPoissonMarginalRV( + inputs=inputs, + outputs=outputs, + marginalized_dims=node.op.marginalized_dims, + ) + new_outputs = typed_op(*inputs) + return list(new_outputs)[: len(node.outputs)] + + + marginal_rewrites_db.register( + "gamma_poisson_marginal", gamma_poisson_marginal, "basic" + ) + + Because the database is an ``EquilibriumDB``, rewrites run in no + particular order and repeatedly until the graph stabilizes. Be + conservative in what you match, and decline (``return None``) rather than + raise when unsure — raising is reserved for patterns that are recognizably + yours but unsupported (see ``finite_discrete_marginal`` for an example). + +3. **Register the marginal logp.** Use ``inline_ofg_outputs`` to recover the + inner generative graph expressed over the node's actual inputs, extract + the parameters, and return the logp of the dependent values with the + marginalized variable integrated out. Note that ``dist_params`` returns + the *backend* parametrization — for Gamma that is ``(alpha, scale)``, not + ``(alpha, beta)``: + + .. code-block:: python + + from pymc import NegativeBinomial + from pymc.logprob.abstract import _logprob + from pymc.logprob.basic import logp + + from pymc_extras.model.marginal.distributions.core import inline_ofg_outputs + + + @_logprob.register(GammaPoissonMarginalRV) + def gamma_poisson_marginal_logp(op, values, *inputs, **kwargs): + [value] = values + marginalized_rv, _ = inline_ofg_outputs(op, inputs)[:2] + alpha, scale = marginalized_rv.owner.op.dist_params(marginalized_rv.owner) + beta = 1 / scale + return logp(NegativeBinomial.dist(n=alpha, p=beta / (beta + 1)), value) + +4. **(Optional) Register the conditional** to support + :func:`~pymc_extras.marginal.conditional` and + :func:`~pymc_extras.marginal.recover`. It returns a + ``(sample_graph, dep_dummies)`` pair: a random variable distributed as + ``p(marginalized | dependents)``, and placeholder tensors standing in for + the dependent values (the caller substitutes the actual model variables or + observed data; see the docstring of ``marginalized_conditional`` in + ``distributions/core.py`` for the full contract): + + .. code-block:: python + + from pymc_extras.model.marginal.distributions.core import marginalized_conditional + + + @marginalized_conditional.register(GammaPoissonMarginalRV) + def gamma_poisson_conditional(op, node): + fgraph = op.fgraph.clone() + marginalized, dependent = fgraph.outputs[:2] + alpha, scale = marginalized.owner.op.dist_params(marginalized.owner) + beta = 1 / scale + + dep_dummy = dependent.type() + return Gamma.dist(alpha + dep_dummy, beta + 1), [dep_dummy] + + Without this registration everything else works; + ``conditional``/``recover`` will raise for your variable. + +That's the whole extension: + +.. code-block:: python + + import pymc as pm + from pymc_extras.marginal import conditional, marginalize + + with pm.Model() as m: + z = pm.Gamma("z", 2.0, 3.0) + y = pm.Poisson("y", mu=z, observed=4) + + marg_m = marginalize(m, [z]) # y is now NegativeBinomial under the hood + cond_m = conditional(marg_m) # z is back, as Gamma(2 + 4, 3 + 1) + +For a real-world template handling broadcasting and parameter transformations +see ``NormalNormalMarginalRV`` (``distributions/normal.py``). + +Other things you get (or must check) for free +--------------------------------------------- + +* **Support points** for initialization are derived generically from the + inner graph (``_support_point_marginal_rv`` in ``distributions/core.py``). +* **Nested marginalization** (marginalizing a variable that depends on an + already-marginalized one) is handled by ``unwrap_inner_marginal_rv``, + which inlines the existing ``MarginalRV`` and re-marks both subgraphs. + Your rewrite only ever sees plain ``MarginalSubgraph`` markers. +* **Batched dependencies**: if your marginalization needs to reason about + how batch dimensions of the marginalized variable map onto the dependents + (as the finite discrete case does), use + ``subgraph_batch_dim_connection`` from ``graph_analysis.py``. + +Tests live in ``tests/model/marginal/``; ``test_distributions.py`` and +``test_rewrites.py`` show the expected coverage for a new marginalization: +the marginal logp against a reference, forward sampling, and round-trips +through ``unmarginalize``/``recover``. diff --git a/docs/developer/index.rst b/docs/developer/index.rst new file mode 100644 index 000000000..16886d82e --- /dev/null +++ b/docs/developer/index.rst @@ -0,0 +1,12 @@ +Developer Guide +=============== + +Documentation of PyMC Extras internals, aimed at contributors rather than +users. For the general contribution workflow (setup, tests, pull requests), +see the +`contributing guidelines `_. + +.. toctree:: + :maxdepth: 1 + + extending_marginalization diff --git a/docs/index.rst b/docs/index.rst index 7bf8edbbd..0fbf3a848 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,8 +12,26 @@ PyMC Extras :target: https://codecov.io/gh/pymc-devs/pymc-extras - -Where we grow the next batch of cool PyMC features +PyMC Extras extends `PyMC `_ with additional +distributions, inference methods, and model transformations. It is maintained +by the PyMC team and hosts functionality that is too specialized for the core +library, but useful enough that you shouldn't have to write it yourself. + +What's inside +============= + +* :doc:`Automatic marginalization `: exact for finite + discrete and conjugate variables, approximate via the Laplace approximation. +* :doc:`Alternative inference methods `: Pathfinder, DADVI, + INLA, Laplace approximation, and better MAP estimation. +* :doc:`Statespace models `: SARIMAX, VARMAX, ETS, and + structural time series with Kalman filtering. +* :doc:`Additional distributions ` such as + ``DiscreteMarkovChain``, ``GeneralizedPoisson``, and ``GenExtreme``. +* :doc:`Model building tools ` like the ``as_model`` decorator and + the ``ModelBuilder`` base class. + +See the full :doc:`api_reference` for everything else. Installation ============ @@ -32,7 +50,11 @@ For the development version, you can install directly from GitHub: Contributing ============ -We welcome contributions from interested individuals or groups! For information about contributing to PyMC Extras check out our instructions, policies, and guidelines `here `_. +We welcome contributions from interested individuals or groups! For information +about contributing to PyMC Extras check out our instructions, policies, and +guidelines `here `_. +If you want to extend the internals (e.g. add a new marginalization), start +with the :doc:`developer/index`. Contributors ============ @@ -42,3 +64,4 @@ See the `GitHub contributor page