Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 74 additions & 9 deletions src/layer/arm/sdpa_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@

namespace ncnn {

static Mat make_persistent_kvcache_view_arm(const Mat& cache, int seqlen)
{
Mat view = cache;
view.h = seqlen;
return view;
}

SDPA_arm::SDPA_arm()
{
#if NCNN_ARM82
Expand Down Expand Up @@ -151,28 +158,83 @@ int SDPA_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
if (int8_scale_term)
{
opt.use_packing_layout = false; // TODO enable packing
if (kv_cache == 2)
return SDPA::forward(bottom_blobs, top_blobs, opt);
}

const Mat& query = bottom_blobs[0];
const Mat& cur_key = bottom_blobs[1];
const Mat& cur_value = bottom_blobs[2];
const Mat& attn_mask_blob = attn_mask ? bottom_blobs[3] : Mat();
const Mat& past_key = kv_cache ? bottom_blobs[attn_mask ? 4 : 3] : Mat();
const Mat& past_value = kv_cache ? bottom_blobs[attn_mask ? 5 : 4] : Mat();
const int blob_offset = attn_mask ? 4 : 3;
if (kv_cache == 2 && (int)bottom_blobs.size() <= blob_offset + 1)
return -1;
const Mat& past_key = kv_cache ? bottom_blobs[blob_offset] : Mat();
const Mat& past_value = kv_cache ? bottom_blobs[blob_offset + 1] : Mat();

const int embed_dim = query.w;
const int src_seqlen = query.h;
const int num_heads = query.c;
const int cur_seqlen = cur_key.h;
const int num_group = cur_key.c;
const int out_embed_dim = cur_value.w;
const int past_seqlen = kv_cache ? past_key.h : 0;

int past_seqlen = 0;
if (kv_cache == 2)
{
if (past_key.dims == 0 || past_value.dims == 0)
return -1;
past_seqlen = past_key.h;
}
else if (kv_cache == 1 && past_key.dims > 0)
past_seqlen = past_key.h;

if (kv_cache == 2 && past_value.h != past_seqlen)
return -1;

const int dst_seqlen = past_seqlen + cur_seqlen;

const size_t elemsize = query.elemsize;
const int num_heads_per_group = num_heads / num_group;

// kv_cache==2: in-place append to preallocated buffer
if (kv_cache == 2 && past_key.dims > 0)
{
const int key_capacity = (int)(past_key.cstep / embed_dim);
const int value_capacity = (int)(past_value.cstep / out_embed_dim);
if (dst_seqlen > key_capacity || dst_seqlen > value_capacity)
return -1;

// In-place append cur_key/cur_value to preallocated buffer
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_group; q++)
{
unsigned char* pk = (unsigned char*)past_key.channel(q).data;
unsigned char* pv = (unsigned char*)past_value.channel(q).data;
memcpy(pk + (size_t)past_seqlen * embed_dim * elemsize,
cur_key.channel(q).data, embed_dim * cur_seqlen * elemsize);
memcpy(pv + (size_t)past_seqlen * out_embed_dim * elemsize,
cur_value.channel(q).data, out_embed_dim * cur_seqlen * elemsize);
}
}

Mat key;
if (past_seqlen > 0)
Mat value;
if (kv_cache == 2 && past_key.dims > 0)
{
// Copy dst_seqlen rows from preallocated buffer (already appended above)
key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator);
if (key.empty()) return -100;
value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator);
if (value.empty()) return -100;
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_group; q++)
{
memcpy(key.channel(q), past_key.channel(q), embed_dim * dst_seqlen * elemsize);
memcpy(value.channel(q), past_value.channel(q), out_embed_dim * dst_seqlen * elemsize);
}
}
else if (past_seqlen > 0)
{
key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator);
if (key.empty())
Expand All @@ -194,8 +256,7 @@ int SDPA_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
key = cur_key;
}

Mat value;
if (past_seqlen > 0)
if (kv_cache != 2 && past_seqlen > 0)
{
value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator);
if (value.empty())
Expand All @@ -212,12 +273,11 @@ int SDPA_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * elemsize);
}
}
else
else if (kv_cache != 2)
{
value = cur_value;
}

const int num_heads_per_group = num_heads / num_group;

#if NCNN_BF16
const bool use_bf16_storage = opt.use_bf16_storage && !opt.use_fp16_storage && query.elembits() == 16;
Expand Down Expand Up @@ -361,7 +421,12 @@ int SDPA_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to

value_fp32.release();

if (kv_cache)
if (kv_cache == 2)
{
top_blobs[1] = make_persistent_kvcache_view_arm(past_key, dst_seqlen);
top_blobs[2] = make_persistent_kvcache_view_arm(past_value, dst_seqlen);
}
else if (kv_cache)
{
top_blobs[1] = key;
top_blobs[2] = value;
Expand Down
76 changes: 68 additions & 8 deletions src/layer/loongarch/sdpa_loongarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,18 @@

#include "sdpa_loongarch.h"

#include "cpu.h"
#include "layer_type.h"

namespace ncnn {

static Mat make_persistent_kvcache_view_loongarch(const Mat& cache, int seqlen)
{
Mat view = cache;
view.h = seqlen;
return view;
}

SDPA_loongarch::SDPA_loongarch()
{
#if NCNN_BF16
Expand Down Expand Up @@ -143,22 +151,70 @@ int SDPA_loongarch::forward(const std::vector<Mat>& bottom_blobs, std::vector<Ma
const Mat& cur_key = bottom_blobs[1];
const Mat& cur_value = bottom_blobs[2];
const Mat& attn_mask_blob = attn_mask ? bottom_blobs[3] : Mat();
const Mat& past_key = kv_cache ? bottom_blobs[attn_mask ? 4 : 3] : Mat();
const Mat& past_value = kv_cache ? bottom_blobs[attn_mask ? 5 : 4] : Mat();
const int blob_offset = attn_mask ? 4 : 3;
if (kv_cache == 2 && (int)bottom_blobs.size() <= blob_offset + 1)
return -1;
const Mat& past_key = kv_cache ? bottom_blobs[blob_offset] : Mat();
const Mat& past_value = kv_cache ? bottom_blobs[blob_offset + 1] : Mat();

const int embed_dim = query.w;
const int src_seqlen = query.h;
const int num_heads = query.c;
const int cur_seqlen = cur_key.h;
const int num_group = cur_key.c;
const int out_embed_dim = cur_value.w;
const int past_seqlen = kv_cache ? past_key.h : 0;
int past_seqlen = 0;
if (kv_cache == 2)
{
if (past_key.dims == 0 || past_value.dims == 0)
return -1;
past_seqlen = past_key.h;
}
else if (kv_cache == 1 && past_key.dims > 0)
past_seqlen = past_key.h;

if (kv_cache == 2 && past_value.h != past_seqlen)
return -1;

const int dst_seqlen = past_seqlen + cur_seqlen;

const size_t elemsize = query.elemsize;

Mat key;
if (past_seqlen > 0)
Mat value;
if (kv_cache == 2 && past_key.dims > 0)
{
const int key_capacity = (int)(past_key.cstep / embed_dim);
const int value_capacity = (int)(past_value.cstep / out_embed_dim);
if (dst_seqlen > key_capacity || dst_seqlen > value_capacity)
return -1;

#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_group; q++)
{
unsigned char* pk = (unsigned char*)past_key.channel(q).data;
unsigned char* pv = (unsigned char*)past_value.channel(q).data;
memcpy(pk + (size_t)past_seqlen * embed_dim * elemsize,
cur_key.channel(q).data, embed_dim * cur_seqlen * elemsize);
memcpy(pv + (size_t)past_seqlen * out_embed_dim * elemsize,
cur_value.channel(q).data, out_embed_dim * cur_seqlen * elemsize);
}

key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator);
if (key.empty())
return -100;
value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator);
if (value.empty())
return -100;

#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_group; q++)
{
memcpy(key.channel(q), past_key.channel(q), embed_dim * dst_seqlen * elemsize);
memcpy(value.channel(q), past_value.channel(q), out_embed_dim * dst_seqlen * elemsize);
}
}
else if (past_seqlen > 0)
{
key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator);
if (key.empty())
Expand All @@ -180,8 +236,7 @@ int SDPA_loongarch::forward(const std::vector<Mat>& bottom_blobs, std::vector<Ma
key = cur_key;
}

Mat value;
if (past_seqlen > 0)
if (kv_cache != 2 && past_seqlen > 0)
{
value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator);
if (value.empty())
Expand All @@ -198,7 +253,7 @@ int SDPA_loongarch::forward(const std::vector<Mat>& bottom_blobs, std::vector<Ma
memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * elemsize);
}
}
else
else if (kv_cache != 2)
{
value = cur_value;
}
Expand Down Expand Up @@ -337,7 +392,12 @@ int SDPA_loongarch::forward(const std::vector<Mat>& bottom_blobs, std::vector<Ma

value_fp32.release();

if (kv_cache)
if (kv_cache == 2)
{
top_blobs[1] = make_persistent_kvcache_view_loongarch(past_key, dst_seqlen);
top_blobs[2] = make_persistent_kvcache_view_loongarch(past_value, dst_seqlen);
}
else if (kv_cache)
{
top_blobs[1] = key;
top_blobs[2] = value;
Expand Down
76 changes: 68 additions & 8 deletions src/layer/mips/sdpa_mips.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,18 @@

#include "sdpa_mips.h"

#include "cpu.h"
#include "layer_type.h"

namespace ncnn {

static Mat make_persistent_kvcache_view_mips(const Mat& cache, int seqlen)
{
Mat view = cache;
view.h = seqlen;
return view;
}

SDPA_mips::SDPA_mips()
{
#if NCNN_BF16
Expand Down Expand Up @@ -143,22 +151,70 @@ int SDPA_mips::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& t
const Mat& cur_key = bottom_blobs[1];
const Mat& cur_value = bottom_blobs[2];
const Mat& attn_mask_blob = attn_mask ? bottom_blobs[3] : Mat();
const Mat& past_key = kv_cache ? bottom_blobs[attn_mask ? 4 : 3] : Mat();
const Mat& past_value = kv_cache ? bottom_blobs[attn_mask ? 5 : 4] : Mat();
const int blob_offset = attn_mask ? 4 : 3;
if (kv_cache == 2 && (int)bottom_blobs.size() <= blob_offset + 1)
return -1;
const Mat& past_key = kv_cache ? bottom_blobs[blob_offset] : Mat();
const Mat& past_value = kv_cache ? bottom_blobs[blob_offset + 1] : Mat();

const int embed_dim = query.w;
const int src_seqlen = query.h;
const int num_heads = query.c;
const int cur_seqlen = cur_key.h;
const int num_group = cur_key.c;
const int out_embed_dim = cur_value.w;
const int past_seqlen = kv_cache ? past_key.h : 0;
int past_seqlen = 0;
if (kv_cache == 2)
{
if (past_key.dims == 0 || past_value.dims == 0)
return -1;
past_seqlen = past_key.h;
}
else if (kv_cache == 1 && past_key.dims > 0)
past_seqlen = past_key.h;

if (kv_cache == 2 && past_value.h != past_seqlen)
return -1;

const int dst_seqlen = past_seqlen + cur_seqlen;

const size_t elemsize = query.elemsize;

Mat key;
if (past_seqlen > 0)
Mat value;
if (kv_cache == 2 && past_key.dims > 0)
{
const int key_capacity = (int)(past_key.cstep / embed_dim);
const int value_capacity = (int)(past_value.cstep / out_embed_dim);
if (dst_seqlen > key_capacity || dst_seqlen > value_capacity)
return -1;

#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_group; q++)
{
unsigned char* pk = (unsigned char*)past_key.channel(q).data;
unsigned char* pv = (unsigned char*)past_value.channel(q).data;
memcpy(pk + (size_t)past_seqlen * embed_dim * elemsize,
cur_key.channel(q).data, embed_dim * cur_seqlen * elemsize);
memcpy(pv + (size_t)past_seqlen * out_embed_dim * elemsize,
cur_value.channel(q).data, out_embed_dim * cur_seqlen * elemsize);
}

key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator);
if (key.empty())
return -100;
value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator);
if (value.empty())
return -100;

#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_group; q++)
{
memcpy(key.channel(q), past_key.channel(q), embed_dim * dst_seqlen * elemsize);
memcpy(value.channel(q), past_value.channel(q), out_embed_dim * dst_seqlen * elemsize);
}
}
else if (past_seqlen > 0)
{
key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator);
if (key.empty())
Expand All @@ -180,8 +236,7 @@ int SDPA_mips::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& t
key = cur_key;
}

Mat value;
if (past_seqlen > 0)
if (kv_cache != 2 && past_seqlen > 0)
{
value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator);
if (value.empty())
Expand All @@ -198,7 +253,7 @@ int SDPA_mips::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& t
memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * elemsize);
}
}
else
else if (kv_cache != 2)
{
value = cur_value;
}
Expand Down Expand Up @@ -337,7 +392,12 @@ int SDPA_mips::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& t

value_fp32.release();

if (kv_cache)
if (kv_cache == 2)
{
top_blobs[1] = make_persistent_kvcache_view_mips(past_key, dst_seqlen);
top_blobs[2] = make_persistent_kvcache_view_mips(past_value, dst_seqlen);
}
else if (kv_cache)
{
top_blobs[1] = key;
top_blobs[2] = value;
Expand Down
Loading
Loading