Skip to content

Nested Sampling#911

Open
yallup wants to merge 285 commits into
blackjax-devs:mainfrom
handley-lab:nested_sampling
Open

Nested Sampling#911
yallup wants to merge 285 commits into
blackjax-devs:mainfrom
handley-lab:nested_sampling

Conversation

@yallup

@yallup yallup commented May 12, 2026

Copy link
Copy Markdown

Description

Nested Sampling (NS) is a workhorse inference method in the physical sciences (cosmology, particle physics, gravitational waves, materials), where it is routinely used for both posterior characterisation and evidence-driven model comparison. Bringing a first-class NS implementation into blackjax lets existing analyses be benchmarked against blackjax's SMC alternatives under a shared, JIT-compatible API, and makes the method available to users who already build pipelines on top of blackjax.

This PR contributes three pieces, layered to match existing blackjax conventions:

  1. Nested Sampling kernel (blackjax.ns) -- a base NS kernel and an adaptive variant, with from_mcmc patterns mirroring blackjax.smc's inner-kernel composition. The interface is sampler-agnostic so any constrained inner kernel can drive the live-point update.

  2. Slice Sampling (blackjax.mcmc.ss) -- a standalone hit-and-run slice sampler (HRSS) implementing Neal's stepping-out and rejection with contraction. Usable on its own as an MCMC kernel, and serves as the canonical inner kernel for NS.

  3. Nested Slice Sampling (NSS, blackjax.ns.nss) -- the practical synthesis: a performant constrained sampler wiring HRSS into the NS from_mcmc pattern. Matches the algorithm described in the companion paper arXiv:2601.23252.

Sampling-book examples can follow this PR -- an NSS example extending the existing SMC material, and a standalone slice-sampling example.

This branch has been through extensive iteration; happy to squash changes to avoid git history bloat

Related issues / discussions

Closes earlier PR #755 , builds on issue #753

Checklist

General

  • The branch is rebased on the latest main
  • Commit messages are clear and descriptive
  • pre-commit run --all-files passes (black, isort, flake8, mypy)
  • Tests cover the changes (mamba run -n blackjax python -m pytest tests/)

Code quality

  • Public functions have docstrings following the NumPy style guide
  • Naming follows existing conventions (logdensity, jax.tree.map, jax.random.key(), jnp.clip(min=, max=))
  • All new code is JIT-compatible

New sampler / algorithm

  • There is an open issue discussing this algorithm (use the sampler proposal template)
  • Follows the three-layer pattern: init / build_kernel / as_top_level_api
  • Registered in blackjax/__init__.py
  • An example notebook has been added or updated

williamjameshandley and others added 29 commits November 24, 2025 12:27
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Reinstate internal adaptive and tuning
Flatten NSState and move to external parameter management
Extract evidence integration to NSIntegrator
Bundle particle data into StateWithLogLikelihood
Resolved merge conflicts by accepting upstream changes:
- Added persistent_sampling SMC feature
- Added Lfactor parameter to mclmc_adaptation
- Renamed lmbda to tempering_param for better clarity
- Preserved nested sampling functionality

All upstream SMC improvements integrated while maintaining nested sampling code.
Merge upstream blackjax-devs main into nested sampling branch
…ling-api

# Conflicts:
#	blackjax/__init__.py
#	blackjax/adaptation/mclmc_adaptation.py
#	blackjax/smc/tempered.py
#	tests/smc/test_inner_kernel_tuning.py
#	tests/smc/test_pretuning.py
- Fix loglikelihood_birth type hint: Array -> float
- Rename ambiguous l/r to left/right in ss.py (E741)
- Fix test function signatures and expected values
Refactor slice sampler to decouple constraint handling from core algorithm. Refactors large portions of the machinery and introduces some useful features:
- Cap on number of slice sampling steps to prevent unprotected loops
- Refactored clearer machinery for different update strategies
- Matches published description of algorithm
* Generalise inner kernel interface for sampler-agnostic NS

The base kernel was MCMC-centric: it selected start particles and
pre-batched RNG keys before calling the inner kernel. This prevented
non-MCMC inner kernels (e.g. rejection sampling from ellipsoids) which
need the full live set.

- Inner kernel now receives (rng_key, state, dead_idx, loglikelihood_0)
  with a single PRNG key and the full duck-typed NS state
- delete_fn simplified to (state, num_delete) -> (dead_idx, target_idx);
  start particle selection moved into the MCMC wrapper
- from_mcmc.update_with_mcmc_take_last handles survivor selection and
  key splitting internally

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Simplify inner kernel signature by removing dead_idx parameter

The inner kernel only used dead_idx.shape[0] for the count, which is
already known at construction time. Move num_delete to a closure variable
in update_with_mcmc_take_last, giving a cleaner constrained-prior-sampler
interface: (rng_key, state, loglikelihood_0) -> (new_particles, info).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Simplify constrained MCMC step by removing retry loop

* Add update_strategy parameter to nss.build_kernel

Allows custom update strategies (e.g., sharded) to be plugged in
while reusing the rest of the NSS machinery. The strategy is duck-typed
with signature: (constrained_step_fn, num_inner_steps, num_delete) -> inner_kernel

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Fix black formatting for delete_fn signature

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: yallup <david.yallup@gmail.com>
* surely rng_key was supposed to be passed through here

* dunno what was going on with this type hint
Cherry-picks the fix from origin/irmh (94cf667): float32 cumsum drift
in logX can push log1mexp's input slightly positive, yielding NaN in
log_weights / sample / ess. Clamping to -eps keeps it well-defined.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
NS: clamp log1mexp input to avoid NaN from float32 logX drift
@yallup yallup mentioned this pull request Jun 23, 2026
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants