Skip to content

Commit ee8cc37

Browse files
Implement basic INLA interface (#533)
1 parent 59f55df commit ee8cc37

12 files changed

Lines changed: 1094 additions & 102 deletions

File tree

conda-envs/environment-test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,5 @@ dependencies:
1919
- pip:
2020
- jax
2121
- blackjax
22+
- pytensor>=3.0.4
23+
- preliz>=0.25

notebooks/INLA Example.ipynb

Lines changed: 634 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from pymc_extras.inference.INLA.inla import fit_INLA
2+
3+
__all__ = ["fit_INLA"]

pymc_extras/inference/INLA/inla.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import warnings
2+
3+
import pymc as pm
4+
5+
from pytensor.tensor import TensorLike, TensorVariable, as_tensor
6+
from xarray import DataTree
7+
8+
from pymc_extras.model.marginal.marginal_model import marginalize
9+
10+
11+
def fit_INLA(
12+
x: TensorVariable,
13+
Q: TensorLike,
14+
minimizer_seed: int = 42,
15+
model: pm.Model | None = None,
16+
minimizer_kwargs: dict = {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}},
17+
return_latent_posteriors: bool = False,
18+
**sampler_kwargs,
19+
) -> DataTree:
20+
r"""
21+
Performs inference over a linear mixed model using Integrated Nested Laplace Approximations (INLA). Assumes a model of the form:
22+
23+
.. math::
24+
25+
\theta \rightarrow x \rightarrow y
26+
27+
Where the prior on the hyperparameters :math:`\pi(\theta)` is arbitrary, the prior on the latent field is Gaussian (and in precision form): :math:`\pi(x) = N(\mu, Q^{-1})` and the latent field is linked to the observables $y$ through some linear map.
28+
29+
As it stands, INLA in PyMC Extras is currently experimental.
30+
31+
Parameters
32+
----------
33+
x: TensorVariable
34+
The latent gaussian to marginalize out.
35+
Q: TensorLike
36+
Precision matrix of the latent field.
37+
minimizer_seed: int
38+
Seed for random initialisation of the minimum point x*.
39+
model: pm.Model
40+
PyMC model.
41+
minimizer_kwargs:
42+
Kwargs to pass to pytensor.optimize.minimize during the optimization step maximizing logp(x | y, params).
43+
returned_latent_posteriors:
44+
If True, also return posteriors for the latent Gaussian field (currently unsupported).
45+
sampler_kwargs:
46+
Kwargs to pass to pm.sample.
47+
48+
Returns
49+
-------
50+
DataTree
51+
The inference data containing the results of the INLA algorithm.
52+
53+
Examples
54+
--------
55+
.. code:: ipython
56+
57+
In [1]: rng = np.random.default_rng(123)
58+
...: n = 10000
59+
...: d = 3
60+
...: mu_mu = 10 * rng.random(d)
61+
...: mu_true = rng.random(d)
62+
...: tau = np.identity(d)
63+
...: cov = np.linalg.inv(tau)
64+
...: y_obs = rng.multivariate_normal(mean=mu_true, cov=cov, size=n)
65+
66+
In [2]: with pm.Model() as model:
67+
...: mu = pm.MvNormal("mu", mu=mu_mu, tau=tau)
68+
...: x = pm.MvNormal("x", mu=mu, tau=tau)
69+
...: y = pm.MvNormal("y", mu=x, tau=tau, observed=y_obs)
70+
71+
...: idata = pmx.fit(
72+
...: method="INLA",
73+
...: x=x,
74+
...: Q=tau,
75+
...: return_latent_posteriors=False,
76+
...: )
77+
78+
In[3]: posterior_mean_true = (mu_mu + mu_true) / 2
79+
...: posterior_mean_inla = idata.posterior.mu.mean(axis=(0, 1)).values
80+
...: print(posterior_mean_true)
81+
...: print(posterior_mean_inla)
82+
83+
Out[3]:
84+
[3.50394522 0.35705804 1.50784662]
85+
[3.48732847 0.35738072 1.46851421]
86+
87+
"""
88+
warnings.warn(
89+
"INLA is currently experimental. Please see the INLA Roadmap for more info: https://github.com/pymc-devs/pymc-extras/issues/340.",
90+
UserWarning,
91+
)
92+
model = pm.modelcontext(model)
93+
94+
# Marginalize out the latent field
95+
marginalize_kwargs = {
96+
"Q": as_tensor(Q),
97+
"minimizer_seed": minimizer_seed,
98+
"minimizer_kwargs": minimizer_kwargs,
99+
}
100+
marginal_model = marginalize(model, x, use_laplace=True, **marginalize_kwargs)
101+
102+
# Sample over the hyperparameters
103+
if not return_latent_posteriors:
104+
idata = pm.sample(model=marginal_model, **sampler_kwargs)
105+
return idata
106+
107+
# Unmarginalize stuff
108+
raise NotImplementedError(
109+
"Inference over the latent field with INLA is currently unsupported. Set return_latent_posteriors to False"
110+
)

