Skip to content

Implement Slice Sampler Reparameterization Gradients #845

Description

@ischeinfeld

Algorithm category

MCMC (gradient-free)

Paper reference

Title: Slice Sampling Reparameterization Gradients
Authors: David M. Zoltowski, Diana Cai, Ryan P. Adams
Year: 2021
Venue: NeurIPS
Link: https://proceedings.neurips.cc/paper/2021/hash/c5c3d4fe6b2cc463c7d7ecba17cc9de7-Abstract.html

Existing implementations

Benefit and motivation

This algorithm would add a slice-sampling algorithm that supports a use case BlackJAX does not currently cover: pathwise differentiation of samples with respect to a parameterized log-density. In ordinary forward use it still behaves like a normal MCMC kernel, but its main value is that the transitions can be differentiated with respect to parameters of the log-density.

The main motivation is optimizating computations that involve parameterized unnormalized densities. In that setting, this sampler provides pathwise sample gradients rather than relying on score-function estimators. It works similarly to the reparameterization trick, except that it enables sampling from general parameterized distribution families.

It might also be possible to extend some existing implementations to support differentiable samples, but this algorithm seems like a good place to start since it has existing JAX code by the original authors I could adapt. Slice sampling is rejection-free, which makes this much easier to implement than algorithms that include accept-reject steps.

Comparison to existing BlackJAX algorithms

The closest existing BlackJAX sampler is elliptical_slice, since both are slice-sampling methods. However, this algorithm implements a new functionality not currently existing in BlackJAX.

Estimated JAX implementation effort

L — Already existing research implementation is easy to adapt.

JAX-specific implementation notes

  • The public API matches the BlackJAX conventions and the sampler can be called as usual, but optionally additional parameters can be passed to the log-density as explicit positional arguments following the position argument. In my implementation these must then be passed to init and step. Another option might be to include these in the state. Happy to discuss if people have other ideas for the interface.
# Assume signature logdensity_fn(position, param1, param2)
algo = blackjax.reparameterized_slice(logdensity_fn)
keys = jax.random.split(rng_key, num_steps)

def loss_fn(param1, param2):
    """ Calculates E[fn(X)] where X ~ logdensity(param1, param2) """
    state = algo.init(position, param1, param2)
    def one_step(state, key):
        state, _ = algo.step(key, state, param1, param2)
        return state, state.position

    _, positions = jax.lax.scan(one_step, state, keys)
    return jnp.mean(jax.vmap(fn)(positions))

grad_param1, grad_param2 = jax.grad(loss_fn, argnums=(0, 1))(param1, param2)
  • My implementation uses jax.custom_vjp on a single sampler step to conform to the interface of other BlackJAX samplers, and unlike the authors' implementation which implements the backwards pass on an entire sample trajectory.

Are you willing to open a PR?

Yes — I have implemented this.

Intentional scope limits for the initial PR

  • No attempt to generalize pathwise differentiation across any existing BlackJAX samplers in the same PR.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions