Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 84 additions & 123 deletions src/layer/riscv/gemm_bf16s_fp16s.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ static void pack_A_tile_bf16_fp16(const Mat& A, Mat& AT, int i, int max_ii, int
const size_t vl = __riscv_vsetvl_e16m1(packn);
#endif

const int elempack = A.elempack;
const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w;

unsigned short* pp = AT;

int ii = 0;
#if __riscv_vector
const int elempack = A.elempack;

for (; ii + (packn - 1) < max_ii; ii += packn)
{
if (elempack == packn)
Expand Down Expand Up @@ -209,25 +210,30 @@ static void pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int max_jj, int
#if __riscv_vector
const int packn = csrr_vlenb() / 2;
const size_t vl = __riscv_vsetvl_e16m1(packn);
const size_t vl8 = __riscv_vsetvl_e16m1(8);
const size_t vl4 = __riscv_vsetvl_e16m1(4);
#endif

const int elempack = B.elempack;
const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w;

unsigned short* pp = BT;

int jj = 0;
#if __riscv_vector
for (; jj + (packn - 1) < max_jj; jj += packn)
const int elempack = B.elempack;

for (; jj + 7 < max_jj; jj += 8)
{
if (elempack == packn)
{
const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * packn;
const int q = (j + jj) / packn * packn;
const int r = (j + jj) % packn;
const unsigned short* p0 = (const unsigned short*)B + q * B_hstep + k * packn + r;

for (int kk = 0; kk < max_kk; kk++)
{
__riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl), vl);
pp += packn;
__riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl8), vl8);
pp += 8;
p0 += packn;
}
}
Expand All @@ -237,8 +243,35 @@ static void pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int max_jj, int

for (int kk = 0; kk < max_kk; kk++)
{
__riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0, B_hstep * sizeof(unsigned short), vl), vl);
pp += packn;
__riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0, B_hstep * sizeof(unsigned short), vl8), vl8);
pp += 8;
p0++;
}
}
}
for (; jj + 3 < max_jj; jj += 4)
{
if (elempack == packn)
{
const int q = (j + jj) / packn * packn;
const int r = (j + jj) % packn;
const unsigned short* p0 = (const unsigned short*)B + q * B_hstep + k * packn + r;

for (int kk = 0; kk < max_kk; kk++)
{
__riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl4), vl4);
pp += 4;
p0 += packn;
}
}
if (elempack == 1)
{
const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k;

for (int kk = 0; kk < max_kk; kk++)
{
__riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0, B_hstep * sizeof(unsigned short), vl4), vl4);
pp += 4;
p0++;
}
}
Expand Down Expand Up @@ -303,6 +336,8 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma
#if __riscv_vector
const int packn = csrr_vlenb() / 2;
const size_t vl = __riscv_vsetvl_e16m1(packn);
const size_t vl8 = __riscv_vsetvl_e16m1(8);
const size_t vl4 = __riscv_vsetvl_e16m1(4);
#endif

const int elempack = B.elempack;
Expand All @@ -312,7 +347,7 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma

int jj = 0;
#if __riscv_vector
for (; jj + (packn - 1) < max_jj; jj += packn)
for (; jj + 7 < max_jj; jj += 8)
{
if (elempack == packn)
{
Expand All @@ -321,11 +356,40 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma
int kk = 0;
for (; kk + (packn - 1) < max_kk; kk += packn)
{
// transposeNxN
for (int l = 0; l < packn; l++)
{
__riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0 + l, packn * sizeof(unsigned short), vl), vl);
pp += packn;
__riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0 + l, packn * sizeof(unsigned short), vl8), vl8);
pp += 8;
}
p0 += B_hstep * packn;
}
}
if (elempack == 1)
{
const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj);

for (int kk = 0; kk < max_kk; kk++)
{
__riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl8), vl8);
pp += 8;
p0 += B_hstep;
}
}
}
for (; jj + 3 < max_jj; jj += 4)
{
if (elempack == packn)
{
const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * packn;

int kk = 0;
for (; kk + (packn - 1) < max_kk; kk += packn)
{
// transposeNx4
for (int l = 0; l < packn; l++)
{
__riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0 + l, packn * sizeof(unsigned short), vl4), vl4);
pp += 4;
}
p0 += B_hstep * packn;
}
Expand All @@ -337,8 +401,8 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma
int kk = 0;
for (; kk < max_kk; kk++)
{
__riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl), vl);
pp += packn;
__riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl4), vl4);
pp += 4;
p0 += B_hstep;
}
}
Expand Down Expand Up @@ -407,111 +471,6 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma
}
}

