Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 89 additions & 19 deletions src/layer/riscv/gemm_bf16s_fp16s.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ static void pack_A_tile_bf16_fp16(const Mat& A, Mat& AT, int i, int max_ii, int
const size_t vl = __riscv_vsetvl_e16m1(packn);
#endif

const int elempack = A.elempack;
const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w;

unsigned short* pp = AT;

int ii = 0;
#if __riscv_vector
const int elempack = A.elempack;

for (; ii + (packn - 1) < max_ii; ii += packn)
{
if (elempack == packn)
Expand Down Expand Up @@ -209,25 +210,30 @@ static void pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int max_jj, int
#if __riscv_vector
const int packn = csrr_vlenb() / 2;
const size_t vl = __riscv_vsetvl_e16m1(packn);
const size_t vl8 = __riscv_vsetvl_e16m1(8);
const size_t vl4 = __riscv_vsetvl_e16m1(4);
#endif

const int elempack = B.elempack;
const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w;

unsigned short* pp = BT;

int jj = 0;
#if __riscv_vector
for (; jj + (packn - 1) < max_jj; jj += packn)
const int elempack = B.elempack;

for (; jj + 7 < max_jj; jj += 8)
{
if (elempack == packn)
{
const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * packn;
const int q = (j + jj) / packn * packn;
const int r = (j + jj) % packn;
const unsigned short* p0 = (const unsigned short*)B + q * B_hstep + k * packn + r;

for (int kk = 0; kk < max_kk; kk++)
{
__riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl), vl);
pp += packn;
__riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl8), vl8);
pp += 8;
p0 += packn;
}
}
Expand All @@ -237,8 +243,35 @@ static void pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int max_jj, int

for (int kk = 0; kk < max_kk; kk++)
{
__riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0, B_hstep * sizeof(unsigned short), vl), vl);
pp += packn;
__riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0, B_hstep * sizeof(unsigned short), vl8), vl8);
pp += 8;
p0++;
}
}
}
for (; jj + 3 < max_jj; jj += 4)
{
if (elempack == packn)
{
const int q = (j + jj) / packn * packn;
const int r = (j + jj) % packn;
const unsigned short* p0 = (const unsigned short*)B + q * B_hstep + k * packn + r;

for (int kk = 0; kk < max_kk; kk++)
{
__riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl4), vl4);
pp += 4;
p0 += packn;
}
}
if (elempack == 1)
{
const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k;

for (int kk = 0; kk < max_kk; kk++)
{
__riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0, B_hstep * sizeof(unsigned short), vl4), vl4);
pp += 4;
p0++;
}
}
Expand Down Expand Up @@ -303,6 +336,8 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma
#if __riscv_vector
const int packn = csrr_vlenb() / 2;
const size_t vl = __riscv_vsetvl_e16m1(packn);
const size_t vl8 = __riscv_vsetvl_e16m1(8);
const size_t vl4 = __riscv_vsetvl_e16m1(4);
#endif

const int elempack = B.elempack;
Expand All @@ -312,20 +347,53 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma

int jj = 0;
#if __riscv_vector
for (; jj + (packn - 1) < max_jj; jj += packn)
for (; jj + 7 < max_jj; jj += 8)
{
if (elempack == packn)
{
const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * packn;
const int q = (j + jj) / packn * packn;
const int r = (j + jj) % packn;
const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + q * packn + r * packn;

int kk = 0;
for (; kk + (packn - 1) < max_kk; kk += packn)
{
// transposeNxN
for (int l = 0; l < packn; l++)
{
__riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0 + l, packn * sizeof(unsigned short), vl), vl);
pp += packn;
__riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0 + l, packn * sizeof(unsigned short), vl8), vl8);
pp += 8;
}
p0 += B_hstep * packn;
}
}
if (elempack == 1)
{
const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj);

for (int kk = 0; kk < max_kk; kk++)
{
__riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl8), vl8);
pp += 8;
p0 += B_hstep;
}
}
}
for (; jj + 3 < max_jj; jj += 4)
{
if (elempack == packn)
{
const int q = (j + jj) / packn * packn;
const int r = (j + jj) % packn;
const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + q * packn + r * packn;

int kk = 0;
for (; kk + (packn - 1) < max_kk; kk += packn)
{
// transposeNx4
for (int l = 0; l < packn; l++)
{
__riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0 + l, packn * sizeof(unsigned short), vl4), vl4);
pp += 4;
}
p0 += B_hstep * packn;
}
Expand All @@ -337,8 +405,8 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma
int kk = 0;
for (; kk < max_kk; kk++)
{
__riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl), vl);
pp += packn;
__riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl4), vl4);
pp += 4;
p0 += B_hstep;
}
}
Expand Down Expand Up @@ -524,12 +592,14 @@ static void get_optimal_tile_mnk_bf16s_fp16s(int M, int N, int K, int constant_T

#if __riscv_vector
const int packn = csrr_vlenb() / 2;
const int packn_n = 8;
#else
const int packn = 4;
const int packn_n = 4;
#endif

TILE_M = std::max(packn, tile_size / packn * packn);
TILE_N = std::max(packn, tile_size / packn * packn);
TILE_N = std::max(packn_n, tile_size / packn_n * packn_n);
TILE_K = std::max(packn, tile_size / packn * packn);

if (K > 0)
Expand All @@ -542,7 +612,7 @@ static void get_optimal_tile_mnk_bf16s_fp16s(int M, int N, int K, int constant_T
tile_size = (int)((float)l2_cache_size / 2 / sizeof(unsigned short) / TILE_K);

TILE_M = std::max(packn, tile_size / packn * packn);
TILE_N = std::max(packn, tile_size / packn * packn);
TILE_N = std::max(packn_n, tile_size / packn_n * packn_n);
}
}

Expand All @@ -557,7 +627,7 @@ static void get_optimal_tile_mnk_bf16s_fp16s(int M, int N, int K, int constant_T
if (N > 0)
{
int nn_N = (N + TILE_N - 1) / TILE_N;
TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + (packn - 1)) / packn * packn);
TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + (packn_n - 1)) / packn_n * packn_n);
}

if (nT > 1)
Expand All @@ -573,7 +643,7 @@ static void get_optimal_tile_mnk_bf16s_fp16s(int M, int N, int K, int constant_T

if (constant_TILE_N > 0)
{
TILE_N = (constant_TILE_N + (packn - 1)) / packn * packn;
TILE_N = (constant_TILE_N + (packn_n - 1)) / packn_n * packn_n;
}

if (constant_TILE_K > 0)
Expand Down
Loading
Loading