Skip to content

feat: AMX-BF16 inner-product distance computer for IndexHNSWSQ#5235

Open
xtangxtang wants to merge 5 commits into
facebookresearch:mainfrom
epeshared:bf16-amx-hnsw
Open

feat: AMX-BF16 inner-product distance computer for IndexHNSWSQ#5235
xtangxtang wants to merge 5 commits into
facebookresearch:mainfrom
epeshared:bf16-amx-hnsw

Conversation

@xtangxtang

@xtangxtang xtangxtang commented May 26, 2026

Copy link
Copy Markdown

Summary

Accelerate IndexHNSWSQ(QT_bf16) inner-product search on Intel AMX
(Sapphire Rapids / Granite Rapids and newer) by computing 16 bf16 inner
products per distances_batch_16 call in a single AMX tile pass.

Re-architected from the original standalone-storage version to fit faiss's
existing abstractions (per review):

1. BF16 via the ScalarQuantizer (IndexHNSWSQ)

QuantizerBF16 (QT_bf16) already stores each component as bf16, so the AMX
tile engine consumes the SQ codes directly — no separate bf16_storage, no
re-encoding, and no bf16-specific members on IndexHNSW.

2. Batched DistanceComputer, 4 → 16

DistanceComputer::distances_batch_16 (default = 4× distances_batch_4, so
every computer benefits for free) is overridden by the AMX bf16 computer. The
level-0 HNSW loop (search_from_candidates_fixVT) now batches 16 neighbours;
the upper-level greedy and unbounded loops keep batch-4.

3. AMX behind the SIMD dynamic dispatch

New SIMDLevel::AMX (above AVX512_SPR) with runtime CPUID + tile-permission
detection and fallback chain AMX → AVX512_SPR → AVX512 → AVX2 → NONE. The
kernel is faiss/impl/scalar_quantizer/sq-amx.cpp (guarded by
COMPILE_SIMD_AMX, compiled into a faiss_amx target with
-mamx-tile -mamx-bf16, mirroring faiss_avx512_spr).
ScalarQuantizer::get_distance_computer dispatches via with_simd_level_amx.

Scope: inner product only

AMX is routed only for QT_bf16 + METRIC_INNER_PRODUCT. L2 (and every other
quantizer type) uses the AVX512 SQ path, because the per-vector ‖c‖² needed
to turn the tile dot product into an L2 distance offsets the tile gain.

Performance

Granite Rapids (Xeon 6972P), d=768, 100k vectors, efSearch=128; AMX vs the
AVX512 bf16 SQ path on an identical index, recall unchanged (~0.99):

threads AVX512 QPS AMX QPS speedup
1 1,386 1,831 1.32×
16 (8C/16T) 15,008 18,217 1.21×

At full-socket thread counts the workload becomes memory-bandwidth-bound and
the two converge; the gain is in the compute-bound regime.

Usage

faiss::IndexHNSWSQ index(
    d, faiss::ScalarQuantizer::QT_bf16, M, faiss::METRIC_INNER_PRODUCT);
index.train(n, xb);
index.add(n, xb);
index.search(nq, xq, k, distances, labels);  // AMX tile path on SPR/GNR+

Commits

  1. feat: add AMX SIMD level to the dynamic dispatch framework
  2. feat: add DistanceComputer::distances_batch_16 for HNSW level-0 search
  3. feat: add AMX-BF16 tile inner-product distance computer

@meta-cla meta-cla Bot added the CLA Signed label May 26, 2026
@xtangxtang xtangxtang changed the title feat: Add BF16+AMX acceleration for HNSW search feat: Add BF16+AMX acceleration for HNSW search (batch-16 distance computation) May 26, 2026
@mdouze

mdouze commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

Thanks for the work on this PR.

In order for us to accept it, it should be re-architctured in the following way:

Thanks!

xtang added 3 commits June 5, 2026 14:58
Add SIMDLevel::AMX (Intel AMX-BF16 tiles, Sapphire Rapids / Granite
Rapids and newer) above AVX512_SPR in faiss' SIMD dynamic-dispatch
framework:

- simd_levels.h: new enum value; SINGLE_SIMD_LEVEL, the 256/512-bit
  level selectors and simd_width() treat AMX like AVX512 for generic
  SIMD types (only dedicated AMX kernels differ).
- simd_levels.cpp: runtime detection via CPUID leaf 7 (EDX bit 22
  AMX-BF16, bit 24 AMX-TILE) gated on XCR0 tile state, plus a one-time
  arch_prctl(ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA) permission
  request; to_string / to_simd_level.
- simd_dispatch.h: get_simd_fallback(AMX) -> AVX512_SPR, an
  AVAILABLE_SIMD_LEVELS_A0_AMX mask, the AMX case in the DD switch,
  and a with_simd_level_amx() helper.

Fallback chain: AMX -> AVX512_SPR -> AVX512 -> AVX2 -> NONE. No AMX
kernels are wired up yet; this is the framework plumbing only.
Extend the batched DistanceComputer interface from 4 to 16, as requested
in PR review, so an AMX-specialized scalar-quantizer distance computer
can fill a full BF16 tile:

- DistanceComputer: new virtual distances_batch_16(idx, dis) with a
  default that groups the work into 4x distances_batch_4 -- so every
  existing computer with a batch_4 specialization benefits for free.
  NegativeDistanceComputer overrides it to negate (inner product).
