Algorithm category: MCMC (discrete / gradient-based)
Paper reference
"A Langevin-like Sampler for Discrete Distributions" — Ruqi Zhang, Xingchao Liu, Qiang Liu. Proceedings of the 39th International Conference on Machine Learning (ICML), PMLR 162, 2022. https://arxiv.org/abs/2206.09914
Existing implementations
Benefit and motivation
BlackJAX currently has no sampler for discrete distributions. The Discrete Langevin Proposal (DLP) adapts the continuous Langevin gradient-based approach to discrete spaces, updating all coordinates in parallel in a single step controlled by a step size — contrasting with the coordinate-by-coordinate nature of Gibbs sampling. The paper demonstrates strong performance on Ising models, restricted Boltzmann machines, deep energy-based models, binary neural networks, and language generation tasks. Adding DLP/DMALA would open BlackJAX to a large class of problems currently unreachable with existing kernels.
Comparison to existing BlackJAX algorithms
- Closest existing:
blackjax.rmh (random-walk Metropolis-Hastings, gradient-free) — no discrete analogue exists in BlackJAX.
- Advantage: Gradient-guided proposals for discrete variables; parallel coordinate updates; includes unadjusted, Metropolis-adjusted, stochastic, and preconditioned variants; proven zero asymptotic bias for log-quadratic distributions.
- Limitation: Requires differentiable log-probability with respect to discrete variables (typically via relaxation or straight-through estimators); not applicable to purely combinatorial state spaces without further modification.
Estimated JAX implementation effort
M — The core DLP proposal is a closed-form operation on discrete variables that is straightforward in JAX. The DMALA (Metropolis-adjusted) variant adds a standard accept/reject step. The DMALAX reference implementation already follows BlackJAX API conventions and could be adapted with modest effort.
JAX-specific implementation notes
Gradient computation through discrete variables can use JAX's standard jax.grad when the log-prob is defined over relaxed or continuous embeddings. The proposal and accept/reject steps are fully jax.jit-compatible. No lax.while_loop or custom VJP is needed for the basic variant; preconditioned variants may benefit from custom_vjp to handle matrix square-root computations efficiently.
Willing to open a PR?
No — filing for community interest (re-filed from #263)
Re-filed from #263 which was closed as stale. Using the new structured proposal format.
Algorithm category: MCMC (discrete / gradient-based)
Paper reference
"A Langevin-like Sampler for Discrete Distributions" — Ruqi Zhang, Xingchao Liu, Qiang Liu. Proceedings of the 39th International Conference on Machine Learning (ICML), PMLR 162, 2022. https://arxiv.org/abs/2206.09914
Existing implementations
Benefit and motivation
BlackJAX currently has no sampler for discrete distributions. The Discrete Langevin Proposal (DLP) adapts the continuous Langevin gradient-based approach to discrete spaces, updating all coordinates in parallel in a single step controlled by a step size — contrasting with the coordinate-by-coordinate nature of Gibbs sampling. The paper demonstrates strong performance on Ising models, restricted Boltzmann machines, deep energy-based models, binary neural networks, and language generation tasks. Adding DLP/DMALA would open BlackJAX to a large class of problems currently unreachable with existing kernels.
Comparison to existing BlackJAX algorithms
blackjax.rmh(random-walk Metropolis-Hastings, gradient-free) — no discrete analogue exists in BlackJAX.Estimated JAX implementation effort
M — The core DLP proposal is a closed-form operation on discrete variables that is straightforward in JAX. The DMALA (Metropolis-adjusted) variant adds a standard accept/reject step. The DMALAX reference implementation already follows BlackJAX API conventions and could be adapted with modest effort.
JAX-specific implementation notes
Gradient computation through discrete variables can use JAX's standard
jax.gradwhen the log-prob is defined over relaxed or continuous embeddings. The proposal and accept/reject steps are fullyjax.jit-compatible. Nolax.while_loopor custom VJP is needed for the basic variant; preconditioned variants may benefit fromcustom_vjpto handle matrix square-root computations efficiently.Willing to open a PR?
No — filing for community interest (re-filed from #263)
Re-filed from #263 which was closed as stale. Using the new structured proposal format.