[WIP] Sdpa opt flashattn x86#6696
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c7c7d1da49
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| #pragma omp parallel for num_threads(opt.num_threads) | ||
| for (int g = 0; g < num_group; g++) |
There was a problem hiding this comment.
Parallelize decode work across attention heads
The decode fast path is only parallelized over num_group, then iterates num_heads_per_group serially inside each task. In common MQA/GQA cases where num_group is 1 or very small and num_heads is large, this launches too few OpenMP iterations and leaves most CPU cores idle, causing a large per-token latency regression versus head-level parallel execution.
Useful? React with 👍 / 👎.
| #if NCNN_BF16 | ||
| support_bf16_storage = false; | ||
| #endif |
There was a problem hiding this comment.
Preserve BF16 storage support in SDPA_x86
This constructor now unconditionally disables BF16 storage support, so non-int8 SDPA no longer advertises BF16 capability even when NCNN_BF16 is enabled. That forces additional FP32 conversion/storage around every SDPA call in BF16 inference pipelines, which is a regression in memory bandwidth and latency (and can trigger OOM in long-context workloads) compared with the prior BF16-capable path.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 274f6d5cf0
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| const float* mask_data[num_heads_per_group]; | ||
| int mask_stride[num_heads_per_group]; |
There was a problem hiding this comment.
Replace variable-length arrays in SDPA prefill state
The prefill path declares mask_data[num_heads_per_group] and mask_stride[num_heads_per_group] as variable-length arrays, which are a GCC/Clang extension and not valid C++ for MSVC. I checked the repository workflows (.github/workflows/windows.yml and windows-xp.yml), and they compile with Visual Studio toolchains, so this change can break Windows CI/builds even though it compiles on GNU toolchains. Please switch these to a standard container (for example, std::vector) or another MSVC-compatible allocation strategy.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 1473d633b8
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| for (; k < d; k++) | ||
| { | ||
| float qv = Q[(i + mi) * d + k]; | ||
| sum0 += qv * k0[k]; | ||
| sum1 += qv * k1[k]; |
There was a problem hiding this comment.
Reset tail index per row in AVX QK GEMM kernel
In qk_gemm_avx, the scalar cleanup loop for leftover dimensions reuses the same k counter across all mi rows in the micro-kernel. When d is not a multiple of 8, mi=0 advances k to d, so mi>0 skip the tail terms entirely, producing systematically low QK scores for most rows and incorrect attention outputs (reproducible with non-multiple head dims like 26 in test_sdpa).
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: cdd69dd69f
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| for (; k < d; k++) | ||
| sum += Q[(i + mi) * d + k] * kptr[k]; |
There was a problem hiding this comment.
Reset tail index per row in AVX odd-column QK path
In qk_gemm_avx, the j < n branch reuses the same k variable across all six mi rows during scalar tail cleanup. When d is not a multiple of 8 and this odd-column path is hit (n odd), mi=0 advances k to d, so mi=1..5 skip tail terms and produce underestimated QK scores for most rows. This causes incorrect attention outputs for valid shapes (for example, non-multiple head dims like 26 with odd destination sequence length).
Useful? React with 👍 / 👎.
| for (; k < d; k++) | ||
| sum += Q[(i + mi) * d + k] * kptr[k]; |
There was a problem hiding this comment.
Reset tail index per row in SSE2 QK micro-kernel
The SSE2 qk_gemm_sse2 micro-kernel has the same tail-index reuse bug: the scalar cleanup loop uses one shared k across mi=0..3. For d % 4 != 0, the first row consumes the tail and subsequent rows skip it, yielding systematically low dot products and wrong attention results on SSE2-only builds (including CI configs that disable AVX).
Useful? React with 👍 / 👎.
… into sdpa-opt-flashattn-x86
…e helpers, add missing D values to dispatch, add FP32 group_parallel decode path - Extract decode_mask_vec / decode_max_vec / decode_exp_sum_vec shared helpers with AVX512/AVX/SSE2 vectorization, replacing scalar loops in int8 decode path - Refactor sdpa_decode and sdpa_decode_chunk to use shared helpers, eliminating ~150 lines of duplicated mask/softmax code per function - Add D=768,1536,3072 to qk_gemm_dispatch and pv_gemm_dispatch for common LLM head dimensions (LLaMA-13B, Qwen-72B, etc.) - Add group_parallel path to FP32 decode: when num_group < num_threads, parallelize per-head instead of per-group to improve MQA/GQA utilization
…y, hoist Q copy in large_dim path - Apply vectorized max reduction + exp to int8 prefill path, previously scalar - Replace all remaining inline mask code in prefill with decode_mask_vec() - Move Q copy outside N-tile loop in large_dim prefill path: copies per M-tile reduce from (dst_seqlen/BLOCK_N) * num_heads to num_heads Key improvements: prefill e4096_s64_int8: -3.6% decode e128_p2048_int8: -10.4% decode e512_p1024_fp16ps: -8.1%
…, add V prefetch - Replace int8 prefill scalar scale_factor application with vec_scale_dispatch - Replace int8 prefill scalar output normalization with inline AVX512/AVX/SSE2 - Replace int8 decode scalar output normalization with memcpy+vec_scale_dispatch - Replace int8 prefill scalar zero-init with vec_zero_dispatch - Fix pv_gemm dispatch: use D_UNROLL=128 (AVX512) / 64 (AVX) / 32 (SSE2) instead of matching full D dimension. Prevents massive register spilling for large D (VEC_PER_UNROLL was 256 for D=4096 on AVX512 - far beyond 32 ZMM regs) - Add V matrix software prefetch in pv_gemm_avx512 inner j loop for d>=512 (hardware prefetcher cannot track 16KB+ strides)
… for D>=512 - Prefetch next 4 K rows at start of each j+=4 block when D >= 512 - Applied to both M_BLOCK2 and M_BLOCK1 loops - D>=512 means rows are at least 2KB apart; hardware prefetcher may not track such large strides reliably
…zation When num_group < num_threads (e.g. MQA with num_group=1, num_heads=32), parallelize per-head instead of per-group to use all threads
…s directly Deleted vec_scale_dispatch, vec_zero_dispatch, softmax_tile_dispatch, sdpa_decode_dispatch - each was a simple 1-line forward to the real function with no additional logic. Replaced all 16 call sites with direct calls to vec_scale/vec_zero/softmax_tile/sdpa_decode.
…rward_prefill() forward() was 1234 lines; prefill path (334 lines) is now a standalone static helper with explicit parameter list. forward() calls it via return sdpa_forward_prefill(query_ref, ..., use_bf16_path); No functional change. Benefits: forward() shrinks ~334 lines, prefill logic is independently testable, parameter dependencies are explicit.
forward() already defines BLOCK_N=128 at the top; the inner if-block's redefinition was shadowing it unnecessarily. Now only 1 BLOCK_N per scope: - sdpa_decode(), sdpa_decode_chunk(), sdpa_forward_prefill() each have their own - forward() has a single shared BLOCK_N used by all internal paths
…nes of duplicate code sdpa_decode simply delegates to sdpa_decode_chunk(out, &m, &l, q, K, V, mask, 0, n, d, out_d, scale) then normalizes with 1/l. Also fixes a pre-existing unused-variable warning (qk_num_blocks).
…f duplicate N-loop code The int8 decode group_parallel and per-head paths had identical N-tile inner loops (decode_qk → mask → max → online softmax → pv_gemv). Extracted as sdpa_int8_decode_core(). Also fixed: - sdpa_decode forward-declares sdpa_decode_chunk (now calls it) - sdpa_forward_prefill suppresses unused-param warning for num_heads
…rom forward() Extract 4 top-level static functions to make forward() a thin dispatcher: - sdpa_decode_int8_x86() — INT8 decode (src_seqlen==1) - sdpa_prefill_int8_x86() — INT8 prefill (src_seqlen>1) - sdpa_decode_bf16s_x86() — BF16 decode - sdpa_decode_x86() — FP32 decode Also includes prior refactoring: - Remove redundant *_dispatch wrappers from sdpa_x86_bf16s.h - Unify BF16 decode to reuse generic FP32 helpers (decode_mask_vec, decode_max_vec, etc.) - Normalize BF16 micro-kernel naming with _kernel suffix - Extract INT8 KV quantization into sdpa_quantize_key_value_int8_x86() - Add missing include guard to sdpa_x86_int8.h No functional change; decode performance improved vs baseline.
- Add n>=256 guard to QK/PV GEMM prefetch to avoid small-n overhead - Change large_dim threshold to embed_dim>512 && src_seqlen>16 so that MQA configs with seqlen=16 use the !large_dim batched path - Eliminate Q memcpy in large_dim path by passing query_head.row() directly - Always allocate q_batch with full BLOCK_M*num_heads_per_group size
…qlen - Revert qk_gemm_specialized_avx512<4096> tiling from 8,2 back to 2,2 to fix massive register spilling in !large_dim batched path (acc[8][4] = 32 ZMMs). - Add qk_gemm_specialized_avx512_large_m<4096> (2,2) for large_m path, but switch large_dim D=4096 QK GEMM to generic qk_gemm_avx512 (M=8,N=2) which has better ILP for small m/n. - Change large_dim threshold to embed_dim>512 && src_seqlen>16 so MQA seqlen=16 uses the faster batched !large_dim path. - Add n>=256 guard to QK/PV GEMM prefetch to avoid small-n overhead. - Disable outer OpenMP when num_group*num_m_tiles==1 to avoid parallel region overhead for single-tile configs. - Restructure large_dim path as N-outer/head-inner to reuse K/V tiles across heads, with inner OpenMP over heads when only one outer tile exists. Fixes: groups=4 seqlen=32: ~51ms -> ~8ms (-84%, restored correct s_head striding) groups=1 seqlen=16: ~15ms -> ~5ms (-66%, uses batched path) groups=32 seqlen=32: ~6ms -> ~6ms (unchanged) Remaining gap vs baseline: groups=1 seqlen=32: +63% (8.3ms vs 5.1ms, inherent in post-refactor path) groups=1 seqlen=64: +34% (30.2ms vs 22.6ms)
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f39f525859
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if (kv_cache) | ||
| { | ||
| const Mat value_head = value.channel(g); |
There was a problem hiding this comment.
Quantize value tensors for int8 decode without KV cache
In sdpa_quantize_key_value_int8_x86, value quantization is skipped unless kv_cache is true, but the decode int8 path (src_seqlen == 1) always consumes value_int8/value_scales. For non-cache calls that still hit int8 decode (for example single-token attention with int8_scale_term enabled), the value buffers remain uninitialized and the attention output is computed from invalid data. Please quantize V for decode regardless of kv_cache, or gate int8 decode on kv_cache explicitly.
Useful? React with 👍 / 👎.
| top_blobs, opt, embed_dim, src_seqlen, num_heads, | ||
| num_group, out_embed_dim, dst_seqlen, | ||
| num_heads_per_group, _scale, kv_cache, attn_mask, | ||
| use_bf16_path); |
There was a problem hiding this comment.
Avoid referencing BF16 flag outside its compile guard
use_bf16_path is declared only inside #if NCNN_BF16 but is passed unconditionally to sdpa_forward_prefill at the end of forward. When NCNN_BF16=0, this identifier is undefined and the file does not compile. Add a non-BF16 fallback definition (e.g. false) or guard the call argument with the same compile-time condition.
Useful? React with 👍 / 👎.
SDPA x86 Performance ComparisonBaseline:
Prefill
Decode
|
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 02e8f84466
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| for (int j = 0; j < block_n; j++) | ||
| { | ||
| float p = s[j]; | ||
| float inv_scale = 1.f / vscales[(n_start + j) * num_blocks + vb]; |
There was a problem hiding this comment.
Multiply int8 PV by stored dequant scale
In the int8 path, sdpa_quantize_key_value_int8_x86 converts each value scale to its reciprocal after quantization, so vscales already holds the dequant multiplier (absmax / 127). Dividing by it here applies the inverse scale instead, so any int8 SDPA decode/prefill with quantized values whose block absmax is not 1 produces PV outputs inflated or shrunk by the wrong factor.
Useful? React with 👍 / 👎.
| sum0 += (float)block_sum0 / (qscales[0] * ks0[0]); | ||
| sum1 += (float)block_sum1 / (qscales[0] * ks1[0]); |
There was a problem hiding this comment.
Multiply AVX2 QK dots by dequant scales
For AVX2 int8 decode, the paired-column loop divides by qscales * kscales, but both scales were already inverted to dequant multipliers before this kernel is called (q_scale[0] = 1.f / q_scale[0] and key_scales_head.row(j)[0] = 1.f / ...). When block_n >= 2 on AVX2 builds, the first columns in each pair get attention logits scaled in the opposite direction while the scalar tail below multiplies, yielding inconsistent and incorrect attention scores.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 32c0e6ba9b
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| Mat s_vec_thread = s_vec.channel(get_omp_thread_num()); | ||
| Mat o_accum_thread = o_accum.channel(get_omp_thread_num()); |
There was a problem hiding this comment.
Use rows for 2D per-thread workspaces
When prefill runs with opt.num_threads > 1, s_vec and o_accum are 2D Mats whose per-thread storage is in rows, but channel(get_omp_thread_num()) offsets by cstep; for a 2D Mat cstep is the entire allocation, so any OpenMP worker with thread id > 0 writes past the workspace and can corrupt memory or produce wrong attention outputs. Use row(thread_id) or allocate these as 3D Mats before taking channels.
Useful? React with 👍 / 👎.
|
|
||
| void decode_qk_dot_int8_avxvnniint8(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale) | ||
| { | ||
| decode_qk_dot_int8_avxvnni_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); |
There was a problem hiding this comment.
Call an AVXVNNIINT8-defined decode kernel
With NCNN_AVXVNNIINT8 enabled, CMake adds -mavxvnniint8, which defines __AVXVNNIINT8__ but not __AVXVNNI__ on GCC; sdpa_x86_int8.h only defines decode_qk_dot_int8_avxvnni_kernel under #if __AVXVNNI__. This wrapper therefore fails to compile in the AVXVNNIINT8 build before any runtime dispatch can happen.
Useful? React with 👍 / 👎.
No description provided.