pymc_extras/inference/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from pymc_extras.inference.dadvi.dadvi import fit_dadvi
1616
from pymc_extras.inference.fit import fit
17+
from pymc_extras.inference.INLA.inla import fit_INLA
1718
from pymc_extras.inference.laplace_approx.find_map import find_MAP
1819
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
1920
from pymc_extras.inference.pathfinder.pathfinder import fit_blackjax_pathfinder, fit_pathfinder
@@ -25,4 +26,5 @@
2526
"fit_laplace",
2627
"fit_pathfinder",
2728
"fit_dadvi",
29+
"fit_INLA",
2830
]

pymc_extras/inference/fit.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,22 @@ def fit(method: str, **kwargs) -> DataTree:
3636

3737
return fit_pathfinder(**kwargs)
3838

39-
if method == "laplace":
40-
from pymc_extras.inference import fit_laplace
39+
elif method == "laplace":
40+
from pymc_extras.inference.laplace_approx import fit_laplace
4141

4242
return fit_laplace(**kwargs)
4343

44-
if method == "dadvi":
44+
elif method == "INLA":
45+
from pymc_extras.inference.INLA import fit_INLA
46+
47+
return fit_INLA(**kwargs)
48+
49+
elif method == "dadvi":
4550
from pymc_extras.inference import fit_dadvi
4651

4752
return fit_dadvi(**kwargs)
53+
54+
else:
55+
raise ValueError(
56+
f"method '{method}' not supported. Use one of 'pathfinder', 'laplace' or 'INLA'."
57+
)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
2+
3+
__all__ = ["fit_laplace"]

pymc_extras/model/marginal/distributions.py

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from collections.abc import Sequence
44

55
import numpy as np
6+
import pytensor
67
import pytensor.tensor as pt
78

89
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
910
from pymc.distributions.distribution import _support_point, support_point
11+
from pymc.distributions.multivariate import _logdet_from_cholesky
1012
from pymc.logprob.abstract import MeasurableOp, _logprob
1113
from pymc.logprob.basic import conditional_logp, logp
1214
from pymc.pytensorf import constant_fold
@@ -17,7 +19,8 @@
1719
from pytensor.graph.replace import clone_replace, graph_replace
1820
from pytensor.scan import map as scan_map
1921
from pytensor.scan import scan
20-
from pytensor.tensor import TensorVariable
22+
from pytensor.tensor import TensorLike, TensorVariable
23+
from pytensor.tensor.optimize import minimize
2124
from pytensor.tensor.random.type import RandomType
2225

2326
from pymc_extras.distributions import DiscreteMarkovChain
@@ -134,6 +137,24 @@ class MarginalDiscreteMarkovChainRV(MarginalRV):
134137
"""Base class for Marginalized Discrete Markov Chain RVs"""
135138

136139

