Algorithm category
MCMC (gradient-free)
Paper reference
Title: Slice Sampling Reparameterization Gradients
Authors: David M. Zoltowski, Diana Cai, Ryan P. Adams
Year: 2021
Venue: NeurIPS
Link: https://proceedings.neurips.cc/paper/2021/hash/c5c3d4fe6b2cc463c7d7ecba17cc9de7-Abstract.html
Existing implementations
Benefit and motivation
This algorithm would add a slice-sampling algorithm that supports a use case BlackJAX does not currently cover: pathwise differentiation of samples with respect to a parameterized log-density. In ordinary forward use it still behaves like a normal MCMC kernel, but its main value is that the transitions can be differentiated with respect to parameters of the log-density.
The main motivation is optimizating computations that involve parameterized unnormalized densities. In that setting, this sampler provides pathwise sample gradients rather than relying on score-function estimators. It works similarly to the reparameterization trick, except that it enables sampling from general parameterized distribution families.
It might also be possible to extend some existing implementations to support differentiable samples, but this algorithm seems like a good place to start since it has existing JAX code by the original authors I could adapt. Slice sampling is rejection-free, which makes this much easier to implement than algorithms that include accept-reject steps.
Comparison to existing BlackJAX algorithms
The closest existing BlackJAX sampler is elliptical_slice, since both are slice-sampling methods. However, this algorithm implements a new functionality not currently existing in BlackJAX.
Estimated JAX implementation effort
L — Already existing research implementation is easy to adapt.
JAX-specific implementation notes
- The public API matches the BlackJAX conventions and the sampler can be called as usual, but optionally additional parameters can be passed to the log-density as explicit positional arguments following the
position argument. In my implementation these must then be passed to init and step. Another option might be to include these in the state. Happy to discuss if people have other ideas for the interface.
# Assume signature logdensity_fn(position, param1, param2)
algo = blackjax.reparameterized_slice(logdensity_fn)
keys = jax.random.split(rng_key, num_steps)
def loss_fn(param1, param2):
""" Calculates E[fn(X)] where X ~ logdensity(param1, param2) """
state = algo.init(position, param1, param2)
def one_step(state, key):
state, _ = algo.step(key, state, param1, param2)
return state, state.position
_, positions = jax.lax.scan(one_step, state, keys)
return jnp.mean(jax.vmap(fn)(positions))
grad_param1, grad_param2 = jax.grad(loss_fn, argnums=(0, 1))(param1, param2)
- My implementation uses
jax.custom_vjp on a single sampler step to conform to the interface of other BlackJAX samplers, and unlike the authors' implementation which implements the backwards pass on an entire sample trajectory.
Are you willing to open a PR?
Yes — I have implemented this.
Intentional scope limits for the initial PR
- No attempt to generalize pathwise differentiation across any existing BlackJAX samplers in the same PR.
Algorithm category
MCMC (gradient-free)
Paper reference
Title: Slice Sampling Reparameterization Gradients
Authors: David M. Zoltowski, Diana Cai, Ryan P. Adams
Year: 2021
Venue: NeurIPS
Link: https://proceedings.neurips.cc/paper/2021/hash/c5c3d4fe6b2cc463c7d7ecba17cc9de7-Abstract.html
Existing implementations
Benefit and motivation
This algorithm would add a slice-sampling algorithm that supports a use case BlackJAX does not currently cover: pathwise differentiation of samples with respect to a parameterized log-density. In ordinary forward use it still behaves like a normal MCMC kernel, but its main value is that the transitions can be differentiated with respect to parameters of the log-density.
The main motivation is optimizating computations that involve parameterized unnormalized densities. In that setting, this sampler provides pathwise sample gradients rather than relying on score-function estimators. It works similarly to the reparameterization trick, except that it enables sampling from general parameterized distribution families.
It might also be possible to extend some existing implementations to support differentiable samples, but this algorithm seems like a good place to start since it has existing JAX code by the original authors I could adapt. Slice sampling is rejection-free, which makes this much easier to implement than algorithms that include accept-reject steps.
Comparison to existing BlackJAX algorithms
The closest existing BlackJAX sampler is
elliptical_slice, since both are slice-sampling methods. However, this algorithm implements a new functionality not currently existing in BlackJAX.Estimated JAX implementation effort
L — Already existing research implementation is easy to adapt.
JAX-specific implementation notes
positionargument. In my implementation these must then be passed toinitandstep. Another option might be to include these in the state. Happy to discuss if people have other ideas for the interface.jax.custom_vjpon a single sampler step to conform to the interface of other BlackJAX samplers, and unlike the authors' implementation which implements the backwards pass on an entire sample trajectory.Are you willing to open a PR?
Yes — I have implemented this.
Intentional scope limits for the initial PR