|
3 | 3 | from collections.abc import Sequence |
4 | 4 |
|
5 | 5 | import numpy as np |
| 6 | +import pytensor |
6 | 7 | import pytensor.tensor as pt |
7 | 8 |
|
8 | 9 | from pymc.distributions import Bernoulli, Categorical, DiscreteUniform |
9 | 10 | from pymc.distributions.distribution import _support_point, support_point |
| 11 | +from pymc.distributions.multivariate import _logdet_from_cholesky |
10 | 12 | from pymc.logprob.abstract import MeasurableOp, _logprob |
11 | 13 | from pymc.logprob.basic import conditional_logp, logp |
12 | 14 | from pymc.pytensorf import constant_fold |
|
17 | 19 | from pytensor.graph.replace import clone_replace, graph_replace |
18 | 20 | from pytensor.scan import map as scan_map |
19 | 21 | from pytensor.scan import scan |
20 | | -from pytensor.tensor import TensorVariable |
| 22 | +from pytensor.tensor import TensorLike, TensorVariable |
| 23 | +from pytensor.tensor.optimize import minimize |
21 | 24 | from pytensor.tensor.random.type import RandomType |
22 | 25 |
|
23 | 26 | from pymc_extras.distributions import DiscreteMarkovChain |
@@ -134,6 +137,24 @@ class MarginalDiscreteMarkovChainRV(MarginalRV): |
134 | 137 | """Base class for Marginalized Discrete Markov Chain RVs""" |
135 | 138 |
|
136 | 139 |
|
| 140 | +class MarginalLaplaceRV(MarginalRV): |
| 141 | + """Base class for Marginalized Laplace-Approximated RVs. |
| 142 | +
|
| 143 | + Estimates log likelihood using Laplace approximations. |
| 144 | + """ |
| 145 | + |
| 146 | + def __init__( |
| 147 | + self, |
| 148 | + *args, |
| 149 | + minimizer_seed: int, |
| 150 | + minimizer_kwargs: dict = {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}}, |
| 151 | + **kwargs, |
| 152 | + ) -> None: |
| 153 | + self.minimizer_seed = minimizer_seed |
| 154 | + self.minimizer_kwargs = minimizer_kwargs |
| 155 | + super().__init__(*args, **kwargs) |
| 156 | + |
| 157 | + |
137 | 158 | def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]: |
138 | 159 | op = rv.owner.op |
139 | 160 | dist_params = rv.owner.op.dist_params(rv.owner) |
@@ -377,3 +398,137 @@ def step_alpha(logp_emission, log_alpha, log_P): |
377 | 398 | warn_non_separable_logp(values) |
378 | 399 | dummy_logps = (DUMMY_ZERO,) * (len(values) - 1) |
379 | 400 | return joint_logp, *dummy_logps |
| 401 | + |
| 402 | + |
| 403 | +def _precision_mv_normal_logp(value: TensorLike, mean: TensorLike, tau: TensorLike): |
| 404 | + """ |
| 405 | + Compute the log likelihood of a multivariate normal distribution in precision form. May be phased out - see https://github.com/pymc-devs/pymc/pull/7895 |
| 406 | +
|
| 407 | + Parameters |
| 408 | + ---------- |
| 409 | + value: TensorLike |
| 410 | + Query point to compute the log prob at. |
| 411 | + mean: TensorLike |
| 412 | + Mean vector of the Gaussian, |
| 413 | + tau: TensorLike |
| 414 | + Precision matrix of the Gaussian (i.e. cov = inv(tau)) |
| 415 | +
|
| 416 | + Returns |
| 417 | + ------- |
| 418 | + logp: TensorLike |
| 419 | + Log likelihood at value. |
| 420 | + posdef: TensorLike |
| 421 | + Boolean indicating whether the precision matrix is positive definite. |
| 422 | + """ |
| 423 | + k = value.shape[-1].astype("floatX") |
| 424 | + |
| 425 | + delta = value - mean |
| 426 | + quadratic_form = delta.T @ tau @ delta |
| 427 | + logdet, posdef = _logdet_from_cholesky(pt.linalg.cholesky(tau, lower=True)) |
| 428 | + logp = -0.5 * (k * pt.log(2 * np.pi) + quadratic_form) + logdet |
| 429 | + |
| 430 | + return logp, posdef |
| 431 | + |
| 432 | + |
| 433 | +def get_laplace_approx( |
| 434 | + log_likelihood: TensorVariable, |
| 435 | + logp_objective: TensorVariable, |
| 436 | + x: TensorVariable, |
| 437 | + x0_init: TensorLike, |
| 438 | + Q: TensorLike, |
| 439 | + minimizer_kwargs: dict = {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}}, |
| 440 | +): |
| 441 | + """ |
| 442 | + Compute the laplace approximation logp_G(x | y, params) of some variable x. |
| 443 | +
|
| 444 | + Parameters |
| 445 | + ---------- |
| 446 | + log_likelihood: TensorVariable |
| 447 | + Model likelihood logp(y | x, params). |
| 448 | + logp_objective: TensorVariable |
| 449 | + Obective log likelihood to maximize, logp(x | y, params) (up to some constant in x). |
| 450 | + x: TensorVariable |
| 451 | + Variable to be laplace approximated. |
| 452 | + x0_init: TensorLike |
| 453 | + Initial guess for minimization. |
| 454 | + Q: TensorLike |
| 455 | + Precision matrix of x. |
| 456 | + minimizer_kwargs: |
| 457 | + Kwargs to pass to pytensor.optimize.minimize. |
| 458 | +
|
| 459 | + Returns |
| 460 | + ------- |
| 461 | + x0: TensorVariable |
| 462 | + x*, the maximizer of logp(x | y, params) in x. |
| 463 | + log_laplace_approx: TensorVariable |
| 464 | + Laplace approximation of logp(x | y, params) evaluated at x. |
| 465 | + """ |
| 466 | + # Maximize log(p(x | y, params)) wrt x to find mode x0 |
| 467 | + # This step is currently bottlenecking the logp calculation. |
| 468 | + x0, _ = minimize( |
| 469 | + objective=-logp_objective, # logp(x | y, params) = logp(y | x, params) + logp(x | params) + const (const omitted during minimization) |
| 470 | + x=x, |
| 471 | + use_vectorized_jac=True, |
| 472 | + **minimizer_kwargs, |
| 473 | + ) |
| 474 | + |
| 475 | + # Set minimizer initialisation to be random |
| 476 | + x0 = pytensor.graph.replace.graph_replace(x0, {x: x0_init}) |
| 477 | + |
| 478 | + # This step is also expensive (but not as much as minimize). Could be made more efficient by recycling hessian from the minimizer step, however that requires a bespoke algorithm described in Rasmussen & Williams |
| 479 | + # since the general optimisation scheme maximises logp(x | y, params) rather than logp(y | x, params), and thus the hessian that comes out of methods |
| 480 | + # like L-BFGS-B is in fact not the hessian of logp(y | x, params) |
| 481 | + # TODO: Use vectorized hessian? |
| 482 | + hess = pytensor.gradient.hessian(log_likelihood, x) |
| 483 | + |
| 484 | + # Evaluate logp of Laplace approx of logp(x | y, params) at some point x |
| 485 | + tau = Q - hess |
| 486 | + mu = x0 |
| 487 | + log_laplace_approx, _ = _precision_mv_normal_logp(x, mu, tau) |
| 488 | + |
| 489 | + return x0, log_laplace_approx |
| 490 | + |
| 491 | + |
| 492 | +@_logprob.register(MarginalLaplaceRV) |
| 493 | +def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs_and_Q, **kwargs): |
| 494 | + # Get Q and remove it from the graph (stored as a dummy input) |
| 495 | + *inputs, Q = inputs_and_Q |
| 496 | + |
| 497 | + # Clone the inner RV graph of the Marginalized RV |
| 498 | + x, *inner_rvs = inline_ofg_outputs(op, inputs) |
| 499 | + |
| 500 | + # Obtain the joint_logp graph of the inner RV graph |
| 501 | + inner_rv_values = dict(zip(inner_rvs, values)) |
| 502 | + |
| 503 | + marginalized_vv = x.clone() |
| 504 | + rv_values = inner_rv_values | {x: marginalized_vv} |
| 505 | + logps_dict = conditional_logp(rv_values=rv_values, **kwargs) |
| 506 | + |
| 507 | + # logp(x | params) |
| 508 | + logp_x = logps_dict.pop(marginalized_vv).sum() |
| 509 | + |
| 510 | + # logp(y | x, params) |
| 511 | + logp_y = pt.sum([logp_term.sum() for value, logp_term in logps_dict.items()]) |
| 512 | + |
| 513 | + # logp_total = logp(y | x, params) + logp(x | params) (i.e. logp(x | y, params) up to a constant in x) |
| 514 | + logp_total = logp_x + logp_y |
| 515 | + |
| 516 | + # Set minimizer initialisation to be random (TODO: Let pymc accept this one, maybe when rng is constant) |
| 517 | + # TODO: Use newer pytensor helper |
| 518 | + d = pt.prod(constant_fold(tuple(x.shape), raise_not_constant=True)) |
| 519 | + x0_init = pt.ones(d) |
| 520 | + |
| 521 | + # Obtain laplace approx for logp(x | y, params) |
| 522 | + x0, log_laplace_approx = get_laplace_approx( |
| 523 | + logp_y, |
| 524 | + logp_total, |
| 525 | + x=marginalized_vv, |
| 526 | + x0_init=x0_init, |
| 527 | + Q=Q, |
| 528 | + minimizer_kwargs=op.minimizer_kwargs, |
| 529 | + ) |
| 530 | + |
| 531 | + # logp(y | params) = logp(y | x, params) + logp(x | params) - logp(x | y, params) |
| 532 | + # TODO: Can we recover the elementwise logp? |
| 533 | + marginal_likelihood = logp_total - log_laplace_approx |
| 534 | + return graph_replace(marginal_likelihood, {marginalized_vv: x0}) |
0 commit comments