140+
class MarginalLaplaceRV(MarginalRV):
141+
"""Base class for Marginalized Laplace-Approximated RVs.
142+
143+
Estimates log likelihood using Laplace approximations.
144+
"""
145+
146+
def __init__(
147+
self,
148+
*args,
149+
minimizer_seed: int,
150+
minimizer_kwargs: dict = {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}},
151+
**kwargs,
152+
) -> None:
153+
self.minimizer_seed = minimizer_seed
154+
self.minimizer_kwargs = minimizer_kwargs
155+
super().__init__(*args, **kwargs)
156+
157+
137158
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
138159
op = rv.owner.op
139160
dist_params = rv.owner.op.dist_params(rv.owner)
@@ -377,3 +398,137 @@ def step_alpha(logp_emission, log_alpha, log_P):
377398
warn_non_separable_logp(values)
378399
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
379400
return joint_logp, *dummy_logps
401+
402+
403+
def _precision_mv_normal_logp(value: TensorLike, mean: TensorLike, tau: TensorLike):
404+
"""
405+
Compute the log likelihood of a multivariate normal distribution in precision form. May be phased out - see https://github.com/pymc-devs/pymc/pull/7895
406+
407+
Parameters
408+
----------
409+
value: TensorLike
410+
Query point to compute the log prob at.
411+
mean: TensorLike
412+
Mean vector of the Gaussian,
413+
tau: TensorLike
414+
Precision matrix of the Gaussian (i.e. cov = inv(tau))
415+
416+
Returns
417+
-------
418+
logp: TensorLike
419+
Log likelihood at value.
420+
posdef: TensorLike
421+
Boolean indicating whether the precision matrix is positive definite.
422+
"""
423+
k = value.shape[-1].astype("floatX")
424+
425+
delta = value - mean
426+
quadratic_form = delta.T @ tau @ delta
427+
logdet, posdef = _logdet_from_cholesky(pt.linalg.cholesky(tau, lower=True))
428+
logp = -0.5 * (k * pt.log(2 * np.pi) + quadratic_form) + logdet
429+
430+
return logp, posdef
431+
432+
433+
def get_laplace_approx(
434+
log_likelihood: TensorVariable,
435+
logp_objective: TensorVariable,
436+
x: TensorVariable,
437+
x0_init: TensorLike,
438+
Q: TensorLike,
439+
minimizer_kwargs: dict = {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}},
440+
):
441+
"""
442+
Compute the laplace approximation logp_G(x | y, params) of some variable x.
443+
444+
Parameters
445+
----------
446+
log_likelihood: TensorVariable
447+
Model likelihood logp(y | x, params).
448+
logp_objective: TensorVariable
449+
Obective log likelihood to maximize, logp(x | y, params) (up to some constant in x).
450+
x: TensorVariable
451+
Variable to be laplace approximated.
452+
x0_init: TensorLike
453+
Initial guess for minimization.
454+
Q: TensorLike
455+
Precision matrix of x.
456+
minimizer_kwargs:
457+
Kwargs to pass to pytensor.optimize.minimize.
458+
459+
Returns
460+
-------
461+
x0: TensorVariable
462+
x*, the maximizer of logp(x | y, params) in x.
463+
log_laplace_approx: TensorVariable
464+
Laplace approximation of logp(x | y, params) evaluated at x.
465+
"""
466+
# Maximize log(p(x | y, params)) wrt x to find mode x0
467+
# This step is currently bottlenecking the logp calculation.
468+
x0, _ = minimize(
469+
objective=-logp_objective, # logp(x | y, params) = logp(y | x, params) + logp(x | params) + const (const omitted during minimization)
470+
x=x,
471+
use_vectorized_jac=True,
472+
**minimizer_kwargs,
473+
)
474+
475+
# Set minimizer initialisation to be random
476+
x0 = pytensor.graph.replace.graph_replace(x0, {x: x0_init})
477+
478+
# This step is also expensive (but not as much as minimize). Could be made more efficient by recycling hessian from the minimizer step, however that requires a bespoke algorithm described in Rasmussen & Williams
479+
# since the general optimisation scheme maximises logp(x | y, params) rather than logp(y | x, params), and thus the hessian that comes out of methods
480+
# like L-BFGS-B is in fact not the hessian of logp(y | x, params)
481+
# TODO: Use vectorized hessian?
482+
hess = pytensor.gradient.hessian(log_likelihood, x)
483+
484+
# Evaluate logp of Laplace approx of logp(x | y, params) at some point x
485+
tau = Q - hess
486+
mu = x0
487+
log_laplace_approx, _ = _precision_mv_normal_logp(x, mu, tau)
488+
489+
return x0, log_laplace_approx
490+
491+
492+
@_logprob.register(MarginalLaplaceRV)
493+
def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs_and_Q, **kwargs):
494+
# Get Q and remove it from the graph (stored as a dummy input)
495+
*inputs, Q = inputs_and_Q
496+
497+
# Clone the inner RV graph of the Marginalized RV
498+
x, *inner_rvs = inline_ofg_outputs(op, inputs)
499+
500+
# Obtain the joint_logp graph of the inner RV graph
501+
inner_rv_values = dict(zip(inner_rvs, values))
502+
503+
marginalized_vv = x.clone()
504+
rv_values = inner_rv_values | {x: marginalized_vv}
505+
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
506+
507+
# logp(x | params)
508+
logp_x = logps_dict.pop(marginalized_vv).sum()
509+
510+
# logp(y | x, params)
511+
logp_y = pt.sum([logp_term.sum() for value, logp_term in logps_dict.items()])
512+
513+
# logp_total = logp(y | x, params) + logp(x | params) (i.e. logp(x | y, params) up to a constant in x)
514+
logp_total = logp_x + logp_y
515+
516+
# Set minimizer initialisation to be random (TODO: Let pymc accept this one, maybe when rng is constant)
517+
# TODO: Use newer pytensor helper
518+
d = pt.prod(constant_fold(tuple(x.shape), raise_not_constant=True))
519+
x0_init = pt.ones(d)
520+
521+
# Obtain laplace approx for logp(x | y, params)
522+
x0, log_laplace_approx = get_laplace_approx(
523+
logp_y,
524+
logp_total,
525+
x=marginalized_vv,
526+
x0_init=x0_init,
527+
Q=Q,
528+
minimizer_kwargs=op.minimizer_kwargs,
529+
)
530+
531+
# logp(y | params) = logp(y | x, params) + logp(x | params) - logp(x | y, params)
532+
# TODO: Can we recover the elementwise logp?
533+
marginal_likelihood = logp_total - log_laplace_approx
534+
return graph_replace(marginal_likelihood, {marginalized_vv: x0})

0 commit comments

Comments
 (0)