Nested Sampling#911
Open
yallup wants to merge 285 commits into
Open
Conversation
…o nested_sampling
…o nested_sampling
…o nested_sampling
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Reinstate internal adaptive and tuning
Flatten NSState and move to external parameter management
Extract evidence integration to NSIntegrator
Bundle particle data into StateWithLogLikelihood
…blackjax into refactor-slice-sampling-api
This reverts commit d454247.
Resolved merge conflicts by accepting upstream changes: - Added persistent_sampling SMC feature - Added Lfactor parameter to mclmc_adaptation - Renamed lmbda to tempering_param for better clarity - Preserved nested sampling functionality All upstream SMC improvements integrated while maintaining nested sampling code.
Merge upstream blackjax-devs main into nested sampling branch
…ling-api # Conflicts: # blackjax/__init__.py # blackjax/adaptation/mclmc_adaptation.py # blackjax/smc/tempered.py # tests/smc/test_inner_kernel_tuning.py # tests/smc/test_pretuning.py
- Fix loglikelihood_birth type hint: Array -> float - Rename ambiguous l/r to left/right in ss.py (E741) - Fix test function signatures and expected values
…blackjax into refactor-slice-sampling-api
Refactor slice sampler to decouple constraint handling from core algorithm. Refactors large portions of the machinery and introduces some useful features: - Cap on number of slice sampling steps to prevent unprotected loops - Refactored clearer machinery for different update strategies - Matches published description of algorithm
* Generalise inner kernel interface for sampler-agnostic NS The base kernel was MCMC-centric: it selected start particles and pre-batched RNG keys before calling the inner kernel. This prevented non-MCMC inner kernels (e.g. rejection sampling from ellipsoids) which need the full live set. - Inner kernel now receives (rng_key, state, dead_idx, loglikelihood_0) with a single PRNG key and the full duck-typed NS state - delete_fn simplified to (state, num_delete) -> (dead_idx, target_idx); start particle selection moved into the MCMC wrapper - from_mcmc.update_with_mcmc_take_last handles survivor selection and key splitting internally Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Simplify inner kernel signature by removing dead_idx parameter The inner kernel only used dead_idx.shape[0] for the count, which is already known at construction time. Move num_delete to a closure variable in update_with_mcmc_take_last, giving a cleaner constrained-prior-sampler interface: (rng_key, state, loglikelihood_0) -> (new_particles, info). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Simplify constrained MCMC step by removing retry loop * Add update_strategy parameter to nss.build_kernel Allows custom update strategies (e.g., sharded) to be plugged in while reusing the rest of the NSS machinery. The strategy is duck-typed with signature: (constrained_step_fn, num_inner_steps, num_delete) -> inner_kernel Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Fix black formatting for delete_fn signature Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: yallup <david.yallup@gmail.com>
Double traced jit
Cherry-picks the fix from origin/irmh (94cf667): float32 cumsum drift in logX can push log1mexp's input slightly positive, yielding NaN in log_weights / sample / ess. Clamping to -eps keeps it well-defined. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
NS: clamp log1mexp input to avoid NaN from float32 logX drift
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Nested Sampling (NS) is a workhorse inference method in the physical sciences (cosmology, particle physics, gravitational waves, materials), where it is routinely used for both posterior characterisation and evidence-driven model comparison. Bringing a first-class NS implementation into blackjax lets existing analyses be benchmarked against blackjax's SMC alternatives under a shared, JIT-compatible API, and makes the method available to users who already build pipelines on top of blackjax.
This PR contributes three pieces, layered to match existing blackjax conventions:
Nested Sampling kernel (
blackjax.ns) -- a base NS kernel and an adaptive variant, withfrom_mcmcpatterns mirroringblackjax.smc's inner-kernel composition. The interface is sampler-agnostic so any constrained inner kernel can drive the live-point update.Slice Sampling (
blackjax.mcmc.ss) -- a standalone hit-and-run slice sampler (HRSS) implementing Neal's stepping-out and rejection with contraction. Usable on its own as an MCMC kernel, and serves as the canonical inner kernel for NS.Nested Slice Sampling (NSS,
blackjax.ns.nss) -- the practical synthesis: a performant constrained sampler wiring HRSS into the NSfrom_mcmcpattern. Matches the algorithm described in the companion paper arXiv:2601.23252.Sampling-book examples can follow this PR -- an NSS example extending the existing SMC material, and a standalone slice-sampling example.
This branch has been through extensive iteration; happy to squash changes to avoid git history bloat
Related issues / discussions
Closes earlier PR #755 , builds on issue #753
Checklist
General
mainpre-commit run --all-filespasses (black, isort, flake8, mypy)mamba run -n blackjax python -m pytest tests/)Code quality
logdensity,jax.tree.map,jax.random.key(),jnp.clip(min=, max=))New sampler / algorithm
init/build_kernel/as_top_level_apiblackjax/__init__.py