Skip to content

[WIP] Sdpa opt flashattn x86#6696

Open
futz12 wants to merge 56 commits into
Tencent:masterfrom
futz12:sdpa-opt-flashattn-x86
Open

[WIP] Sdpa opt flashattn x86#6696
futz12 wants to merge 56 commits into
Tencent:masterfrom
futz12:sdpa-opt-flashattn-x86

Conversation

@futz12

@futz12 futz12 commented Apr 28, 2026

Copy link
Copy Markdown
Contributor

No description provided.

@futz12

futz12 commented Apr 28, 2026

Copy link
Copy Markdown
Contributor Author

Uploading image.png…

Almost same as gemm+softmax but lower memory, fast then llama.cpp

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c7c7d1da49

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread src/layer/x86/sdpa_x86.cpp Outdated
Comment on lines +3511 to +3512
#pragma omp parallel for num_threads(opt.num_threads)
for (int g = 0; g < num_group; g++)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Parallelize decode work across attention heads

The decode fast path is only parallelized over num_group, then iterates num_heads_per_group serially inside each task. In common MQA/GQA cases where num_group is 1 or very small and num_heads is large, this launches too few OpenMP iterations and leaves most CPU cores idle, causing a large per-token latency regression versus head-level parallel execution.

Useful? React with 👍 / 👎.

Comment on lines +32 to +34
#if NCNN_BF16
support_bf16_storage = false;
#endif

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve BF16 storage support in SDPA_x86

This constructor now unconditionally disables BF16 storage support, so non-int8 SDPA no longer advertises BF16 capability even when NCNN_BF16 is enabled. That forces additional FP32 conversion/storage around every SDPA call in BF16 inference pipelines, which is a regression in memory bandwidth and latency (and can trigger OOM in long-context workloads) compared with the prior BF16-capable path.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 274f6d5cf0

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread src/layer/x86/sdpa_x86.cpp Outdated
Comment on lines +3560 to +3561
const float* mask_data[num_heads_per_group];
int mask_stride[num_heads_per_group];

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Replace variable-length arrays in SDPA prefill state

The prefill path declares mask_data[num_heads_per_group] and mask_stride[num_heads_per_group] as variable-length arrays, which are a GCC/Clang extension and not valid C++ for MSVC. I checked the repository workflows (.github/workflows/windows.yml and windows-xp.yml), and they compile with Visual Studio toolchains, so this change can break Windows CI/builds even though it compiles on GNU toolchains. Please switch these to a standard container (for example, std::vector) or another MSVC-compatible allocation strategy.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 1473d633b8

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +1874 to +1878
for (; k < d; k++)
{
float qv = Q[(i + mi) * d + k];
sum0 += qv * k0[k];
sum1 += qv * k1[k];

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Reset tail index per row in AVX QK GEMM kernel

In qk_gemm_avx, the scalar cleanup loop for leftover dimensions reuses the same k counter across all mi rows in the micro-kernel. When d is not a multiple of 8, mi=0 advances k to d, so mi>0 skip the tail terms entirely, producing systematically low QK scores for most rows and incorrect attention outputs (reproducible with non-multiple head dims like 26 in test_sdpa).

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: cdd69dd69f

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +1940 to +1941
for (; k < d; k++)
sum += Q[(i + mi) * d + k] * kptr[k];

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Reset tail index per row in AVX odd-column QK path

In qk_gemm_avx, the j < n branch reuses the same k variable across all six mi rows during scalar tail cleanup. When d is not a multiple of 8 and this odd-column path is hit (n odd), mi=0 advances k to d, so mi=1..5 skip tail terms and produce underestimated QK scores for most rows. This causes incorrect attention outputs for valid shapes (for example, non-multiple head dims like 26 with odd destination sequence length).

Useful? React with 👍 / 👎.

Comment on lines +2593 to +2594
for (; k < d; k++)
sum += Q[(i + mi) * d + k] * kptr[k];

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Reset tail index per row in SSE2 QK micro-kernel

The SSE2 qk_gemm_sse2 micro-kernel has the same tail-index reuse bug: the scalar cleanup loop uses one shared k across mi=0..3. For d % 4 != 0, the first row consumes the tail and subsequent rows skip it, yielding systematically low dot products and wrong attention results on SSE2-only builds (including CI configs that disable AVX).

Useful? React with 👍 / 👎.

futz12 and others added 16 commits May 3, 2026 22:23
…e helpers, add missing D values to dispatch, add FP32 group_parallel decode path

- Extract decode_mask_vec / decode_max_vec / decode_exp_sum_vec shared helpers
  with AVX512/AVX/SSE2 vectorization, replacing scalar loops in int8 decode path
- Refactor sdpa_decode and sdpa_decode_chunk to use shared helpers, eliminating
  ~150 lines of duplicated mask/softmax code per function
- Add D=768,1536,3072 to qk_gemm_dispatch and pv_gemm_dispatch for common
  LLM head dimensions (LLaMA-13B, Qwen-72B, etc.)
- Add group_parallel path to FP32 decode: when num_group < num_threads,
  parallelize per-head instead of per-group to improve MQA/GQA utilization
…y, hoist Q copy in large_dim path

- Apply vectorized max reduction + exp to int8 prefill path, previously scalar
- Replace all remaining inline mask code in prefill with decode_mask_vec()
- Move Q copy outside N-tile loop in large_dim prefill path:
  copies per M-tile reduce from (dst_seqlen/BLOCK_N) * num_heads to num_heads

Key improvements:
  prefill e4096_s64_int8: -3.6%
  decode  e128_p2048_int8: -10.4%
  decode  e512_p1024_fp16ps: -8.1%
…, add V prefetch

- Replace int8 prefill scalar scale_factor application with vec_scale_dispatch
- Replace int8 prefill scalar output normalization with inline AVX512/AVX/SSE2
- Replace int8 decode scalar output normalization with memcpy+vec_scale_dispatch
- Replace int8 prefill scalar zero-init with vec_zero_dispatch
- Fix pv_gemm dispatch: use D_UNROLL=128 (AVX512) / 64 (AVX) / 32 (SSE2)
  instead of matching full D dimension. Prevents massive register spilling
  for large D (VEC_PER_UNROLL was 256 for D=4096 on AVX512 - far beyond 32 ZMM regs)
- Add V matrix software prefetch in pv_gemm_avx512 inner j loop for d>=512
  (hardware prefetcher cannot track 16KB+ strides)
… for D>=512

- Prefetch next 4 K rows at start of each j+=4 block when D >= 512
- Applied to both M_BLOCK2 and M_BLOCK1 loops
- D>=512 means rows are at least 2KB apart; hardware prefetcher may
  not track such large strides reliably
…zation

When num_group < num_threads (e.g. MQA with num_group=1, num_heads=32),
parallelize per-head instead of per-group to use all threads
…s directly

Deleted vec_scale_dispatch, vec_zero_dispatch, softmax_tile_dispatch,
sdpa_decode_dispatch - each was a simple 1-line forward to the real
function with no additional logic. Replaced all 16 call sites with
direct calls to vec_scale/vec_zero/softmax_tile/sdpa_decode.
…rward_prefill()

forward() was 1234 lines; prefill path (334 lines) is now a standalone
static helper with explicit parameter list. forward() calls it via
  return sdpa_forward_prefill(query_ref, ..., use_bf16_path);

No functional change. Benefits: forward() shrinks ~334 lines, prefill
logic is independently testable, parameter dependencies are explicit.
forward() already defines BLOCK_N=128 at the top; the inner if-block's
redefinition was shadowing it unnecessarily. Now only 1 BLOCK_N per scope:
  - sdpa_decode(), sdpa_decode_chunk(), sdpa_forward_prefill() each have their own
  - forward() has a single shared BLOCK_N used by all internal paths
…nes of duplicate code

sdpa_decode simply delegates to sdpa_decode_chunk(out, &m, &l, q, K, V,
mask, 0, n, d, out_d, scale) then normalizes with 1/l.  Also fixes a
pre-existing unused-variable warning (qk_num_blocks).
…f duplicate N-loop code

The int8 decode group_parallel and per-head paths had identical N-tile
inner loops (decode_qk → mask → max → online softmax → pv_gemv).
Extracted as sdpa_int8_decode_core().  Also fixed:
- sdpa_decode forward-declares sdpa_decode_chunk (now calls it)
- sdpa_forward_prefill suppresses unused-param warning for num_heads
…rom forward()

Extract 4 top-level static functions to make forward() a thin dispatcher:
- sdpa_decode_int8_x86()     — INT8 decode (src_seqlen==1)
- sdpa_prefill_int8_x86()    — INT8 prefill (src_seqlen>1)
- sdpa_decode_bf16s_x86()    — BF16 decode
- sdpa_decode_x86()          — FP32 decode

Also includes prior refactoring:
- Remove redundant *_dispatch wrappers from sdpa_x86_bf16s.h
- Unify BF16 decode to reuse generic FP32 helpers (decode_mask_vec, decode_max_vec, etc.)
- Normalize BF16 micro-kernel naming with _kernel suffix
- Extract INT8 KV quantization into sdpa_quantize_key_value_int8_x86()
- Add missing include guard to sdpa_x86_int8.h

No functional change; decode performance improved vs baseline.
- Add n>=256 guard to QK/PV GEMM prefetch to avoid small-n overhead
- Change large_dim threshold to embed_dim>512 && src_seqlen>16 so that
  MQA configs with seqlen=16 use the !large_dim batched path
- Eliminate Q memcpy in large_dim path by passing query_head.row() directly
- Always allocate q_batch with full BLOCK_M*num_heads_per_group size
…qlen

- Revert qk_gemm_specialized_avx512<4096> tiling from 8,2 back to 2,2 to fix
  massive register spilling in !large_dim batched path (acc[8][4] = 32 ZMMs).
- Add qk_gemm_specialized_avx512_large_m<4096> (2,2) for large_m path, but
  switch large_dim D=4096 QK GEMM to generic qk_gemm_avx512 (M=8,N=2) which
  has better ILP for small m/n.
- Change large_dim threshold to embed_dim>512 && src_seqlen>16 so MQA seqlen=16
  uses the faster batched !large_dim path.
- Add n>=256 guard to QK/PV GEMM prefetch to avoid small-n overhead.
- Disable outer OpenMP when num_group*num_m_tiles==1 to avoid parallel region
  overhead for single-tile configs.
- Restructure large_dim path as N-outer/head-inner to reuse K/V tiles across
  heads, with inner OpenMP over heads when only one outer tile exists.

Fixes:
  groups=4 seqlen=32: ~51ms -> ~8ms (-84%, restored correct s_head striding)
  groups=1 seqlen=16: ~15ms -> ~5ms (-66%, uses batched path)
  groups=32 seqlen=32: ~6ms -> ~6ms (unchanged)

Remaining gap vs baseline:
  groups=1 seqlen=32: +63% (8.3ms vs 5.1ms, inherent in post-refactor path)
  groups=1 seqlen=64: +34% (30.2ms vs 22.6ms)

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f39f525859

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +4592 to +4594
if (kv_cache)
{
const Mat value_head = value.channel(g);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Quantize value tensors for int8 decode without KV cache

In sdpa_quantize_key_value_int8_x86, value quantization is skipped unless kv_cache is true, but the decode int8 path (src_seqlen == 1) always consumes value_int8/value_scales. For non-cache calls that still hit int8 decode (for example single-token attention with int8_scale_term enabled), the value buffers remain uninitialized and the attention output is computed from invalid data. Please quantize V for decode regardless of kv_cache, or gate int8 decode on kv_cache explicitly.

Useful? React with 👍 / 👎.

top_blobs, opt, embed_dim, src_seqlen, num_heads,
num_group, out_embed_dim, dst_seqlen,
num_heads_per_group, _scale, kv_cache, attn_mask,
use_bf16_path);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid referencing BF16 flag outside its compile guard

use_bf16_path is declared only inside #if NCNN_BF16 but is passed unconditionally to sdpa_forward_prefill at the end of forward. When NCNN_BF16=0, this identifier is undefined and the file does not compile. Add a non-BF16 fallback definition (e.g. false) or guard the call argument with the same compile-time condition.

Useful? React with 👍 / 👎.

@futz12

futz12 commented May 5, 2026

Copy link
Copy Markdown
Contributor Author

SDPA x86 Performance Comparison

Baseline: ~/baseline.txt (2026-04-28) vs Current: sdpa-opt-flashattn-x86 HEAD
All times are per-iteration (μs), calculated by dividing reported min by inner loop count.

  • 🟢 Improvement ≥ 5%
  • 🔴 Regression ≥ 5%

Prefill

Config dtype Baseline (μs) Current (μs) Diff
e128 h4 g4 s16 fp32 14.67 9.51 -35.2% 🟢
fp16ps 10.29 6.76 -34.3% 🟢
fp16psa 10.34 6.72 -35.0% 🟢
bf16ps 11.23 7.05 -37.2% 🟢
e128 h4 g4 s32 fp32 23.39 21.20 -9.4% 🟢
fp16ps 23.39 21.21 -9.3% 🟢
fp16psa 23.39 21.20 -9.4% 🟢
bf16ps 23.27 28.19 +21.1% 🔴
e128 h4 g4 s64 fp32 82.50 73.10 -11.4% 🟢
fp16ps 83.00 73.00 -12.0% 🟢
fp16psa 83.60 73.00 -12.7% 🟢
bf16ps 90.00 87.00 -3.3%
e128 h4 g4 s128 fp32 306.40 283.00 -7.6% 🟢
fp16ps 307.90 283.10 -8.1% 🟢
fp16psa 308.60 282.90 -8.3% 🟢
bf16ps 295.80 290.50 -1.8%
e128 h4 g4 s256 fp32 968.00 1109.00 +14.6% 🔴
fp16ps 968.00 1111.00 +14.8% 🔴
fp16psa 968.00 1111.00 +14.8% 🔴
bf16ps 1023.00 1073.00 +4.9%
e128 h4 g4 s512 fp32 3831.00 4401.00 +14.9% 🔴
fp16ps 3854.00 4424.00 +14.8% 🔴
fp16psa 3826.00 4431.00 +15.8% 🔴
bf16ps 3050.00 4096.00 +34.3% 🔴
e512 h8 g8 s16 fp32 50.60 42.48 -16.0% 🟢
fp16ps 50.50 42.47 -15.9% 🟢
fp16psa 50.50 42.11 -16.6% 🟢
bf16ps 55.80 39.29 -29.6% 🟢
e512 h8 g8 s32 fp32 148.30 146.50 -1.2%
fp16ps 148.40 147.10 -0.9%
fp16psa 148.30 147.50 -0.5%
bf16ps 149.40 133.60 -10.6% 🟢
e512 h8 g8 s64 fp32 501.00 531.00 +6.0% 🔴
fp16ps 501.00 539.00 +7.6% 🔴
fp16psa 501.00 539.00 +7.6% 🔴
bf16ps 451.00 687.00 +52.3% 🔴
e512 h8 g8 s128 fp32 1864.00 2058.00 +10.4% 🔴
fp16ps 1868.00 2065.00 +10.5% 🔴
fp16psa 1870.00 2047.00 +9.5% 🔴
bf16ps 1557.00 2196.00 +41.0% 🔴
e512 h8 g8 s256 fp32 7430.00 8220.00 +10.6% 🔴
fp16ps 7430.00 8220.00 +10.6% 🔴
fp16psa 7430.00 8210.00 +10.5% 🔴
bf16ps 7700.00 7980.00 +3.6%
e512 h8 g8 s512 fp32 32240.00 32640.00 +1.2%
fp16ps 32200.00 31960.00 -0.7%
fp16psa 32890.00 32990.00 +0.3%
bf16ps 28240.00 29740.00 +5.3% 🔴
e512 h8 g8 s1024 fp32 121340.00 128440.00 +5.9% 🔴
fp16ps 121270.00 128150.00 +5.7% 🔴
fp16psa 121350.00 128180.00 +5.6% 🔴
bf16ps 102310.00 110010.00 +7.5% 🔴
e4096 h32 g1 s16 fp32 1506.00 4461.00 +196.2% 🔴
fp16ps 1512.00 4471.00 +195.7% 🔴
fp16psa 1521.00 4463.00 +193.4% 🔴
bf16ps 1792.00 4836.00 +169.9% 🔴
e4096 h32 g1 s32 fp32 5020.00 7090.00 +41.2% 🔴
fp16ps 5030.00 7100.00 +41.2% 🔴
fp16psa 5020.00 7120.00 +41.8% 🔴
bf16ps 4990.00 11400.00 +128.5% 🔴
e4096 h32 g1 s64 fp32 22320.00 26020.00 +16.6% 🔴
fp16ps 22360.00 26030.00 +16.4% 🔴
fp16psa 22350.00 26010.00 +16.4% 🔴
bf16ps 21110.00 28320.00 +34.2% 🔴
e4096 h32 g4 s16 fp32 1547.00 1584.00 +2.4%
fp16ps 1553.00 1590.00 +2.4%
fp16psa 1550.00 1584.00 +2.2%
bf16ps 1791.00 4047.00 +126.0% 🔴
e4096 h32 g4 s32 fp32 4966.00 6980.00 +40.6% 🔴
fp16ps 4959.00 7000.00 +41.2% 🔴
fp16psa 5082.00 6970.00 +37.2% 🔴
bf16ps 5042.00 7010.00 +39.0% 🔴
e4096 h32 g4 s64 fp32 22180.00 20990.00 -5.4% 🟢
fp16ps 22220.00 20860.00 -6.1% 🟢
fp16psa 22160.00 21040.00 -5.1% 🟢
bf16ps 21850.00 23510.00 +7.6% 🔴
e4096 h32 g32 s16 fp32 2084.00 1854.00 -11.0% 🟢
fp16ps 2082.00 1854.00 -11.0% 🟢
fp16psa 2084.00 1849.00 -11.3% 🟢
bf16ps 2245.00 1758.00 -21.7% 🟢
e4096 h32 g32 s32 fp32 5990.00 5750.00 -4.0%
fp16ps 5990.00 5740.00 -4.2%
fp16psa 5980.00 5750.00 -3.8%
bf16ps 11610.00 9080.00 -21.8% 🟢
e4096 h32 g32 s64 fp32 24320.00 22570.00 -7.2% 🟢
fp16ps 24490.00 23230.00 -5.1% 🟢
fp16psa 24500.00 22600.00 -7.8% 🟢
bf16ps 29560.00 24060.00 -18.6% 🟢

Decode

Config dtype Baseline (μs) Current (μs) Diff
e128 h4 g4 p0 fp32 5.46 0.44 -92.0% 🟢
fp16ps 5.45 0.44 -92.0% 🟢
fp16psa 5.42 0.44 -91.9% 🟢
e128 h4 g4 p128 fp32 75.20 47.83 -36.4% 🟢
fp16ps 75.30 47.87 -36.4% 🟢
fp16psa 74.60 47.81 -35.9% 🟢
e128 h4 g4 p512 fp32 321.80 223.80 -30.5% 🟢
fp16ps 324.20 224.10 -30.9% 🟢
fp16psa 324.00 224.20 -30.8% 🟢
e128 h4 g4 p1024 fp32 686.00 460.80 -32.8% 🟢
fp16ps 687.00 457.50 -33.4% 🟢
fp16psa 687.00 457.40 -33.4% 🟢
e128 h4 g4 p2048 fp32 1854.00 1169.00 -36.9% 🟢
fp16ps 1802.00 1126.00 -37.5% 🟢
fp16psa 1794.00 1137.00 -36.6% 🟢
e512 h8 g8 p0 fp32 11.38 1.13 -90.1% 🟢
fp16ps 11.34 1.13 -90.0% 🟢
fp16psa 11.39 1.13 -90.1% 🟢
e512 h8 g8 p128 fp32 206.50 94.50 -54.2% 🟢
fp16ps 184.50 95.60 -48.2% 🟢
fp16psa 187.50 95.40 -49.1% 🟢
e512 h8 g8 p512 fp32 4022.00 2854.00 -29.0% 🟢
fp16ps 3962.00 2903.00 -26.7% 🟢
fp16psa 3998.00 2906.00 -27.3% 🟢
e512 h8 g8 p1024 fp32 8910.00 6070.00 -31.9% 🟢
fp16ps 8920.00 6020.00 -32.5% 🟢
fp16psa 8730.00 6080.00 -30.4% 🟢
e512 h8 g8 p2048 fp32 17750.00 12580.00 -29.1% 🟢
fp16ps 17650.00 12640.00 -28.4% 🟢
fp16psa 17810.00 12630.00 -29.1% 🟢
e4096 h32 g32 p0 fp32 69.20 24.13 -65.1% 🟢
fp16ps 69.40 24.25 -65.1% 🟢
fp16psa 69.50 24.11 -65.3% 🟢

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 02e8f84466

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

for (int j = 0; j < block_n; j++)
{
float p = s[j];
float inv_scale = 1.f / vscales[(n_start + j) * num_blocks + vb];

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Multiply int8 PV by stored dequant scale

In the int8 path, sdpa_quantize_key_value_int8_x86 converts each value scale to its reciprocal after quantization, so vscales already holds the dequant multiplier (absmax / 127). Dividing by it here applies the inverse scale instead, so any int8 SDPA decode/prefill with quantized values whose block absmax is not 1 produces PV outputs inflated or shrunk by the wrong factor.

Useful? React with 👍 / 👎.

Comment on lines +742 to +743
sum0 += (float)block_sum0 / (qscales[0] * ks0[0]);
sum1 += (float)block_sum1 / (qscales[0] * ks1[0]);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Multiply AVX2 QK dots by dequant scales

For AVX2 int8 decode, the paired-column loop divides by qscales * kscales, but both scales were already inverted to dequant multipliers before this kernel is called (q_scale[0] = 1.f / q_scale[0] and key_scales_head.row(j)[0] = 1.f / ...). When block_n >= 2 on AVX2 builds, the first columns in each pair get attention logits scaled in the opposite direction while the scalar tail below multiplies, yielding inconsistent and incorrect attention scores.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 32c0e6ba9b

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +3476 to +3477
Mat s_vec_thread = s_vec.channel(get_omp_thread_num());
Mat o_accum_thread = o_accum.channel(get_omp_thread_num());

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Use rows for 2D per-thread workspaces

When prefill runs with opt.num_threads > 1, s_vec and o_accum are 2D Mats whose per-thread storage is in rows, but channel(get_omp_thread_num()) offsets by cstep; for a 2D Mat cstep is the entire allocation, so any OpenMP worker with thread id > 0 writes past the workspace and can corrupt memory or produce wrong attention outputs. Use row(thread_id) or allocate these as 3D Mats before taking channels.

Useful? React with 👍 / 👎.


void decode_qk_dot_int8_avxvnniint8(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale)
{
decode_qk_dot_int8_avxvnni_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Call an AVXVNNIINT8-defined decode kernel

With NCNN_AVXVNNIINT8 enabled, CMake adds -mavxvnniint8, which defines __AVXVNNIINT8__ but not __AVXVNNI__ on GCC; sdpa_x86_int8.h only defines decode_qk_dot_int8_avxvnni_kernel under #if __AVXVNNI__. This wrapper therefore fails to compile in the AVXVNNIINT8 build before any runtime dispatch can happen.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant