Background
Currently every MCMC inner kernel accepts logdensity_fn at call time:
kernel(rng_key, state, logdensity_fn, step_size, ...)
The original motivation (see #495) was SMC tempering: at each SMC step the target density changes (λ · loglikelihood + logprior), so a new Python closure is constructed and passed to the kernel on every step. Keeping logdensity_fn as a call-time argument was the only way to support this without rebuilding the kernel.
The problem
Python callables are static from JAX's perspective. A new closure object at each SMC step is a potential recompilation trigger. It also means logdensity_fn leaks into every kernel call signature even for non-SMC use cases where it never changes.
Proposed solution
Apply the same pattern recently used for integration_steps_fn → integration_steps_params in #880:
- Move
logdensity_fn from the inner kernel call to build_kernel, with an extended signature logdensity_fn(position, *logdensity_params) -> float.
- Add
logdensity_params: tuple = () to the inner kernel call.
- SMC tempering passes
logdensity_params=(lmbda,) as a traced JAX scalar instead of constructing a new Python closure each step.
# build time — logdensity_fn closed over once
built_kernel = hmc.build_kernel(
logdensity_fn=lambda pos, lmbda: logprior(pos) + lmbda * loglikelihood(pos)
)
# call time — lmbda is a traced JAX array, no new closure
kernel(rng_key, state, step_size=..., logdensity_params=(current_lmbda,))
For the common (non-SMC) case logdensity_params=() and logdensity_fn keeps its current (position) -> float signature — no change for most users at build_kernel time, and the inner kernel call simply drops the logdensity_fn argument.
Trade-offs
Pros
- No Python closure construction per SMC step → no JAX recompilation risk from changing density
- Cleaner kernel call sites —
logdensity_fn no longer passed on every step
- Consistent with the
integration_steps_params pattern already in the codebase
Cons
- Breaking API change across every MCMC kernel, every adapter, and all SMC plumbing — large blast radius
logdensity_fn baked in at build_kernel time makes the low-level kernel slightly less composable across different densities
- Most kernels don't need
logdensity_params; adds a parameter that is a no-op for the majority of use cases
Questions for discussion
- Is the recompilation risk in practice significant enough to justify the blast radius?
- Should
logdensity_params be a tuple (positional, like integration_steps_params) or a dict (named, more readable for e.g. {"lmbda": 0.5})?
- Could this be done incrementally — e.g. deprecate
logdensity_fn as a kernel call arg first, keep it as a pass-through shim, then remove in a later release?
Background
Currently every MCMC inner kernel accepts
logdensity_fnat call time:The original motivation (see #495) was SMC tempering: at each SMC step the target density changes (
λ · loglikelihood + logprior), so a new Python closure is constructed and passed to the kernel on every step. Keepinglogdensity_fnas a call-time argument was the only way to support this without rebuilding the kernel.The problem
Python callables are static from JAX's perspective. A new closure object at each SMC step is a potential recompilation trigger. It also means
logdensity_fnleaks into every kernel call signature even for non-SMC use cases where it never changes.Proposed solution
Apply the same pattern recently used for
integration_steps_fn→integration_steps_paramsin #880:logdensity_fnfrom the inner kernel call tobuild_kernel, with an extended signaturelogdensity_fn(position, *logdensity_params) -> float.logdensity_params: tuple = ()to the inner kernel call.logdensity_params=(lmbda,)as a traced JAX scalar instead of constructing a new Python closure each step.For the common (non-SMC) case
logdensity_params=()andlogdensity_fnkeeps its current(position) -> floatsignature — no change for most users atbuild_kerneltime, and the inner kernel call simply drops thelogdensity_fnargument.Trade-offs
Pros
logdensity_fnno longer passed on every stepintegration_steps_paramspattern already in the codebaseCons
logdensity_fnbaked in atbuild_kerneltime makes the low-level kernel slightly less composable across different densitieslogdensity_params; adds a parameter that is a no-op for the majority of use casesQuestions for discussion
logdensity_paramsbe atuple(positional, likeintegration_steps_params) or adict(named, more readable for e.g.{"lmbda": 0.5})?logdensity_fnas a kernel call arg first, keep it as a pass-through shim, then remove in a later release?