Skip to content

Move logdensity_fn to build_kernel and pass logdensity_params at call time #881

Description

@junpenglao

Background

Currently every MCMC inner kernel accepts logdensity_fn at call time:

kernel(rng_key, state, logdensity_fn, step_size, ...)

The original motivation (see #495) was SMC tempering: at each SMC step the target density changes (λ · loglikelihood + logprior), so a new Python closure is constructed and passed to the kernel on every step. Keeping logdensity_fn as a call-time argument was the only way to support this without rebuilding the kernel.

The problem

Python callables are static from JAX's perspective. A new closure object at each SMC step is a potential recompilation trigger. It also means logdensity_fn leaks into every kernel call signature even for non-SMC use cases where it never changes.

Proposed solution

Apply the same pattern recently used for integration_steps_fnintegration_steps_params in #880:

  1. Move logdensity_fn from the inner kernel call to build_kernel, with an extended signature logdensity_fn(position, *logdensity_params) -> float.
  2. Add logdensity_params: tuple = () to the inner kernel call.
  3. SMC tempering passes logdensity_params=(lmbda,) as a traced JAX scalar instead of constructing a new Python closure each step.
# build time — logdensity_fn closed over once
built_kernel = hmc.build_kernel(
    logdensity_fn=lambda pos, lmbda: logprior(pos) + lmbda * loglikelihood(pos)
)

# call time — lmbda is a traced JAX array, no new closure
kernel(rng_key, state, step_size=..., logdensity_params=(current_lmbda,))

For the common (non-SMC) case logdensity_params=() and logdensity_fn keeps its current (position) -> float signature — no change for most users at build_kernel time, and the inner kernel call simply drops the logdensity_fn argument.

Trade-offs

Pros

  • No Python closure construction per SMC step → no JAX recompilation risk from changing density
  • Cleaner kernel call sites — logdensity_fn no longer passed on every step
  • Consistent with the integration_steps_params pattern already in the codebase

Cons

  • Breaking API change across every MCMC kernel, every adapter, and all SMC plumbing — large blast radius
  • logdensity_fn baked in at build_kernel time makes the low-level kernel slightly less composable across different densities
  • Most kernels don't need logdensity_params; adds a parameter that is a no-op for the majority of use cases

Questions for discussion

  • Is the recompilation risk in practice significant enough to justify the blast radius?
  • Should logdensity_params be a tuple (positional, like integration_steps_params) or a dict (named, more readable for e.g. {"lmbda": 0.5})?
  • Could this be done incrementally — e.g. deprecate logdensity_fn as a kernel call arg first, keep it as a pass-through shim, then remove in a later release?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions