Skip to content

Add discrete Langevin sampler (DLP / DMALA) for discrete distributions #835

Description

@junpenglao

Algorithm category: MCMC (discrete / gradient-based)

Paper reference

"A Langevin-like Sampler for Discrete Distributions" — Ruqi Zhang, Xingchao Liu, Qiang Liu. Proceedings of the 39th International Conference on Machine Learning (ICML), PMLR 162, 2022. https://arxiv.org/abs/2206.09914

Existing implementations

Benefit and motivation

BlackJAX currently has no sampler for discrete distributions. The Discrete Langevin Proposal (DLP) adapts the continuous Langevin gradient-based approach to discrete spaces, updating all coordinates in parallel in a single step controlled by a step size — contrasting with the coordinate-by-coordinate nature of Gibbs sampling. The paper demonstrates strong performance on Ising models, restricted Boltzmann machines, deep energy-based models, binary neural networks, and language generation tasks. Adding DLP/DMALA would open BlackJAX to a large class of problems currently unreachable with existing kernels.

Comparison to existing BlackJAX algorithms

  • Closest existing: blackjax.rmh (random-walk Metropolis-Hastings, gradient-free) — no discrete analogue exists in BlackJAX.
  • Advantage: Gradient-guided proposals for discrete variables; parallel coordinate updates; includes unadjusted, Metropolis-adjusted, stochastic, and preconditioned variants; proven zero asymptotic bias for log-quadratic distributions.
  • Limitation: Requires differentiable log-probability with respect to discrete variables (typically via relaxation or straight-through estimators); not applicable to purely combinatorial state spaces without further modification.

Estimated JAX implementation effort

M — The core DLP proposal is a closed-form operation on discrete variables that is straightforward in JAX. The DMALA (Metropolis-adjusted) variant adds a standard accept/reject step. The DMALAX reference implementation already follows BlackJAX API conventions and could be adapted with modest effort.

JAX-specific implementation notes

Gradient computation through discrete variables can use JAX's standard jax.grad when the log-prob is defined over relaxed or continuous embeddings. The proposal and accept/reject steps are fully jax.jit-compatible. No lax.while_loop or custom VJP is needed for the basic variant; preconditioned variants may benefit from custom_vjp to handle matrix square-root computations efficiently.

Willing to open a PR?

No — filing for community interest (re-filed from #263)


Re-filed from #263 which was closed as stale. Using the new structured proposal format.

Metadata

Metadata

Assignees

No one assigned

    Labels

    samplerIssue related to samplers

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions