diff --git a/src/layer/arm/sdpa_arm.cpp b/src/layer/arm/sdpa_arm.cpp index 28ad86ae9670..24b8b7196bca 100644 --- a/src/layer/arm/sdpa_arm.cpp +++ b/src/layer/arm/sdpa_arm.cpp @@ -8,6 +8,13 @@ namespace ncnn { +static Mat make_persistent_kvcache_view_arm(const Mat& cache, int seqlen) +{ + Mat view = cache; + view.h = seqlen; + return view; +} + SDPA_arm::SDPA_arm() { #if NCNN_ARM82 @@ -151,14 +158,19 @@ int SDPA_arm::forward(const std::vector& bottom_blobs, std::vector& to if (int8_scale_term) { opt.use_packing_layout = false; // TODO enable packing + if (kv_cache == 2) + return SDPA::forward(bottom_blobs, top_blobs, opt); } const Mat& query = bottom_blobs[0]; const Mat& cur_key = bottom_blobs[1]; const Mat& cur_value = bottom_blobs[2]; const Mat& attn_mask_blob = attn_mask ? bottom_blobs[3] : Mat(); - const Mat& past_key = kv_cache ? bottom_blobs[attn_mask ? 4 : 3] : Mat(); - const Mat& past_value = kv_cache ? bottom_blobs[attn_mask ? 5 : 4] : Mat(); + const int blob_offset = attn_mask ? 4 : 3; + if (kv_cache == 2 && (int)bottom_blobs.size() <= blob_offset + 1) + return -1; + const Mat& past_key = kv_cache ? bottom_blobs[blob_offset] : Mat(); + const Mat& past_value = kv_cache ? bottom_blobs[blob_offset + 1] : Mat(); const int embed_dim = query.w; const int src_seqlen = query.h; @@ -166,13 +178,63 @@ int SDPA_arm::forward(const std::vector& bottom_blobs, std::vector& to const int cur_seqlen = cur_key.h; const int num_group = cur_key.c; const int out_embed_dim = cur_value.w; - const int past_seqlen = kv_cache ? past_key.h : 0; + + int past_seqlen = 0; + if (kv_cache == 2) + { + if (past_key.dims == 0 || past_value.dims == 0) + return -1; + past_seqlen = past_key.h; + } + else if (kv_cache == 1 && past_key.dims > 0) + past_seqlen = past_key.h; + + if (kv_cache == 2 && past_value.h != past_seqlen) + return -1; + const int dst_seqlen = past_seqlen + cur_seqlen; const size_t elemsize = query.elemsize; + const int num_heads_per_group = num_heads / num_group; + + // kv_cache==2: in-place append to preallocated buffer + if (kv_cache == 2 && past_key.dims > 0) + { + const int key_capacity = (int)(past_key.cstep / embed_dim); + const int value_capacity = (int)(past_value.cstep / out_embed_dim); + if (dst_seqlen > key_capacity || dst_seqlen > value_capacity) + return -1; + + // In-place append cur_key/cur_value to preallocated buffer + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_group; q++) + { + unsigned char* pk = (unsigned char*)past_key.channel(q).data; + unsigned char* pv = (unsigned char*)past_value.channel(q).data; + memcpy(pk + (size_t)past_seqlen * embed_dim * elemsize, + cur_key.channel(q).data, embed_dim * cur_seqlen * elemsize); + memcpy(pv + (size_t)past_seqlen * out_embed_dim * elemsize, + cur_value.channel(q).data, out_embed_dim * cur_seqlen * elemsize); + } + } Mat key; - if (past_seqlen > 0) + Mat value; + if (kv_cache == 2 && past_key.dims > 0) + { + // Copy dst_seqlen rows from preallocated buffer (already appended above) + key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); + if (key.empty()) return -100; + value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); + if (value.empty()) return -100; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_group; q++) + { + memcpy(key.channel(q), past_key.channel(q), embed_dim * dst_seqlen * elemsize); + memcpy(value.channel(q), past_value.channel(q), out_embed_dim * dst_seqlen * elemsize); + } + } + else if (past_seqlen > 0) { key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); if (key.empty()) @@ -194,8 +256,7 @@ int SDPA_arm::forward(const std::vector& bottom_blobs, std::vector& to key = cur_key; } - Mat value; - if (past_seqlen > 0) + if (kv_cache != 2 && past_seqlen > 0) { value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); if (value.empty()) @@ -212,12 +273,11 @@ int SDPA_arm::forward(const std::vector& bottom_blobs, std::vector& to memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * elemsize); } } - else + else if (kv_cache != 2) { value = cur_value; } - const int num_heads_per_group = num_heads / num_group; #if NCNN_BF16 const bool use_bf16_storage = opt.use_bf16_storage && !opt.use_fp16_storage && query.elembits() == 16; @@ -361,7 +421,12 @@ int SDPA_arm::forward(const std::vector& bottom_blobs, std::vector& to value_fp32.release(); - if (kv_cache) + if (kv_cache == 2) + { + top_blobs[1] = make_persistent_kvcache_view_arm(past_key, dst_seqlen); + top_blobs[2] = make_persistent_kvcache_view_arm(past_value, dst_seqlen); + } + else if (kv_cache) { top_blobs[1] = key; top_blobs[2] = value; diff --git a/src/layer/loongarch/sdpa_loongarch.cpp b/src/layer/loongarch/sdpa_loongarch.cpp index 09afe94ae74d..0cca63b907bc 100644 --- a/src/layer/loongarch/sdpa_loongarch.cpp +++ b/src/layer/loongarch/sdpa_loongarch.cpp @@ -3,10 +3,18 @@ #include "sdpa_loongarch.h" +#include "cpu.h" #include "layer_type.h" namespace ncnn { +static Mat make_persistent_kvcache_view_loongarch(const Mat& cache, int seqlen) +{ + Mat view = cache; + view.h = seqlen; + return view; +} + SDPA_loongarch::SDPA_loongarch() { #if NCNN_BF16 @@ -143,8 +151,11 @@ int SDPA_loongarch::forward(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector 0) + past_seqlen = past_key.h; + + if (kv_cache == 2 && past_value.h != past_seqlen) + return -1; + const int dst_seqlen = past_seqlen + cur_seqlen; const size_t elemsize = query.elemsize; Mat key; - if (past_seqlen > 0) + Mat value; + if (kv_cache == 2 && past_key.dims > 0) + { + const int key_capacity = (int)(past_key.cstep / embed_dim); + const int value_capacity = (int)(past_value.cstep / out_embed_dim); + if (dst_seqlen > key_capacity || dst_seqlen > value_capacity) + return -1; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_group; q++) + { + unsigned char* pk = (unsigned char*)past_key.channel(q).data; + unsigned char* pv = (unsigned char*)past_value.channel(q).data; + memcpy(pk + (size_t)past_seqlen * embed_dim * elemsize, + cur_key.channel(q).data, embed_dim * cur_seqlen * elemsize); + memcpy(pv + (size_t)past_seqlen * out_embed_dim * elemsize, + cur_value.channel(q).data, out_embed_dim * cur_seqlen * elemsize); + } + + key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); + if (key.empty()) + return -100; + value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); + if (value.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_group; q++) + { + memcpy(key.channel(q), past_key.channel(q), embed_dim * dst_seqlen * elemsize); + memcpy(value.channel(q), past_value.channel(q), out_embed_dim * dst_seqlen * elemsize); + } + } + else if (past_seqlen > 0) { key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); if (key.empty()) @@ -180,8 +236,7 @@ int SDPA_loongarch::forward(const std::vector& bottom_blobs, std::vector 0) + if (kv_cache != 2 && past_seqlen > 0) { value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); if (value.empty()) @@ -198,7 +253,7 @@ int SDPA_loongarch::forward(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& t const Mat& cur_key = bottom_blobs[1]; const Mat& cur_value = bottom_blobs[2]; const Mat& attn_mask_blob = attn_mask ? bottom_blobs[3] : Mat(); - const Mat& past_key = kv_cache ? bottom_blobs[attn_mask ? 4 : 3] : Mat(); - const Mat& past_value = kv_cache ? bottom_blobs[attn_mask ? 5 : 4] : Mat(); + const int blob_offset = attn_mask ? 4 : 3; + if (kv_cache == 2 && (int)bottom_blobs.size() <= blob_offset + 1) + return -1; + const Mat& past_key = kv_cache ? bottom_blobs[blob_offset] : Mat(); + const Mat& past_value = kv_cache ? bottom_blobs[blob_offset + 1] : Mat(); const int embed_dim = query.w; const int src_seqlen = query.h; @@ -152,13 +163,58 @@ int SDPA_mips::forward(const std::vector& bottom_blobs, std::vector& t const int cur_seqlen = cur_key.h; const int num_group = cur_key.c; const int out_embed_dim = cur_value.w; - const int past_seqlen = kv_cache ? past_key.h : 0; + int past_seqlen = 0; + if (kv_cache == 2) + { + if (past_key.dims == 0 || past_value.dims == 0) + return -1; + past_seqlen = past_key.h; + } + else if (kv_cache == 1 && past_key.dims > 0) + past_seqlen = past_key.h; + + if (kv_cache == 2 && past_value.h != past_seqlen) + return -1; + const int dst_seqlen = past_seqlen + cur_seqlen; const size_t elemsize = query.elemsize; Mat key; - if (past_seqlen > 0) + Mat value; + if (kv_cache == 2 && past_key.dims > 0) + { + const int key_capacity = (int)(past_key.cstep / embed_dim); + const int value_capacity = (int)(past_value.cstep / out_embed_dim); + if (dst_seqlen > key_capacity || dst_seqlen > value_capacity) + return -1; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_group; q++) + { + unsigned char* pk = (unsigned char*)past_key.channel(q).data; + unsigned char* pv = (unsigned char*)past_value.channel(q).data; + memcpy(pk + (size_t)past_seqlen * embed_dim * elemsize, + cur_key.channel(q).data, embed_dim * cur_seqlen * elemsize); + memcpy(pv + (size_t)past_seqlen * out_embed_dim * elemsize, + cur_value.channel(q).data, out_embed_dim * cur_seqlen * elemsize); + } + + key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); + if (key.empty()) + return -100; + value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); + if (value.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_group; q++) + { + memcpy(key.channel(q), past_key.channel(q), embed_dim * dst_seqlen * elemsize); + memcpy(value.channel(q), past_value.channel(q), out_embed_dim * dst_seqlen * elemsize); + } + } + else if (past_seqlen > 0) { key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); if (key.empty()) @@ -180,8 +236,7 @@ int SDPA_mips::forward(const std::vector& bottom_blobs, std::vector& t key = cur_key; } - Mat value; - if (past_seqlen > 0) + if (kv_cache != 2 && past_seqlen > 0) { value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); if (value.empty()) @@ -198,7 +253,7 @@ int SDPA_mips::forward(const std::vector& bottom_blobs, std::vector& t memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * elemsize); } } - else + else if (kv_cache != 2) { value = cur_value; } @@ -337,7 +392,12 @@ int SDPA_mips::forward(const std::vector& bottom_blobs, std::vector& t value_fp32.release(); - if (kv_cache) + if (kv_cache == 2) + { + top_blobs[1] = make_persistent_kvcache_view_mips(past_key, dst_seqlen); + top_blobs[2] = make_persistent_kvcache_view_mips(past_value, dst_seqlen); + } + else if (kv_cache) { top_blobs[1] = key; top_blobs[2] = value; diff --git a/src/layer/sdpa.cpp b/src/layer/sdpa.cpp index e1824c4b76e4..ceb3b37c90de 100644 --- a/src/layer/sdpa.cpp +++ b/src/layer/sdpa.cpp @@ -9,6 +9,13 @@ namespace ncnn { +static Mat make_persistent_kvcache_view(const Mat& cache, int seqlen) +{ + Mat view = cache; + view.h = seqlen; + return view; +} + SDPA::SDPA() { } @@ -37,8 +44,11 @@ int SDPA::forward(const std::vector& bottom_blobs, std::vector& top_bl const Mat& cur_key = bottom_blobs[1]; const Mat& cur_value = bottom_blobs[2]; const Mat& attn_mask_blob = attn_mask ? bottom_blobs[3] : Mat(); - const Mat& past_key = kv_cache ? bottom_blobs[attn_mask ? 4 : 3] : Mat(); - const Mat& past_value = kv_cache ? bottom_blobs[attn_mask ? 5 : 4] : Mat(); + const int blob_offset = attn_mask ? 4 : 3; + if (kv_cache == 2 && (int)bottom_blobs.size() <= blob_offset + 1) + return -1; + const Mat& past_key = kv_cache ? bottom_blobs[blob_offset] : Mat(); + const Mat& past_value = kv_cache ? bottom_blobs[blob_offset + 1] : Mat(); const int embed_dim = query.w; const int src_seqlen = query.h; @@ -46,7 +56,20 @@ int SDPA::forward(const std::vector& bottom_blobs, std::vector& top_bl const int cur_seqlen = cur_key.h; const int num_group = cur_key.c; const int out_embed_dim = cur_value.w; - const int past_seqlen = kv_cache ? past_key.h : 0; + + int past_seqlen = 0; + if (kv_cache == 2) + { + if (past_key.dims == 0 || past_value.dims == 0) + return -1; + past_seqlen = past_key.h; + } + else if (kv_cache == 1 && past_key.dims > 0) + past_seqlen = past_key.h; + + if (kv_cache == 2 && past_value.h != past_seqlen) + return -1; + const int dst_seqlen = past_seqlen + cur_seqlen; // assert cur_key.w == embed_dim @@ -66,45 +89,94 @@ int SDPA::forward(const std::vector& bottom_blobs, std::vector& top_bl if (qk_cross.empty()) return -100; - Mat key = cur_key; - if (past_seqlen > 0) + Mat key; + Mat value; + if (kv_cache == 2 && past_key.dims > 0) { + const int key_capacity = (int)(past_key.cstep / embed_dim); + const int value_capacity = (int)(past_value.cstep / out_embed_dim); + if (dst_seqlen > key_capacity || dst_seqlen > value_capacity) + return -1; + + // In-place append: write cur into preallocated past buffer + const size_t elemsize = cur_key.elemsize; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_group; q++) + { + unsigned char* kd = (unsigned char*)past_key.channel(q).data + (size_t)past_seqlen * embed_dim * elemsize; + memcpy(kd, cur_key.channel(q).data, embed_dim * cur_seqlen * elemsize); + unsigned char* vd = (unsigned char*)past_value.channel(q).data + (size_t)past_seqlen * out_embed_dim * elemsize; + memcpy(vd, cur_value.channel(q).data, out_embed_dim * cur_seqlen * elemsize); + } + // Copy dst_seqlen rows into compact fp32 Mats for attention computation key.create(embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator); if (key.empty()) return -100; + value.create(out_embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator); + if (value.empty()) + return -100; - // concat #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < num_group; q++) { - const Mat past_key_head = past_key.channel(q); - const Mat cur_key_head = cur_key.channel(q); - Mat key_head = key.channel(q); + // Convert from source elemsize to fp32 + if (elemsize == 4) + { + memcpy(key.channel(q), past_key.channel(q), embed_dim * dst_seqlen * 4); + memcpy(value.channel(q), past_value.channel(q), out_embed_dim * dst_seqlen * 4); + } + else + { + // fp16/bf16 -> fp32 conversion + const unsigned short* ksrc = (const unsigned short*)past_key.channel(q).data; + float* kdst = (float*)key.channel(q); - memcpy(key_head.row(0), past_key_head, embed_dim * past_seqlen * sizeof(float)); - memcpy(key_head.row(past_seqlen), cur_key_head, embed_dim * cur_seqlen * sizeof(float)); + const unsigned short* vsrc = (const unsigned short*)past_value.channel(q).data; + float* vdst = (float*)value.channel(q); + if (opt.use_bf16_storage) + { + for (int i = 0; i < embed_dim * dst_seqlen; i++) + kdst[i] = ncnn::bfloat16_to_float32(ksrc[i]); + + for (int i = 0; i < out_embed_dim * dst_seqlen; i++) + vdst[i] = ncnn::bfloat16_to_float32(vsrc[i]); + } + else + { + for (int i = 0; i < embed_dim * dst_seqlen; i++) + kdst[i] = ncnn::float16_to_float32(ksrc[i]); + + for (int i = 0; i < out_embed_dim * dst_seqlen; i++) + vdst[i] = ncnn::float16_to_float32(vsrc[i]); + } + } } } - - Mat value = cur_value; - if (past_seqlen > 0) + else if (past_seqlen > 0) { + key.create(embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator); + if (key.empty()) + return -100; + value.create(out_embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator); if (value.empty()) return -100; - // concat + // concat key and value #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < num_group; q++) { - const Mat past_value_head = past_value.channel(q); - const Mat cur_value_head = cur_value.channel(q); - Mat value_head = value.channel(q); - - memcpy(value_head.row(0), past_value_head, out_embed_dim * past_seqlen * sizeof(float)); - memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * sizeof(float)); + memcpy((float*)key.channel(q), past_key.channel(q), embed_dim * past_seqlen * sizeof(float)); + memcpy((float*)key.channel(q) + embed_dim * past_seqlen, cur_key.channel(q), embed_dim * cur_seqlen * sizeof(float)); + memcpy((float*)value.channel(q), past_value.channel(q), out_embed_dim * past_seqlen * sizeof(float)); + memcpy((float*)value.channel(q) + out_embed_dim * past_seqlen, cur_value.channel(q), out_embed_dim * cur_seqlen * sizeof(float)); } } + else + { + key = cur_key; + value = cur_value; + } #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < num_heads; q++) @@ -198,7 +270,12 @@ int SDPA::forward(const std::vector& bottom_blobs, std::vector& top_bl } } - if (kv_cache) + if (kv_cache == 2) + { + top_blobs[1] = make_persistent_kvcache_view(past_key, dst_seqlen); + top_blobs[2] = make_persistent_kvcache_view(past_value, dst_seqlen); + } + else if (kv_cache) { // assert top_blobs.size() == 3 top_blobs[1] = key; @@ -278,8 +355,11 @@ int SDPA::forward_int8(const std::vector& bottom_blobs, std::vector& t const Mat& cur_key = bottom_blobs[1]; const Mat& cur_value = bottom_blobs[2]; const Mat& attn_mask_blob = attn_mask ? bottom_blobs[3] : Mat(); - const Mat& past_key = kv_cache ? bottom_blobs[attn_mask ? 4 : 3] : Mat(); - const Mat& past_value = kv_cache ? bottom_blobs[attn_mask ? 5 : 4] : Mat(); + const int blob_offset = attn_mask ? 4 : 3; + if (kv_cache == 2 && (int)bottom_blobs.size() <= blob_offset + 1) + return -1; + const Mat& past_key = kv_cache ? bottom_blobs[blob_offset] : Mat(); + const Mat& past_value = kv_cache ? bottom_blobs[blob_offset + 1] : Mat(); const int embed_dim = query.w; const int src_seqlen = query.h; @@ -287,7 +367,19 @@ int SDPA::forward_int8(const std::vector& bottom_blobs, std::vector& t const int cur_seqlen = cur_key.h; const int num_group = cur_key.c; const int out_embed_dim = cur_value.w; - const int past_seqlen = kv_cache ? past_key.h : 0; + int past_seqlen = 0; + if (kv_cache == 2) + { + if (past_key.dims == 0 || past_value.dims == 0) + return -1; + past_seqlen = past_key.h; + } + else if (kv_cache == 1 && past_key.dims > 0) + past_seqlen = past_key.h; + + if (kv_cache == 2 && past_value.h != past_seqlen) + return -1; + const int dst_seqlen = past_seqlen + cur_seqlen; // assert cur_key.w == embed_dim @@ -328,7 +420,52 @@ int SDPA::forward_int8(const std::vector& bottom_blobs, std::vector& t return -100; Mat key = cur_key; - if (past_seqlen > 0) + if (kv_cache == 2 && past_key.dims > 0) + { + const int key_capacity = (int)(past_key.cstep / embed_dim); + const int value_capacity = (int)(past_value.cstep / out_embed_dim); + if (dst_seqlen > key_capacity || dst_seqlen > value_capacity) + return -1; + + const size_t elemsize = cur_key.elemsize; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_group; q++) + { + unsigned char* kd = (unsigned char*)past_key.channel(q).data + (size_t)past_seqlen * embed_dim * elemsize; + memcpy(kd, cur_key.channel(q).data, embed_dim * cur_seqlen * elemsize); + unsigned char* vd = (unsigned char*)past_value.channel(q).data + (size_t)past_seqlen * out_embed_dim * elemsize; + memcpy(vd, cur_value.channel(q).data, out_embed_dim * cur_seqlen * elemsize); + } + + key.create(embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator); + if (key.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_group; q++) + { + if (elemsize == 4) + { + memcpy(key.channel(q), past_key.channel(q), embed_dim * dst_seqlen * sizeof(float)); + } + else + { + const unsigned short* ksrc = (const unsigned short*)past_key.channel(q).data; + float* kdst = (float*)key.channel(q); + if (opt.use_bf16_storage) + { + for (int i = 0; i < embed_dim * dst_seqlen; i++) + kdst[i] = ncnn::bfloat16_to_float32(ksrc[i]); + } + else + { + for (int i = 0; i < embed_dim * dst_seqlen; i++) + kdst[i] = ncnn::float16_to_float32(ksrc[i]); + } + } + } + } + else if (past_seqlen > 0) { key.create(embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator); if (key.empty()) @@ -348,7 +485,38 @@ int SDPA::forward_int8(const std::vector& bottom_blobs, std::vector& t } Mat value = cur_value; - if (past_seqlen > 0) + if (kv_cache == 2 && past_key.dims > 0) + { + const size_t elemsize = cur_value.elemsize; + value.create(out_embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator); + if (value.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_group; q++) + { + if (elemsize == 4) + { + memcpy(value.channel(q), past_value.channel(q), out_embed_dim * dst_seqlen * sizeof(float)); + } + else + { + const unsigned short* vsrc = (const unsigned short*)past_value.channel(q).data; + float* vdst = (float*)value.channel(q); + if (opt.use_bf16_storage) + { + for (int i = 0; i < out_embed_dim * dst_seqlen; i++) + vdst[i] = ncnn::bfloat16_to_float32(vsrc[i]); + } + else + { + for (int i = 0; i < out_embed_dim * dst_seqlen; i++) + vdst[i] = ncnn::float16_to_float32(vsrc[i]); + } + } + } + } + else if (past_seqlen > 0) { value.create(out_embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator); if (value.empty()) @@ -485,7 +653,12 @@ int SDPA::forward_int8(const std::vector& bottom_blobs, std::vector& t } } - if (kv_cache) + if (kv_cache == 2) + { + top_blobs[1] = make_persistent_kvcache_view(past_key, dst_seqlen); + top_blobs[2] = make_persistent_kvcache_view(past_value, dst_seqlen); + } + else if (kv_cache) { // assert top_blobs.size() == 3 top_blobs[1] = key; diff --git a/src/layer/vulkan/sdpa_vulkan.cpp b/src/layer/vulkan/sdpa_vulkan.cpp index 337d2de19c0b..5691d93bc317 100644 --- a/src/layer/vulkan/sdpa_vulkan.cpp +++ b/src/layer/vulkan/sdpa_vulkan.cpp @@ -8,6 +8,13 @@ namespace ncnn { +static VkMat make_persistent_kvcache_view(const VkMat& cache, int seqlen) +{ + VkMat view = cache; + view.h = seqlen; + return view; +} + SDPA_vulkan::SDPA_vulkan() { support_vulkan = true; @@ -17,6 +24,8 @@ SDPA_vulkan::SDPA_vulkan() qk_softmax = 0; kvcache_concat = 0; + pipeline_sdpa_kvcache_append = 0; + pipeline_sdpa_qk_cross = 0; pipeline_sdpa_qkv_cross = 0; @@ -58,6 +67,14 @@ int SDPA_vulkan::load_param(const ParamDict& pd) int SDPA_vulkan::create_pipeline(const Option& opt) { + { + std::vector specializations(0); + + pipeline_sdpa_kvcache_append = new Pipeline(vkdev); + pipeline_sdpa_kvcache_append->set_optimal_local_size_xyz(4, 4, 4); + pipeline_sdpa_kvcache_append->create(LayerShaderType::sdpa_kvcache_append, opt, specializations); + } + use_cooperative_matrix = vkdev->info.support_cooperative_matrix() && opt.use_cooperative_matrix && (opt.use_fp16_storage || opt.use_fp16_packed); bool use_bf16_cooperative_matrix = false; @@ -323,6 +340,9 @@ int SDPA_vulkan::create_pipeline(const Option& opt) int SDPA_vulkan::destroy_pipeline(const Option& opt) { + delete pipeline_sdpa_kvcache_append; + pipeline_sdpa_kvcache_append = 0; + delete pipeline_sdpa_qk_cross; pipeline_sdpa_qk_cross = 0; @@ -377,8 +397,11 @@ int SDPA_vulkan::forward(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector 0) + { + past_seqlen = past_key.h; + } + + if (kv_cache == 2 && past_value.h != past_seqlen) + return -1; + const int dst_seqlen = past_seqlen + cur_seqlen; const float _scale = scale == 0.f ? 1.f / sqrt(embed_dim) : scale; @@ -394,7 +431,52 @@ int SDPA_vulkan::forward(const std::vector& bottom_blobs, std::vector 0) + VkMat value; + if (kv_cache == 2) + { + if (past_key.w != embed_dim || past_value.w != out_embed_dim || past_key.c != num_group || past_value.c != num_group) + return -1; + + const int key_capacity = (int)(past_key.cstep / embed_dim); + const int value_capacity = (int)(past_value.cstep / out_embed_dim); + if (dst_seqlen > key_capacity || dst_seqlen > value_capacity) + return -1; + + { + std::vector bindings(2); + bindings[0] = cur_key; + bindings[1] = past_key; + + std::vector constants(6); + constants[0].i = cur_key.w; + constants[1].i = cur_key.h; + constants[2].i = cur_key.c; + constants[3].i = cur_key.cstep; + constants[4].i = past_key.cstep; + constants[5].i = past_seqlen; + + cmd.record_pipeline(pipeline_sdpa_kvcache_append, bindings, constants, cur_key); + } + { + std::vector bindings(2); + bindings[0] = cur_value; + bindings[1] = past_value; + + std::vector constants(6); + constants[0].i = cur_value.w; + constants[1].i = cur_value.h; + constants[2].i = cur_value.c; + constants[3].i = cur_value.cstep; + constants[4].i = past_value.cstep; + constants[5].i = past_seqlen; + + cmd.record_pipeline(pipeline_sdpa_kvcache_append, bindings, constants, cur_value); + } + + key = make_persistent_kvcache_view(past_key, dst_seqlen); + value = make_persistent_kvcache_view(past_value, dst_seqlen); + } + else if (past_seqlen > 0) { key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_vkallocator); if (key.empty()) @@ -416,8 +498,7 @@ int SDPA_vulkan::forward(const std::vector& bottom_blobs, std::vector 0) + if (kv_cache != 2 && past_seqlen > 0) { value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_vkallocator); if (value.empty()) @@ -432,7 +513,8 @@ int SDPA_vulkan::forward(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vectorforward_inplace(qk_cross, cmd, opt); - VkMat value; - if (past_seqlen > 0) + if (kv_cache != 2 && past_seqlen > 0) { value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_vkallocator); if (value.empty()) @@ -599,7 +680,8 @@ int SDPA_vulkan::forward(const std::vector& bottom_blobs, std::vector= p.w || gy >= p.h || gz >= p.c) + return; + + const int cur_offset = gz * p.cur_cstep + gy * p.w + gx; + const int cache_offset = gz * p.cache_cstep + (p.past_seqlen + gy) * p.w + gx; + + buffer_cp1(cache_blob_data, cache_offset, cur_blob_data, cur_offset); +} diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index a319913c3f9e..10a1358d33e9 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -3,10 +3,18 @@ #include "sdpa_x86.h" +#include "cpu.h" #include "layer_type.h" namespace ncnn { +static Mat make_persistent_kvcache_view_x86(const Mat& cache, int seqlen) +{ + Mat view = cache; + view.h = seqlen; + return view; +} + SDPA_x86::SDPA_x86() { #if NCNN_BF16 @@ -143,8 +151,11 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to const Mat& cur_key = bottom_blobs[1]; const Mat& cur_value = bottom_blobs[2]; const Mat& attn_mask_blob = attn_mask ? bottom_blobs[3] : Mat(); - const Mat& past_key = kv_cache ? bottom_blobs[attn_mask ? 4 : 3] : Mat(); - const Mat& past_value = kv_cache ? bottom_blobs[attn_mask ? 5 : 4] : Mat(); + const int blob_offset = attn_mask ? 4 : 3; + if (kv_cache == 2 && (int)bottom_blobs.size() <= blob_offset + 1) + return -1; + const Mat& past_key = kv_cache ? bottom_blobs[blob_offset] : Mat(); + const Mat& past_value = kv_cache ? bottom_blobs[blob_offset + 1] : Mat(); const int embed_dim = query.w; const int src_seqlen = query.h; @@ -152,13 +163,58 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to const int cur_seqlen = cur_key.h; const int num_group = cur_key.c; const int out_embed_dim = cur_value.w; - const int past_seqlen = kv_cache ? past_key.h : 0; + int past_seqlen = 0; + if (kv_cache == 2) + { + if (past_key.dims == 0 || past_value.dims == 0) + return -1; + past_seqlen = past_key.h; + } + else if (kv_cache == 1 && past_key.dims > 0) + past_seqlen = past_key.h; + + if (kv_cache == 2 && past_value.h != past_seqlen) + return -1; + const int dst_seqlen = past_seqlen + cur_seqlen; const size_t elemsize = query.elemsize; Mat key; - if (past_seqlen > 0) + Mat value; + if (kv_cache == 2 && past_key.dims > 0) + { + const int key_capacity = (int)(past_key.cstep / embed_dim); + const int value_capacity = (int)(past_value.cstep / out_embed_dim); + if (dst_seqlen > key_capacity || dst_seqlen > value_capacity) + return -1; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_group; q++) + { + unsigned char* pk = (unsigned char*)past_key.channel(q).data; + unsigned char* pv = (unsigned char*)past_value.channel(q).data; + memcpy(pk + (size_t)past_seqlen * embed_dim * elemsize, + cur_key.channel(q).data, embed_dim * cur_seqlen * elemsize); + memcpy(pv + (size_t)past_seqlen * out_embed_dim * elemsize, + cur_value.channel(q).data, out_embed_dim * cur_seqlen * elemsize); + } + + key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); + if (key.empty()) + return -100; + value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); + if (value.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_group; q++) + { + memcpy(key.channel(q), past_key.channel(q), embed_dim * dst_seqlen * elemsize); + memcpy(value.channel(q), past_value.channel(q), out_embed_dim * dst_seqlen * elemsize); + } + } + else if (past_seqlen > 0) { key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); if (key.empty()) @@ -180,8 +236,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to key = cur_key; } - Mat value; - if (past_seqlen > 0) + if (kv_cache != 2 && past_seqlen > 0) { value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); if (value.empty()) @@ -198,7 +253,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * elemsize); } } - else + else if (kv_cache != 2) { value = cur_value; } @@ -337,7 +392,12 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to value_fp32.release(); - if (kv_cache) + if (kv_cache == 2) + { + top_blobs[1] = make_persistent_kvcache_view_x86(past_key, dst_seqlen); + top_blobs[2] = make_persistent_kvcache_view_x86(past_value, dst_seqlen); + } + else if (kv_cache) { top_blobs[1] = key; top_blobs[2] = value; diff --git a/tests/perf/CMakeLists.txt b/tests/perf/CMakeLists.txt index 10c0535d8087..515736a6b853 100644 --- a/tests/perf/CMakeLists.txt +++ b/tests/perf/CMakeLists.txt @@ -43,4 +43,8 @@ ncnn_add_layer_perf(BatchNorm) if(WITH_LAYER_sdpa) ncnn_add_perf(sdpa_decode) ncnn_add_perf(sdpa_prefill) + ncnn_add_perf(sdpa_flash) + ncnn_add_perf(sdpa_mem) + ncnn_add_perf(sdpa_mla_kvcache) + ncnn_add_perf(sdpa_mla_sweep) endif() diff --git a/tests/perf/perf_sdpa_flash.cpp b/tests/perf/perf_sdpa_flash.cpp new file mode 100644 index 000000000000..eb231d5e8040 --- /dev/null +++ b/tests/perf/perf_sdpa_flash.cpp @@ -0,0 +1,157 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +// Micro-benchmark: flash attention (kv_cache=2) vs Gemm baseline (kv_cache=0) +// at the exact dimensions used by Youtu-LLM-2B (MLA, d_k=192, d_v=128). +// +// Both paths receive identical input shapes - the only difference is which +// internal code path SDPA_arm takes. Run side-by-side to see the algorithmic +// improvement of FlashAttention-2 vs ncnn's Gemm-based attention. + +#include "perfutil.h" + +static ncnn::Mat PerfCausalMask(int src_seqlen, int past_seqlen) +{ + const int total = past_seqlen + src_seqlen; + ncnn::Mat mask(total, src_seqlen); + mask.fill(0.f); + for (int i = 0; i < src_seqlen; i++) + { + float* row = mask.row(i); + for (int j = past_seqlen + i + 1; j < total; j++) + row[j] = -1e38f; + } + return mask; +} + +// Shared dimensions: matching Youtu-LLM-2B +// d_k = 192 (MLA query/key head dim) +// d_v = 128 (MLA value head dim) +// heads = 128 +// groups = 16 (GQA 8:1) + +// kv_cache=0: pure prefill via ncnn Gemm path +static void perf_gemm_prefill(int d_k, int d_v, int heads, int groups, int src_seqlen) +{ + if (!perf_match_env_int("NCNN_PERF_SDPA_M", src_seqlen)) + return; + + const bool causal = perf_env_int("NCNN_PERF_SDPA_CAUSAL", 0, 0) != 0; + + ncnn::ParamDict pd; + pd.set(5, causal ? 1 : 0); // attn_mask + pd.set(6, 0.f); // scale = 0 (default 1/sqrt(d_k)) + pd.set(7, 0); // kv_cache = 0 + + std::vector weights(0); + std::vector inputs; + inputs.resize(causal ? 4 : 3); + inputs[0] = PerfMat(d_k, src_seqlen, heads); // q + inputs[1] = PerfMat(d_k, src_seqlen, groups); // k + inputs[2] = PerfMat(d_v, src_seqlen, groups); // v + if (causal) + inputs[3] = PerfCausalMask(src_seqlen, 0); + + perf_layer("SDPA", pd, weights, inputs, 1, + "GEMM d_k=%d d_v=%d h=%d g=%d M=%d causal=%d", + d_k, d_v, heads, groups, src_seqlen, causal ? 1 : 0); +} + +// kv_cache=2 prefill: in-place append (past_seqlen=0) + flash prefill +static void perf_flash_prefill(int d_k, int d_v, int heads, int groups, int src_seqlen, int n_ctx) +{ + if (!perf_match_env_int("NCNN_PERF_SDPA_M", src_seqlen)) + return; + + const bool causal = perf_env_int("NCNN_PERF_SDPA_CAUSAL", 0, 0) != 0; + + ncnn::ParamDict pd; + pd.set(5, causal ? 1 : 0); // attn_mask + pd.set(6, 0.f); // scale + pd.set(7, 2); // kv_cache = 2 (in-place append + flash) + + std::vector weights(0); + std::vector inputs; + inputs.resize(causal ? 6 : 5); + inputs[0] = PerfMat(d_k, src_seqlen, heads); // q + inputs[1] = PerfMat(d_k, src_seqlen, groups); // cur_k + inputs[2] = PerfMat(d_v, src_seqlen, groups); // cur_v + int offset = 3; + if (causal) + inputs[offset++] = PerfCausalMask(src_seqlen, 0); + ncnn::Mat past_key = PerfMat(d_k, n_ctx, groups); + ncnn::Mat past_value = PerfMat(d_v, n_ctx, groups); + past_key.h = 0; + past_value.h = 0; + inputs[offset++] = past_key; // past_k view (capacity in cstep) + inputs[offset++] = past_value; // past_v view (capacity in cstep) + + perf_layer("SDPA", pd, weights, inputs, 3, + "FLASH d_k=%d d_v=%d h=%d g=%d M=%d ctx=%d causal=%d", + d_k, d_v, heads, groups, src_seqlen, n_ctx, causal ? 1 : 0); +} + +// kv_cache=2 decode: in-place append (past_seqlen>0) + flash decode +static void perf_flash_decode(int d_k, int d_v, int heads, int groups, int past_seqlen, int n_ctx) +{ + if (!perf_match_env_int("NCNN_PERF_SDPA_PAST", past_seqlen)) + return; + + ncnn::ParamDict pd; + pd.set(5, 0); + pd.set(6, 0.f); + pd.set(7, 2); + + std::vector weights(0); + std::vector inputs(5); + inputs[0] = PerfMat(d_k, 1, heads); // q (1 token) + inputs[1] = PerfMat(d_k, 1, groups); // cur_k (1 token) + inputs[2] = PerfMat(d_v, 1, groups); // cur_v (1 token) + inputs[3] = PerfMat(d_k, n_ctx, groups); // past_k capacity + inputs[4] = PerfMat(d_v, n_ctx, groups); // past_v capacity + inputs[3].h = past_seqlen; + inputs[4].h = past_seqlen; + + perf_layer("SDPA", pd, weights, inputs, 3, + "FLASH d_k=%d d_v=%d h=%d g=%d past=%d ctx=%d", + d_k, d_v, heads, groups, past_seqlen, n_ctx); +} + +int main() +{ + const int d_k = 192; + const int d_v = 128; + const int heads = 128; + const int groups = 16; + const int n_ctx = 4096; + + const bool run_prefill = !perf_has_env("NCNN_PERF_SDPA_DECODE_ONLY"); + const bool run_decode = !perf_has_env("NCNN_PERF_SDPA_PREFILL_ONLY"); + + if (run_prefill) + { + fprintf(stdout, "=== Prefill: Gemm vs Flash (Youtu-LLM-2B dims) ===\n\n"); + + int seqlens[] = {32, 64, 128, 256, 512, 1024}; + for (int i = 0; i < (int)(sizeof(seqlens) / sizeof(seqlens[0])); i++) + { + int M = seqlens[i]; + perf_gemm_prefill(d_k, d_v, heads, groups, M); + perf_flash_prefill(d_k, d_v, heads, groups, M, n_ctx); + fprintf(stdout, "\n"); + } + } + + if (run_decode) + { + fprintf(stdout, "=== Decode: Flash with varying past_seqlen ===\n\n"); + + int pasts[] = {32, 128, 256, 512, 1024, 2048}; + for (int i = 0; i < (int)(sizeof(pasts) / sizeof(pasts[0])); i++) + { + perf_flash_decode(d_k, d_v, heads, groups, pasts[i], n_ctx); + } + } + + return 0; +} diff --git a/tests/perf/perf_sdpa_mem.cpp b/tests/perf/perf_sdpa_mem.cpp new file mode 100644 index 000000000000..a43be265a799 --- /dev/null +++ b/tests/perf/perf_sdpa_mem.cpp @@ -0,0 +1,273 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause +// +// Memory comparison: SDPA Gemm path vs Flash path at MLA dimensions. +// +// Each path is run once with a tracking allocator that records: +// - peak workspace bytes (qk_cross etc) +// - peak blob bytes (top_blob, key/value copies) +// - total bytes allocated cumulative over the call +// +// Flash's main theoretical advantage is avoiding the O(M*N*heads) qk_cross +// matrix. This tool quantifies that for our model dims. + +#include "allocator.h" +#include "benchmark.h" +#include "layer.h" +#include "mat.h" +#include "modelbin.h" +#include "option.h" +#include "paramdict.h" + +#include +#include +#include + +// ----------------------------------------------------------------------------- +// Tracking allocator: counts bytes through fastMalloc/fastFree. +// Reports peak (max live) and total (cumulative allocated) bytes. +// ----------------------------------------------------------------------------- +class TrackingAllocator : public ncnn::Allocator +{ +public: + TrackingAllocator() : peak_bytes(0), live_bytes(0), total_bytes(0), num_allocs(0) {} + + virtual void* fastMalloc(size_t size) + { + // Allocate with header to remember size on free + size_t total = size + 16; + void* raw = ncnn::fastMalloc(total); + ((size_t*)raw)[0] = size; + live_bytes += size; + total_bytes += size; + num_allocs++; + if (live_bytes > peak_bytes) peak_bytes = live_bytes; + return (void*)((char*)raw + 16); + } + + virtual void fastFree(void* ptr) + { + if (!ptr) return; + void* raw = (void*)((char*)ptr - 16); + size_t size = ((size_t*)raw)[0]; + live_bytes -= size; + ncnn::fastFree(raw); + } + + void reset() + { + peak_bytes = 0; + live_bytes = 0; + total_bytes = 0; + num_allocs = 0; + } + + size_t peak_bytes; + size_t live_bytes; + size_t total_bytes; + int num_allocs; +}; + +// ----------------------------------------------------------------------------- +// Helpers +// ----------------------------------------------------------------------------- +static ncnn::Mat make_mat(int w, int h, int c, float v = 0.01f) +{ + ncnn::Mat m(w, h, c); + m.fill(v); + return m; +} + +// ----------------------------------------------------------------------------- +// Run SDPA forward with the given inputs, return tracked memory +// ----------------------------------------------------------------------------- +struct MemReport +{ + size_t ws_peak; + size_t ws_total; + int ws_allocs; + size_t blob_peak; + size_t blob_total; + int blob_allocs; + double time_ms; +}; + +static MemReport run_sdpa(int kv_cache_mode, + int d_k, int d_v, int heads, int groups, + int M, int N, int n_ctx, + bool use_fp16) +{ + MemReport r = {0, 0, 0, 0, 0, 0, 0.0}; + + ncnn::ParamDict pd; + pd.set(5, 0); // attn_mask + pd.set(6, 0.f); // scale (auto) + pd.set(7, kv_cache_mode); + + ncnn::Layer* op = ncnn::create_layer_cpu("SDPA"); + op->load_param(pd); + std::vector weights(0); + ncnn::ModelBinFromMatArray mb(weights.data()); + op->load_model(mb); + + TrackingAllocator ws_alloc; + TrackingAllocator blob_alloc; + + ncnn::Option opt; + opt.lightmode = true; + opt.num_threads = 1; + opt.use_packing_layout = false; + opt.use_fp16_storage = use_fp16; + opt.use_fp16_arithmetic = use_fp16; + opt.use_fp16_packed = use_fp16; + opt.use_bf16_storage = false; + opt.workspace_allocator = &ws_alloc; + opt.blob_allocator = &blob_alloc; + + op->create_pipeline(opt); + + // Build inputs (as fp32; they'll be cast inside SDPA forward) + std::vector inputs; + int n_outputs = 1; + + if (kv_cache_mode == 0) + { + // No kv_cache: just q, k, v + inputs.resize(3); + inputs[0] = make_mat(d_k, M, heads); + inputs[1] = make_mat(d_k, N, groups); + inputs[2] = make_mat(d_v, N, groups); + } + else if (kv_cache_mode == 2) + { + // kv_cache=2: q, cur_k, cur_v, past_k(view over preallocated cache), past_v(view over preallocated cache) + inputs.resize(5); + inputs[0] = make_mat(d_k, M, heads); + inputs[1] = make_mat(d_k, N, groups); // cur_k = N rows (past=0 prefill) + inputs[2] = make_mat(d_v, N, groups); // cur_v + inputs[3] = make_mat(d_k, n_ctx, groups); + inputs[4] = make_mat(d_v, n_ctx, groups); + n_outputs = 3; + } + + // Need to convert inputs for fp16 path + if (use_fp16) + { + for (size_t i = 0; i < inputs.size(); i++) + { + ncnn::Mat casted; + ncnn::cast_float32_to_float16(inputs[i], casted, opt); + inputs[i] = casted; + } + } + + if (kv_cache_mode == 2) + { + inputs[3].h = 0; + inputs[4].h = 0; + } + + // Reset trackers (input creation went through them) + ws_alloc.reset(); + blob_alloc.reset(); + + // Warmup once (caches workspace pool) + { + std::vector outs(n_outputs); + op->forward(inputs, outs, opt); + } + + // Reset, then run a single timed call + ws_alloc.reset(); + blob_alloc.reset(); + + double t0 = ncnn::get_current_time(); + { + std::vector outs(n_outputs); + op->forward(inputs, outs, opt); + } + double t1 = ncnn::get_current_time(); + + r.ws_peak = ws_alloc.peak_bytes; + r.ws_total = ws_alloc.total_bytes; + r.ws_allocs = ws_alloc.num_allocs; + r.blob_peak = blob_alloc.peak_bytes; + r.blob_total = blob_alloc.total_bytes; + r.blob_allocs = blob_alloc.num_allocs; + r.time_ms = t1 - t0; + + op->destroy_pipeline(opt); + delete op; + return r; +} + +// ----------------------------------------------------------------------------- +// Pretty printing +// ----------------------------------------------------------------------------- +static const char* fmt_bytes(size_t b, char* buf) +{ + if (b >= (size_t)1024 * 1024 * 1024) + snprintf(buf, 32, "%6.2f GB", b / (1024.0 * 1024.0 * 1024.0)); + else if (b >= (size_t)1024 * 1024) + snprintf(buf, 32, "%6.2f MB", b / (1024.0 * 1024.0)); + else if (b >= 1024) + snprintf(buf, 32, "%6.2f KB", b / 1024.0); + else + snprintf(buf, 32, "%5zu B", b); + return buf; +} + +int main() +{ + const int d_k = 192; + const int d_v = 128; + const int heads = 128; + const int groups = 16; + const int n_ctx = 4096; + + int seqlens[] = {32, 64, 128, 256, 512, 1024, 2048}; + + fprintf(stdout, + "Per-call SDPA memory comparison (Youtu-LLM-2B dims: d_k=192 d_v=128 h=128 g=16, fp16)\n" + " GEMM = ncnn Gemm path (kv_cache=0)\n" + " FLASH = our flash kernel (kv_cache=2, past=0, n_ctx=%d)\n" + " workspace = qk_cross / K_packed / Q_packed (intermediate)\n" + " blob = top_blob (output) + intermediate Mats (key/value copies)\n" + " total = cumulative bytes allocated (incl. reused pool slots)\n\n", + n_ctx); + + fprintf(stdout, + "%-6s | %-10s | %-10s %-10s %4s | %-10s %-10s %4s | %s\n", + "M=N", "PATH", "WS peak", "WS total", "WS#", + "Blob peak", "Blob total", "B#", + "time(ms)"); + fprintf(stdout, "%.110s\n", "------------------------------------------------------------------------------------------------------------"); + + for (size_t i = 0; i < sizeof(seqlens) / sizeof(seqlens[0]); i++) + { + int s = seqlens[i]; + MemReport gemm = run_sdpa(0, d_k, d_v, heads, groups, s, s, n_ctx, true); + MemReport flash = run_sdpa(2, d_k, d_v, heads, groups, s, s, n_ctx, true); + + char b1[32], b2[32], b3[32], b4[32]; + fprintf(stdout, "%-6d | %-10s | %-10s %-10s %4d | %-10s %-10s %4d | %7.2f\n", + s, "GEMM", + fmt_bytes(gemm.ws_peak, b1), fmt_bytes(gemm.ws_total, b2), gemm.ws_allocs, + fmt_bytes(gemm.blob_peak, b3), fmt_bytes(gemm.blob_total, b4), gemm.blob_allocs, + gemm.time_ms); + fprintf(stdout, "%-6s | %-10s | %-10s %-10s %4d | %-10s %-10s %4d | %7.2f\n", + "", "FLASH", + fmt_bytes(flash.ws_peak, b1), fmt_bytes(flash.ws_total, b2), flash.ws_allocs, + fmt_bytes(flash.blob_peak, b3), fmt_bytes(flash.blob_total, b4), flash.blob_allocs, + flash.time_ms); + + // Ratio summary + double ws_ratio = gemm.ws_peak > 0 ? (double)gemm.ws_peak / flash.ws_peak : 0; + double blob_ratio = gemm.blob_peak > 0 ? (double)gemm.blob_peak / (flash.blob_peak ? flash.blob_peak : 1) : 0; + fprintf(stdout, " | SAVING | WS %.1fx %s | Blob %.1fx |\n\n", + ws_ratio, ws_ratio > 1 ? "smaller" : "larger ", + blob_ratio); + } + + return 0; +} diff --git a/tests/perf/perf_sdpa_mla_kvcache.cpp b/tests/perf/perf_sdpa_mla_kvcache.cpp new file mode 100644 index 000000000000..085a32bbb7ef --- /dev/null +++ b/tests/perf/perf_sdpa_mla_kvcache.cpp @@ -0,0 +1,70 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "perfutil.h" + +static bool should_run_case(int kv_cache, int past_seqlen, int threads) +{ + return perf_match_env_int("NCNN_PERF_SDPA_KVCACHE", kv_cache) + && perf_match_env_int("NCNN_PERF_SDPA_PAST", past_seqlen) + && perf_match_env_int("NCNN_PERF_THREADS", threads); +} + +static void perf_sdpa_mla_decode(int kv_cache, int past_seqlen) +{ + const int d_k = perf_env_int("NCNN_PERF_SDPA_DK", 192, 1); + const int d_v = perf_env_int("NCNN_PERF_SDPA_DV", 128, 1); + const int num_heads = perf_env_int("NCNN_PERF_SDPA_HEADS", 128, 1); + const int num_groups = perf_env_int("NCNN_PERF_SDPA_GROUPS", 16, 1); + const int n_ctx_default = past_seqlen + 1 > 4096 ? past_seqlen + 1 : 4096; + const int n_ctx = perf_env_int("NCNN_PERF_SDPA_CTX", n_ctx_default, past_seqlen + 1); + const int threads = perf_env_int("NCNN_PERF_THREADS", 1, 1); + + if (!should_run_case(kv_cache, past_seqlen, threads)) + return; + + ncnn::ParamDict pd; + pd.set(5, 0); // attn_mask = 0 + pd.set(6, 0.f); // scale = 0 (default 1/sqrt(d_k)) + pd.set(7, kv_cache); + + std::vector weights(0); + std::vector inputs; + + if (kv_cache == 1) + { + inputs.resize(5); + inputs[0] = PerfMat(d_k, 1, num_heads); // q + inputs[1] = PerfMat(d_k, 1, num_groups); // cur_k + inputs[2] = PerfMat(d_v, 1, num_groups); // cur_v + inputs[3] = PerfMat(d_k, past_seqlen, num_groups); // past_k + inputs[4] = PerfMat(d_v, past_seqlen, num_groups); // past_v + } + else + { + inputs.resize(5); + inputs[0] = PerfMat(d_k, 1, num_heads); // q + inputs[1] = PerfMat(d_k, 1, num_groups); // cur_k + inputs[2] = PerfMat(d_v, 1, num_groups); // cur_v + inputs[3] = PerfMat(d_k, n_ctx, num_groups); // preallocated past_k + inputs[4] = PerfMat(d_v, n_ctx, num_groups); // preallocated past_v + inputs[3].h = past_seqlen; + inputs[4].h = past_seqlen; + } + + perf_layer("SDPA", pd, weights, inputs, 3, + "MLA kv_cache=%d d_k=%d d_v=%d h=%d g=%d past=%d ctx=%d t=%d", + kv_cache, d_k, d_v, num_heads, num_groups, past_seqlen, n_ctx, threads); +} + +int main() +{ + int pasts[] = {0, 128, 512, 1024, 2048, 4096}; + for (int i = 0; i < (int)(sizeof(pasts) / sizeof(pasts[0])); i++) + { + perf_sdpa_mla_decode(1, pasts[i]); + perf_sdpa_mla_decode(2, pasts[i]); + } + + return 0; +} diff --git a/tests/perf/perf_sdpa_mla_sweep.cpp b/tests/perf/perf_sdpa_mla_sweep.cpp new file mode 100644 index 000000000000..045723f79668 --- /dev/null +++ b/tests/perf/perf_sdpa_mla_sweep.cpp @@ -0,0 +1,376 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "allocator.h" +#include "benchmark.h" +#include "cpu.h" +#include "layer.h" +#include "mat.h" +#include "modelbin.h" +#include "option.h" +#include "paramdict.h" + +#include +#include +#include +#include +#include + +class TrackingAllocator : public ncnn::Allocator +{ +public: + TrackingAllocator() : peak_bytes(0), live_bytes(0), baseline_live_bytes(0), total_bytes(0), num_allocs(0) {} + + virtual void* fastMalloc(size_t size) + { + size_t total = size + 16; + void* raw = ncnn::fastMalloc(total); + ((size_t*)raw)[0] = size; + live_bytes += size; + total_bytes += size; + num_allocs++; + size_t relative_live_bytes = live_bytes > baseline_live_bytes ? live_bytes - baseline_live_bytes : 0; + if (relative_live_bytes > peak_bytes) + peak_bytes = relative_live_bytes; + return (void*)((char*)raw + 16); + } + + virtual void fastFree(void* ptr) + { + if (!ptr) + return; + + void* raw = (void*)((char*)ptr - 16); + size_t size = ((size_t*)raw)[0]; + if (size <= live_bytes) + live_bytes -= size; + else + live_bytes = 0; + ncnn::fastFree(raw); + } + + void reset() + { + peak_bytes = 0; + baseline_live_bytes = live_bytes; + total_bytes = 0; + num_allocs = 0; + } + + size_t peak_bytes; + size_t live_bytes; + size_t baseline_live_bytes; + size_t total_bytes; + int num_allocs; +}; + +struct Report +{ + double median_ms; + size_t ws_peak; + size_t blob_peak; + size_t op_peak; + size_t ws_total; + size_t blob_total; + int ws_allocs; + int blob_allocs; +}; + +static int env_int(const char* name, int default_value) +{ + const char* s = getenv(name); + if (!s || !s[0]) + return default_value; + return atoi(s); +} + +static bool env_match_int(const char* name, int value) +{ + const char* s = getenv(name); + if (!s || !s[0]) + return true; + return atoi(s) == value; +} + +static bool env_match_str(const char* name, const char* value) +{ + const char* s = getenv(name); + if (!s || !s[0]) + return true; + return strcmp(s, value) == 0; +} + +static ncnn::Mat make_mat(int w, int h, int c, float v = 0.01f) +{ + ncnn::Mat m(w, h, c); + m.fill(v); + return m; +} + +static ncnn::Mat make_causal_mask(int src_seqlen, int past_seqlen) +{ + const int total = src_seqlen + past_seqlen; + ncnn::Mat mask(total, src_seqlen); + mask.fill(0.f); + + for (int i = 0; i < src_seqlen; i++) + { + float* row = mask.row(i); + for (int j = past_seqlen + i + 1; j < total; j++) + row[j] = -1e38f; + } + + return mask; +} + +static void convert_input_layout(const ncnn::Mat& src, ncnn::Mat& dst, const ncnn::Option& opt, const ncnn::Layer* op) +{ + ncnn::Mat casted; + +#if NCNN_ARM82 + if (opt.use_fp16_storage && ncnn::cpu_support_arm_asimdhp() && op->support_fp16_storage) + { + ncnn::cast_float32_to_float16(src, casted, opt); + } + else +#endif + if (opt.use_fp16_storage && op->support_fp16_storage) + { + ncnn::cast_float32_to_float16(src, casted, opt); + } + else + { + casted = src; + } + + if (opt.use_packing_layout && op->support_packing) + { + int elemcount = 0; + if (casted.dims == 1) elemcount = casted.elempack * casted.w; + if (casted.dims == 2) elemcount = casted.elempack * casted.h; + if (casted.dims == 3 || casted.dims == 4) elemcount = casted.elempack * casted.c; + + int dst_elempack = 1; + if (casted.elembits() == 16) + { +#if NCNN_ARM82 + if (elemcount % 8 == 0 && ncnn::cpu_support_arm_asimdhp() && opt.use_fp16_arithmetic && op->support_fp16_storage) + dst_elempack = 8; + else if (elemcount % 4 == 0) + dst_elempack = 4; +#else + if (elemcount % 4 == 0) + dst_elempack = 4; +#endif + } + else if (casted.elembits() == 32) + { + if (elemcount % 4 == 0) + dst_elempack = 4; + } + else if (casted.elembits() == 8) + { + if (elemcount % 8 == 0) + dst_elempack = 8; + } + + ncnn::convert_packing(casted, dst, dst_elempack, opt); + } + else + { + dst = casted; + } +} + +static void convert_input_layout_persistent_view(const ncnn::Mat& src, ncnn::Mat& dst, const ncnn::Option& opt, const ncnn::Layer* op) +{ + ncnn::Mat src_full = src; + const int capacity = src.w == 0 ? src.h : (int)(src.cstep / src.w); + src_full.h = capacity; + + convert_input_layout(src_full, dst, opt, op); + dst.h = src.h; +} + +static Report run_sdpa(const char* phase, int kv_cache, int M, int past, int n_ctx, int threads) +{ + const int d_k = 192; + const int d_v = 128; + const int heads = 128; + const int groups = 16; + const int warmup_count = env_int("NCNN_MLA_SWEEP_WARMUP", 3); + const int run_count = env_int("NCNN_MLA_SWEEP_RUNS", 9); + + ncnn::ParamDict pd; + pd.set(5, strcmp(phase, "prefill") == 0 ? 1 : 0); + pd.set(6, 0.f); + pd.set(7, kv_cache); + + ncnn::Layer* op = ncnn::create_layer_cpu("SDPA"); + op->load_param(pd); + std::vector weights(0); + ncnn::ModelBinFromMatArray mb(weights.data()); + op->load_model(mb); + + TrackingAllocator ws_alloc; + TrackingAllocator blob_alloc; + + ncnn::Option opt; + opt.lightmode = true; + opt.num_threads = threads; + opt.use_packing_layout = true; + opt.use_fp16_storage = true; + opt.use_fp16_arithmetic = true; + opt.use_fp16_packed = true; + opt.use_bf16_storage = false; + opt.workspace_allocator = &ws_alloc; + opt.blob_allocator = &blob_alloc; + + op->create_pipeline(opt); + + std::vector inputs; + int output_count = 1; + + if (strcmp(phase, "prefill") == 0 && kv_cache == 0) + { + inputs.resize(4); + inputs[0] = make_mat(d_k, M, heads); + inputs[1] = make_mat(d_k, M, groups); + inputs[2] = make_mat(d_v, M, groups); + inputs[3] = make_causal_mask(M, 0); + } + else if (strcmp(phase, "prefill") == 0 && kv_cache == 2) + { + inputs.resize(6); + inputs[0] = make_mat(d_k, M, heads); + inputs[1] = make_mat(d_k, M, groups); + inputs[2] = make_mat(d_v, M, groups); + inputs[3] = make_causal_mask(M, 0); + inputs[4] = make_mat(d_k, n_ctx, groups); + inputs[5] = make_mat(d_v, n_ctx, groups); + inputs[4].h = 0; + inputs[5].h = 0; + output_count = 3; + } + else if (strcmp(phase, "decode") == 0 && kv_cache == 1) + { + inputs.resize(5); + inputs[0] = make_mat(d_k, 1, heads); + inputs[1] = make_mat(d_k, 1, groups); + inputs[2] = make_mat(d_v, 1, groups); + inputs[3] = make_mat(d_k, past, groups); + inputs[4] = make_mat(d_v, past, groups); + output_count = 3; + } + else if (strcmp(phase, "decode") == 0 && kv_cache == 2) + { + inputs.resize(5); + inputs[0] = make_mat(d_k, 1, heads); + inputs[1] = make_mat(d_k, 1, groups); + inputs[2] = make_mat(d_v, 1, groups); + inputs[3] = make_mat(d_k, n_ctx, groups); + inputs[4] = make_mat(d_v, n_ctx, groups); + inputs[3].h = past; + inputs[4].h = past; + output_count = 3; + } + + std::vector converted_inputs(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) + { + const int cache_offset = strcmp(phase, "prefill") == 0 ? 4 : 3; + if (kv_cache == 2 && (i == (size_t)cache_offset || i == (size_t)(cache_offset + 1))) + convert_input_layout_persistent_view(inputs[i], converted_inputs[i], opt, op); + else + convert_input_layout(inputs[i], converted_inputs[i], opt, op); + } + + for (int i = 0; i < warmup_count; i++) + { + std::vector outputs(output_count); + op->forward(converted_inputs, outputs, opt); + } + + std::vector times; + times.reserve(run_count); + Report r = {0.0, 0, 0, 0, 0, 0, 0, 0}; + + for (int i = 0; i < run_count; i++) + { + ws_alloc.reset(); + blob_alloc.reset(); + + double t0 = ncnn::get_current_time(); + { + std::vector outputs(output_count); + op->forward(converted_inputs, outputs, opt); + } + double t1 = ncnn::get_current_time(); + + times.push_back(t1 - t0); + if (ws_alloc.peak_bytes + blob_alloc.peak_bytes > r.op_peak) + { + r.ws_peak = ws_alloc.peak_bytes; + r.blob_peak = blob_alloc.peak_bytes; + r.op_peak = ws_alloc.peak_bytes + blob_alloc.peak_bytes; + r.ws_total = ws_alloc.total_bytes; + r.blob_total = blob_alloc.total_bytes; + r.ws_allocs = ws_alloc.num_allocs; + r.blob_allocs = blob_alloc.num_allocs; + } + } + + std::sort(times.begin(), times.end()); + r.median_ms = times[times.size() / 2]; + + op->destroy_pipeline(opt); + delete op; + + return r; +} + +static void print_report(const char* mode, const char* phase, int kv_cache, int len, int n_ctx, int threads, const Report& r) +{ + fprintf(stdout, + "%-8s %-7s kv=%d len=%-5d ctx=%-5d t=%d median_ms=%9.4f ws_peak=%10zu blob_peak=%10zu op_peak=%10zu ws_total=%10zu blob_total=%10zu ws_allocs=%3d blob_allocs=%3d\n", + mode, phase, kv_cache, len, n_ctx, threads, r.median_ms, + r.ws_peak, r.blob_peak, r.op_peak, r.ws_total, r.blob_total, r.ws_allocs, r.blob_allocs); +} + +int main() +{ + const int threads = env_int("NCNN_PERF_THREADS", 4); + const int n_ctx = env_int("NCNN_PERF_SDPA_CTX", 4096); + + int prefill_lengths[] = {128, 256, 512, 1024}; + int decode_lengths[] = {128, 512, 1024, 2048}; + + fprintf(stdout, "# Youtu MLA SDPA sweep: d_k=192 d_v=128 heads=128 groups=16 dtype=fp16psa\n"); + fprintf(stdout, "# env filters: NCNN_MLA_SWEEP_MODE=baseline|current NCNN_MLA_SWEEP_PHASE=prefill|decode NCNN_PERF_SDPA_M=... NCNN_PERF_SDPA_PAST=...\n"); + + for (size_t i = 0; i < sizeof(prefill_lengths) / sizeof(prefill_lengths[0]); i++) + { + const int M = prefill_lengths[i]; + if (!env_match_str("NCNN_MLA_SWEEP_PHASE", "prefill") || !env_match_int("NCNN_PERF_SDPA_M", M)) + continue; + + if (env_match_str("NCNN_MLA_SWEEP_MODE", "baseline")) + print_report("baseline", "prefill", 0, M, n_ctx, threads, run_sdpa("prefill", 0, M, 0, n_ctx, threads)); + if (env_match_str("NCNN_MLA_SWEEP_MODE", "current")) + print_report("current", "prefill", 2, M, n_ctx, threads, run_sdpa("prefill", 2, M, 0, n_ctx, threads)); + } + + for (size_t i = 0; i < sizeof(decode_lengths) / sizeof(decode_lengths[0]); i++) + { + const int past = decode_lengths[i]; + if (!env_match_str("NCNN_MLA_SWEEP_PHASE", "decode") || !env_match_int("NCNN_PERF_SDPA_PAST", past)) + continue; + + if (env_match_str("NCNN_MLA_SWEEP_MODE", "baseline")) + print_report("baseline", "decode", 1, past, n_ctx, threads, run_sdpa("decode", 1, 1, past, n_ctx, threads)); + if (env_match_str("NCNN_MLA_SWEEP_MODE", "current")) + print_report("current", "decode", 2, past, n_ctx, threads, run_sdpa("decode", 2, 1, past, n_ctx, threads)); + } + + return 0; +} diff --git a/tests/perf/perfutil.cpp b/tests/perf/perfutil.cpp index f76d8b0ddc7f..ec8e7d5aeefe 100644 --- a/tests/perf/perfutil.cpp +++ b/tests/perf/perfutil.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #if NCNN_VULKAN @@ -25,6 +26,40 @@ #define PERF_RUN_COUNT 20 #define PERF_TARGET_MIN_MS 5.0 +int perf_env_int(const char* name, int default_value, int min_value) +{ + const char* s = getenv(name); + if (!s || !s[0]) + return default_value; + + int v = atoi(s); + return v < min_value ? min_value : v; +} + +bool perf_has_env(const char* name) +{ + const char* s = getenv(name); + return s && s[0]; +} + +bool perf_match_env_int(const char* name, int value) +{ + const char* s = getenv(name); + if (!s || !s[0]) + return true; + + return atoi(s) == value; +} + +bool perf_match_env_string(const char* name, const char* value) +{ + const char* s = getenv(name); + if (!s || !s[0]) + return true; + + return strcmp(s, value) == 0; +} + // benchmark result for a single test case struct PerfResult { @@ -220,6 +255,25 @@ static void convert_input_layout(const ncnn::Mat& src, ncnn::Mat& dst, const ncn } } +static void convert_input_layout_persistent_view(const ncnn::Mat& src, ncnn::Mat& dst, const ncnn::Option& opt, const ncnn::Layer* op) +{ + ncnn::Mat src_full = src; + const int capacity = src.w == 0 ? src.h : (int)(src.cstep / src.w); + src_full.h = capacity; + + convert_input_layout(src_full, dst, opt, op); + dst.h = src.h; +} + +static bool is_sdpa_persistent_kvcache_input(const char* layer_type, const ncnn::ParamDict& pd, size_t input_index) +{ + if (strcmp(layer_type, "SDPA") != 0 || pd.get(7, 0) != 2) + return false; + + const int blob_offset = pd.get(5, 0) ? 4 : 3; + return input_index == (size_t)blob_offset || input_index == (size_t)(blob_offset + 1); +} + // run a single forward pass (pure compute, no conversion) static int run_layer_forward_cpu(ncnn::Layer* op, const std::vector& converted_inputs, @@ -286,7 +340,10 @@ static int perf_layer_cpu(const char* layer_type, const ncnn::ParamDict& pd, std::vector converted(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { - convert_input_layout(inputs[i], converted[i], opt, op); + if (is_sdpa_persistent_kvcache_input(layer_type, pd, i)) + convert_input_layout_persistent_view(inputs[i], converted[i], opt, op); + else + convert_input_layout(inputs[i], converted[i], opt, op); } // warmup and calibrate inner loop count from warmup min time diff --git a/tests/perf/perfutil.h b/tests/perf/perfutil.h index 06b31eeef508..a6e762426fb3 100644 --- a/tests/perf/perfutil.h +++ b/tests/perf/perfutil.h @@ -14,6 +14,11 @@ ncnn::Mat PerfMat(int w, int h, float v = 0.01f); ncnn::Mat PerfMat(int w, int h, int c, float v = 0.01f); ncnn::Mat PerfMat(int w, int h, int d, int c, float v = 0.01f); +int perf_env_int(const char* name, int default_value, int min_value = 0); +bool perf_has_env(const char* name); +bool perf_match_env_int(const char* name, int value); +bool perf_match_env_string(const char* name, const char* value); + // high-level perf entry point: benchmark a layer across all precision and GPU variations // layer_type: ncnn layer type name (e.g. "Convolution") // pd, weights, input(s): layer configuration diff --git a/tests/test_sdpa_kvcache.cpp b/tests/test_sdpa_kvcache.cpp index 1fc84f9c72b3..dcdc63679fb4 100644 --- a/tests/test_sdpa_kvcache.cpp +++ b/tests/test_sdpa_kvcache.cpp @@ -5,6 +5,20 @@ #include +static ncnn::Mat CausalMask(int src_seqlen, int past_seqlen) +{ + const int total = past_seqlen + src_seqlen; + ncnn::Mat mask(total, src_seqlen); + mask.fill(0.f); + for (int i = 0; i < src_seqlen; i++) + { + float* row = mask.row(i); + for (int j = past_seqlen + i + 1; j < total; j++) + row[j] = -1e38f; + } + return mask; +} + static int test_sdpa_kvcache(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int attn_mask, int past_seqlen) { const int embed_dim = q.w; @@ -41,6 +55,285 @@ static int test_sdpa_kvcache(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn: return ret; } +static int test_sdpa_kvcache2(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int attn_mask, int past_seqlen, int max_seqlen) +{ + const int embed_dim = q.w; + const int out_embed_dim = v.w; + const int src_seqlen = q.h; + const int cur_seqlen = k.h; + const int dst_seqlen = past_seqlen + cur_seqlen; + + ncnn::Mat past_key_full = RandomMat(embed_dim, max_seqlen, k.c); + ncnn::Mat past_value_full = RandomMat(out_embed_dim, max_seqlen, v.c); + ncnn::Mat past_key = past_key_full; + ncnn::Mat past_value = past_value_full; + past_key.h = past_seqlen; + past_value.h = past_seqlen; + + ncnn::Mat mask; + if (attn_mask) + mask = CausalMask(src_seqlen, past_seqlen); + + ncnn::ParamDict pd1; + pd1.set(5, attn_mask); + pd1.set(7, 1); // kv_cache + + ncnn::Layer* op1 = ncnn::create_layer_cpu("SDPA"); + op1->load_param(pd1); + + ncnn::ParamDict pd2; + pd2.set(5, attn_mask); + pd2.set(7, 2); // kv_cache + + ncnn::Layer* op2 = ncnn::create_layer_cpu("SDPA"); + op2->load_param(pd2); + + ncnn::Option opt; + opt.num_threads = 1; + + op1->create_pipeline(opt); + op2->create_pipeline(opt); + + std::vector as1(attn_mask ? 6 : 5); + std::vector as2(attn_mask ? 6 : 5); + as1[0] = q; + as1[1] = k; + as1[2] = v; + as2[0] = q; + as2[1] = k; + as2[2] = v; + int offset = 3; + if (attn_mask) + { + as1[offset] = mask; + as2[offset] = mask; + offset++; + } + as1[offset] = past_key; + as1[offset + 1] = past_value; + as2[offset] = past_key; + as2[offset + 1] = past_value; + + std::vector out1(3); + std::vector out2(3); + int ret1 = op1->forward(as1, out1, opt); + int ret2 = op2->forward(as2, out2, opt); + + int ret = 0; + if (ret1 != 0 || ret2 != 0 || CompareMat(out1[0], out2[0], 0.001f) != 0) + ret = -1; + if (ret == 0 && (out2[1].data != past_key_full.data || out2[2].data != past_value_full.data || out2[1].h != dst_seqlen || out2[2].h != dst_seqlen)) + ret = -1; + + op1->destroy_pipeline(opt); + op2->destroy_pipeline(opt); + delete op1; + delete op2; + + if (ret != 0) + fprintf(stderr, "test_sdpa_kvcache2 failed q=(%d %d %d) k=(%d %d %d) v=(%d %d %d) attn_mask=%d past_seqlen=%d max_seqlen=%d\n", q.w, q.h, q.c, k.w, k.h, k.c, v.w, v.h, v.c, attn_mask, past_seqlen, max_seqlen); + + return ret; +} + +static int test_sdpa_kvcache2_invalid_view() +{ + ncnn::ParamDict pd; + pd.set(7, 2); // kv_cache + + ncnn::Layer* op = ncnn::create_layer_cpu("SDPA"); + op->load_param(pd); + + ncnn::Option opt; + opt.num_threads = 1; + + std::vector as(4); + as[0] = RandomMat(32, 1, 4); + as[1] = RandomMat(32, 1, 4); + as[2] = RandomMat(20, 1, 4); + as[3] = RandomMat(32, 8, 4); + + std::vector top_blobs(3); + int ret = op->forward(as, top_blobs, opt); + if (ret == 0) + { + fprintf(stderr, "test_sdpa_kvcache2_invalid_view failed missing past_value\n"); + delete op; + return -1; + } + + as.push_back(RandomMat(20, 7, 4)); + ret = op->forward(as, top_blobs, opt); + if (ret == 0) + { + fprintf(stderr, "test_sdpa_kvcache2_invalid_view failed mismatched cache views\n"); + delete op; + return -1; + } + + delete op; + return 0; +} + +static int test_sdpa_kvcache2_persistent_buffer() +{ + const int embed_dim = 4; + const int out_embed_dim = 3; + const int past_seqlen = 2; + const int cur_seqlen = 2; + const int max_seqlen = 5; + + ncnn::Mat query(embed_dim, cur_seqlen, 2); + query.fill(0.01f); + + ncnn::Mat cur_key(embed_dim, cur_seqlen, 1); + ncnn::Mat cur_value(out_embed_dim, cur_seqlen, 1); + ncnn::Mat past_key_full(embed_dim, max_seqlen, 1); + ncnn::Mat past_value_full(out_embed_dim, max_seqlen, 1); + + for (int y = 0; y < cur_seqlen; y++) + { + float* kptr = cur_key.row(y); + for (int x = 0; x < embed_dim; x++) + kptr[x] = 100.f + y * 10 + x; + + float* vptr = cur_value.row(y); + for (int x = 0; x < out_embed_dim; x++) + vptr[x] = 200.f + y * 10 + x; + } + + for (int y = 0; y < max_seqlen; y++) + { + float* kptr = past_key_full.row(y); + for (int x = 0; x < embed_dim; x++) + kptr[x] = 300.f + y * 10 + x; + + float* vptr = past_value_full.row(y); + for (int x = 0; x < out_embed_dim; x++) + vptr[x] = 400.f + y * 10 + x; + } + + ncnn::Mat past_key = past_key_full; + ncnn::Mat past_value = past_value_full; + past_key.h = past_seqlen; + past_value.h = past_seqlen; + + ncnn::ParamDict pd; + pd.set(7, 2); // kv_cache + + ncnn::Layer* op = ncnn::create_layer_cpu("SDPA"); + op->load_param(pd); + + ncnn::Option opt; + opt.num_threads = 1; + + op->create_pipeline(opt); + + std::vector bottom_blobs(5); + bottom_blobs[0] = query; + bottom_blobs[1] = cur_key; + bottom_blobs[2] = cur_value; + bottom_blobs[3] = past_key; + bottom_blobs[4] = past_value; + + std::vector top_blobs(3); + int ret = op->forward(bottom_blobs, top_blobs, opt); + if (ret != 0) + { + fprintf(stderr, "test_sdpa_kvcache2_persistent_buffer failed forward\n"); + op->destroy_pipeline(opt); + delete op; + return -1; + } + + if (top_blobs[1].data != past_key_full.data || top_blobs[2].data != past_value_full.data) + { + fprintf(stderr, "test_sdpa_kvcache2_persistent_buffer failed buffer identity\n"); + op->destroy_pipeline(opt); + delete op; + return -1; + } + + if (top_blobs[1].h != past_seqlen + cur_seqlen || top_blobs[2].h != past_seqlen + cur_seqlen || top_blobs[1].cstep != past_key_full.cstep || top_blobs[2].cstep != past_value_full.cstep) + { + fprintf(stderr, "test_sdpa_kvcache2_persistent_buffer failed output view shape\n"); + op->destroy_pipeline(opt); + delete op; + return -1; + } + + for (int y = 0; y < past_seqlen; y++) + { + const float* kptr = past_key_full.row(y); + for (int x = 0; x < embed_dim; x++) + { + if (kptr[x] != 300.f + y * 10 + x) + { + fprintf(stderr, "test_sdpa_kvcache2_persistent_buffer clobbered past_key\n"); + op->destroy_pipeline(opt); + delete op; + return -1; + } + } + + const float* vptr = past_value_full.row(y); + for (int x = 0; x < out_embed_dim; x++) + { + if (vptr[x] != 400.f + y * 10 + x) + { + fprintf(stderr, "test_sdpa_kvcache2_persistent_buffer clobbered past_value\n"); + op->destroy_pipeline(opt); + delete op; + return -1; + } + } + } + + for (int y = 0; y < cur_seqlen; y++) + { + const float* kptr = past_key_full.row(past_seqlen + y); + for (int x = 0; x < embed_dim; x++) + { + if (kptr[x] != 100.f + y * 10 + x) + { + fprintf(stderr, "test_sdpa_kvcache2_persistent_buffer failed appended key\n"); + op->destroy_pipeline(opt); + delete op; + return -1; + } + } + + const float* vptr = past_value_full.row(past_seqlen + y); + for (int x = 0; x < out_embed_dim; x++) + { + if (vptr[x] != 200.f + y * 10 + x) + { + fprintf(stderr, "test_sdpa_kvcache2_persistent_buffer failed appended value\n"); + op->destroy_pipeline(opt); + delete op; + return -1; + } + } + } + + past_key.h = max_seqlen - cur_seqlen + 1; + past_value.h = max_seqlen - cur_seqlen + 1; + bottom_blobs[3] = past_key; + bottom_blobs[4] = past_value; + ret = op->forward(bottom_blobs, top_blobs, opt); + if (ret == 0) + { + fprintf(stderr, "test_sdpa_kvcache2_persistent_buffer failed overflow check\n"); + op->destroy_pipeline(opt); + delete op; + return -1; + } + + op->destroy_pipeline(opt); + delete op; + return 0; +} + static int test_sdpa_0() { return 0 @@ -53,7 +346,17 @@ static int test_sdpa_0() || test_sdpa_kvcache(RandomMat(44, 128, 4), RandomMat(44, 123, 4), RandomMat(55, 123, 4), 0, 0) || test_sdpa_kvcache(RandomMat(12, 127, 4), RandomMat(12, 127, 4), RandomMat(55, 127, 4), 1, 0) || test_sdpa_kvcache(RandomMat(28, 17, 15), RandomMat(28, 127, 5), RandomMat(32, 127, 5), 0, 3) - || test_sdpa_kvcache(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, 5); + || test_sdpa_kvcache(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, 5) + || test_sdpa_kvcache2(RandomMat(32, 1, 8), RandomMat(32, 1, 8), RandomMat(20, 1, 8), 0, 11, 32) + || test_sdpa_kvcache2(RandomMat(64, 1, 12), RandomMat(64, 1, 2), RandomMat(64, 1, 2), 0, 1, 8) + || test_sdpa_kvcache2(RandomMat(64, 1, 12), RandomMat(64, 1, 2), RandomMat(96, 1, 2), 0, 1, 8) + || test_sdpa_kvcache2(RandomMat(28, 1, 15), RandomMat(28, 1, 5), RandomMat(32, 1, 5), 0, 3, 16) + || test_sdpa_kvcache2(RandomMat(32, 16, 8), RandomMat(32, 16, 8), RandomMat(20, 16, 8), 0, 0, 32) + || test_sdpa_kvcache2(RandomMat(64, 17, 12), RandomMat(64, 17, 2), RandomMat(64, 17, 2), 0, 0, 32) + || test_sdpa_kvcache2_invalid_view() + || test_sdpa_kvcache2_persistent_buffer() + || test_sdpa_kvcache2(RandomMat(32, 16, 8), RandomMat(32, 16, 8), RandomMat(20, 16, 8), 1, 0, 32) + || test_sdpa_kvcache2(RandomMat(64, 17, 12), RandomMat(64, 17, 2), RandomMat(64, 17, 2), 1, 0, 32); } #if NCNN_INT8 @@ -96,6 +399,88 @@ static int test_sdpa_int8_kvcache(const ncnn::Mat& q, const ncnn::Mat& k, const return ret; } +static int test_sdpa_int8_kvcache2(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int attn_mask, int past_seqlen, int max_seqlen) +{ + const int embed_dim = q.w; + const int out_embed_dim = v.w; + const int src_seqlen = q.h; + const int cur_seqlen = k.h; + const int dst_seqlen = past_seqlen + cur_seqlen; + + ncnn::Mat past_key_full = RandomMat(embed_dim, max_seqlen, k.c); + ncnn::Mat past_value_full = RandomMat(out_embed_dim, max_seqlen, v.c); + ncnn::Mat past_key = past_key_full; + ncnn::Mat past_value = past_value_full; + past_key.h = past_seqlen; + past_value.h = past_seqlen; + + ncnn::Mat mask; + if (attn_mask) + mask = CausalMask(src_seqlen, past_seqlen); + + ncnn::ParamDict pd1; + pd1.set(5, attn_mask); + pd1.set(7, 1); // kv_cache + pd1.set(18, 2); // int8_scale_term + + ncnn::ParamDict pd2; + pd2.set(5, attn_mask); + pd2.set(7, 2); // kv_cache + pd2.set(18, 2); // int8_scale_term + + ncnn::Layer* op2 = ncnn::create_layer_cpu("SDPA"); + op2->load_param(pd2); + std::vector weights(0); + ncnn::ModelBinFromMatArray mb(weights.data()); + op2->load_model(mb); + + float epsilon = 0.01; + + ncnn::Option opt; + opt.num_threads = 1; + + op2->create_pipeline(opt); + + std::vector as1(attn_mask ? 6 : 5); + std::vector as2(attn_mask ? 6 : 5); + as1[0] = q; + as1[1] = k; + as1[2] = v; + as2[0] = q; + as2[1] = k; + as2[2] = v; + int offset = 3; + if (attn_mask) + { + as1[offset] = mask; + as2[offset] = mask; + offset++; + } + as1[offset] = past_key; + as1[offset + 1] = past_value; + as2[offset] = past_key; + as2[offset + 1] = past_value; + + std::vector out1(3); + std::vector out2(3); + int ret1 = test_layer_naive(ncnn::layer_to_index("SDPA"), pd1, weights, as1, 3, out1, 0); + int ret2 = op2->forward(as2, out2, opt); + + int ret = 0; + if (ret1 != 0 || ret2 != 0 || CompareMat(out1[0], out2[0], epsilon) != 0) + ret = -1; + if (ret == 0 && (out2[1].data != past_key_full.data || out2[2].data != past_value_full.data || out2[1].h != dst_seqlen || out2[2].h != dst_seqlen)) + ret = -1; + + op2->destroy_pipeline(opt); + delete op2; + + if (ret != 0) + fprintf(stderr, "test_sdpa_int8_kvcache2 failed q=(%d %d %d) k=(%d %d %d) v=(%d %d %d) attn_mask=%d past_seqlen=%d max_seqlen=%d\n", q.w, q.h, q.c, k.w, k.h, k.c, v.w, v.h, v.c, attn_mask, past_seqlen, max_seqlen); + + return ret; +} + static int test_sdpa_1() { return 0 @@ -108,7 +493,11 @@ static int test_sdpa_1() || test_sdpa_int8_kvcache(RandomMat(44, 128, 4), RandomMat(44, 123, 4), RandomMat(55, 123, 4), 0, 0) || test_sdpa_int8_kvcache(RandomMat(12, 127, 4), RandomMat(12, 127, 4), RandomMat(55, 127, 4), 1, 0) || test_sdpa_int8_kvcache(RandomMat(28, 17, 15), RandomMat(28, 127, 5), RandomMat(32, 127, 5), 0, 3) - || test_sdpa_int8_kvcache(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, 5); + || test_sdpa_int8_kvcache(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, 5) + || test_sdpa_int8_kvcache2(RandomMat(32, 1, 8), RandomMat(32, 1, 8), RandomMat(20, 1, 8), 0, 11, 32) + || test_sdpa_int8_kvcache2(RandomMat(64, 1, 12), RandomMat(64, 1, 2), RandomMat(64, 1, 2), 0, 1, 8) + || test_sdpa_int8_kvcache2(RandomMat(32, 16, 8), RandomMat(32, 16, 8), RandomMat(20, 16, 8), 1, 0, 32) + || test_sdpa_int8_kvcache2(RandomMat(64, 17, 12), RandomMat(64, 17, 2), RandomMat(64, 17, 2), 1, 0, 32); } #endif