static void transpose_unpack_output_tile_bf16_fp16(const Mat& topT, Mat& top_blob, int i, int max_ii, int j, int max_jj)
{
#if __riscv_vector
const int packn = csrr_vlenb() / 2;
const size_t vl = __riscv_vsetvl_e16m1(packn);
#endif

const int out_elempack = top_blob.elempack;
const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w;

const unsigned short* pp = topT;

int ii = 0;
#if __riscv_vector
for (; ii + (packn - 1) < max_ii; ii += packn)
{
if (out_elempack == packn)
{
unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * packn;

for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn)
{
// transposeNxN
for (int l = 0; l < packn; l++)
{
__riscv_vsse16_v_u16m1(p0 + l, packn * sizeof(unsigned short), __riscv_vle16_v_u16m1(pp, vl), vl);
pp += packn;
}
p0 += out_hstep * packn;
}
}
if (out_elempack == 1)
{
unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii);

for (int jj = 0; jj < max_jj; jj += 1)
{
vuint16m1_t _r0 = __riscv_vle16_v_u16m1(pp, vl);
__riscv_vse16_v_u16m1(p0, _r0, vl);
pp += packn;
p0 += out_hstep;
}
}
}
#endif // __riscv_vector
for (; ii + 1 < max_ii; ii += 2)
{
#if __riscv_vector
if (out_elempack == packn)
{
unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * packn;

for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn)
{
vuint16m1x2_t _s0 = __riscv_vlseg2e16_v_u16m1x2(pp, vl);
__riscv_vse16_v_u16m1(p0, __riscv_vget_v_u16m1x2_u16m1(_s0, 0), vl);
__riscv_vse16_v_u16m1(p0 + packn, __riscv_vget_v_u16m1x2_u16m1(_s0, 1), vl);
pp += packn * 2;
p0 += out_hstep * packn;
}
}
#endif // __riscv_vector
if (out_elempack == 1)
{
unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii);

for (int jj = 0; jj < max_jj; jj += 1)
{
p0[0] = pp[0];
p0[1] = pp[1];
pp += 2;
p0 += out_hstep;
}
}
}
for (; ii < max_ii; ii += 1)
{
#if __riscv_vector
if (out_elempack == packn)
{
unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * packn;

for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn)
{
vuint16m1_t _r0 = __riscv_vle16_v_u16m1(pp, vl);
__riscv_vse16_v_u16m1(p0, _r0, vl);
pp += packn;
p0 += out_hstep * packn;
}
}
#endif // __riscv_vector
if (out_elempack == 1)
{
unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii);

for (int jj = 0; jj < max_jj; jj += 1)
{
p0[0] = pp[0];
pp += 1;
p0 += out_hstep;
}
}
}
}

static void get_optimal_tile_mnk_bf16s_fp16s(int M, int N, int K, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int& TILE_M, int& TILE_N, int& TILE_K, int nT)
{
// resolve optimal tile size from cache size
Expand All @@ -524,12 +483,14 @@ static void get_optimal_tile_mnk_bf16s_fp16s(int M, int N, int K, int constant_T

#if __riscv_vector
const int packn = csrr_vlenb() / 2;
const int packn_n = 8;
#else
const int packn = 4;
const int packn_n = 4;
#endif

TILE_M = std::max(packn, tile_size / packn * packn);
TILE_N = std::max(packn, tile_size / packn * packn);
TILE_N = std::max(packn_n, tile_size / packn_n * packn_n);
TILE_K = std::max(packn, tile_size / packn * packn);

if (K > 0)
Expand All @@ -542,7 +503,7 @@ static void get_optimal_tile_mnk_bf16s_fp16s(int M, int N, int K, int constant_T
tile_size = (int)((float)l2_cache_size / 2 / sizeof(unsigned short) / TILE_K);

TILE_M = std::max(packn, tile_size / packn * packn);
TILE_N = std::max(packn, tile_size / packn * packn);
TILE_N = std::max(packn_n, tile_size / packn_n * packn_n);
}
}

Expand All @@ -557,7 +518,7 @@ static void get_optimal_tile_mnk_bf16s_fp16s(int M, int N, int K, int constant_T
if (N > 0)
{
int nn_N = (N + TILE_N - 1) / TILE_N;
TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + (packn - 1)) / packn * packn);
TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + (packn_n - 1)) / packn_n * packn_n);
}

if (nT > 1)
Expand All @@ -573,7 +534,7 @@ static void get_optimal_tile_mnk_bf16s_fp16s(int M, int N, int K, int constant_T

if (constant_TILE_N > 0)
{
TILE_N = (constant_TILE_N + (packn - 1)) / packn * packn;
TILE_N = (constant_TILE_N + (packn_n - 1)) / packn_n * packn_n;
}

if (constant_TILE_K > 0)
Expand Down
Loading
Loading