- HNSW search_from_candidates_fixVT (the level-0 hot loop) now
  accumulates 16 neighbors and calls distances_batch_16 instead of 4.
  Upper-level greedy and the unbounded search keep batch-4.
Accelerate IndexHNSWSQ(QT_bf16) inner-product search on Intel AMX
(Sapphire Rapids / Granite Rapids and newer) by computing 16 bf16 inner
products per distances_batch_16 call in a single AMX tile pass.

- impl/scalar_quantizer/sq-amx.cpp: the AMX tile kernel plus a DCBF16Amx
  distance computer that composes the AVX512 bf16 DC for the scalar
  operator()/query_to_code remainder path and overrides
  distances_batch_16 with the tile dot product. QuantizerBF16 already
  stores codes as bf16, so the tile engine consumes the SQ storage
  directly (no re-encoding). sq_select_distance_computer<AMX> routes only
  QT_bf16 + inner product to AMX; L2 and all other quantizer types
  delegate to the AVX512 scalar-quantizer path.
- ScalarQuantizer.cpp: get_distance_computer dispatches through
  with_simd_level_amx so AMX is in the fallback mask.
- CMakeLists.txt: faiss_amx target (mirrors faiss_avx512_spr) compiled
  with -mamx-tile -mamx-bf16 and COMPILE_SIMD_AMX, plus FAISS_SIMD_AMX_SRC.

~1.3-1.5x inner-product QPS over the AVX512 bf16 path on Granite Rapids
(Xeon 6972P), recall unchanged.
@xtangxtang

Copy link
Copy Markdown
Author

Thanks for the detailed review, @mdouze — I've re-architected the PR along all three points. The standalone BF16 storage + AMX path is gone; everything now flows through the existing ScalarQuantizer and SIMD dynamic-dispatch infrastructure.

1. BF16 encoding lives in the ScalarQuantizer / IndexHNSWSQ
The separate bf16_storage / bf16_norms members on IndexHNSW are removed. Storage and query encoding now use IndexHNSWSQ(d, ScalarQuantizer::QT_bf16, M). Since QuantizerBF16 already stores each component as bf16, the AMX tile engine consumes the SQ codes directly — no separate storage and no re-encoding.

2. The batched DistanceComputer interface, extended 4 → 16
Added DistanceComputer::distances_batch_16, with a default implementation that groups the work into 4× distances_batch_4 (so every existing computer benefits for free), overridden by the AMX bf16 computer. The level-0 HNSW loop (search_from_candidates_fixVT) now accumulates 16 neighbours and calls distances_batch_16; the upper-level greedy and unbounded loops keep batch-4.

3. AMX in a dedicated SIMD file, behind the dispatch mechanism
Added SIMDLevel::AMX (above AVX512_SPR) to simd_levels.{h,cpp} and simd_dispatch.h — runtime CPUID + tile-permission detection, fallback chain AMX → AVX512_SPR → AVX512 → AVX2 → NONE. The kernel is faiss/impl/scalar_quantizer/sq-amx.cpp, guarded by COMPILE_SIMD_AMX and built into a faiss_amx target with -mamx-tile -mamx-bf16 (mirroring faiss_avx512_spr). ScalarQuantizer::get_distance_computer dispatches through with_simd_level_amx.

Scope — inner product only. AMX is routed only for QT_bf16 + METRIC_INNER_PRODUCT. For L2, the per-vector ‖c‖² needed to turn the tile dot product into a distance offsets the tile gain (and slightly hurts recall by quantising the query), so L2 and every other quantizer type use the AVX512 SQ path.

Performance (Granite Rapids, Xeon 6972P; d=768, 100k vectors, efSearch=128; AMX vs the AVX512 bf16 SQ path on an identical index, recall unchanged at ~0.99):

threads AVX512 QPS AMX QPS speedup
1 1,386 1,831 1.32×
16 (8C/16T) 15,008 18,217 1.21×

At full-socket thread counts the workload becomes memory-bandwidth-bound and the two converge, as expected — the gain is in the compute-bound regime (latency / moderate concurrency).

The branch is rebased onto current main and is now three focused commits (AMX SIMD level / distances_batch_16 / sq-amx.cpp). I'll update the PR title and description to match the new scope.

@xtangxtang xtangxtang changed the title feat: Add BF16+AMX acceleration for HNSW search (batch-16 distance computation) feat: AMX-BF16 inner-product distance computer for IndexHNSWSQ Jun 5, 2026
xtangxtang pushed a commit to epeshared/faiss-hnsw-amx that referenced this pull request Jun 11, 2026
…kresearch#5235 base

knn main's jni code depends on faiss-side changes from its patches
0007/0008/0010 (and parts of 0005) that are not in upstream June-2026:
- IndexBinary/IndexBinaryHNSW: optional isExtendedIndex ctor flag relaxing
  the d % 8 check; virtual get_distance_computer (FaissSQHnsw overrides it)
- write_index_binary with io_flags (IO_FLAG_SKIP_STORAGE) + binary HNSW
  skip-storage write path
- HNSW::m member; multi-vector search_level_0 pieces
0009's entry-point radial-search machinery is NOT ported: June-2026 already
ships IndexHNSWCagra::range_search and nothing in knn jni references the
*_with_entry_point variants.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants