diff --git a/src/layer/riscv/gemm_bf16s_fp16s.h b/src/layer/riscv/gemm_bf16s_fp16s.h index 9a08807703e1..ffb77d484d6a 100644 --- a/src/layer/riscv/gemm_bf16s_fp16s.h +++ b/src/layer/riscv/gemm_bf16s_fp16s.h @@ -8,13 +8,14 @@ static void pack_A_tile_bf16_fp16(const Mat& A, Mat& AT, int i, int max_ii, int const size_t vl = __riscv_vsetvl_e16m1(packn); #endif - const int elempack = A.elempack; const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; unsigned short* pp = AT; int ii = 0; #if __riscv_vector + const int elempack = A.elempack; + for (; ii + (packn - 1) < max_ii; ii += packn) { if (elempack == packn) @@ -209,25 +210,30 @@ static void pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int max_jj, int #if __riscv_vector const int packn = csrr_vlenb() / 2; const size_t vl = __riscv_vsetvl_e16m1(packn); + const size_t vl8 = __riscv_vsetvl_e16m1(8); + const size_t vl4 = __riscv_vsetvl_e16m1(4); #endif - const int elempack = B.elempack; const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; unsigned short* pp = BT; int jj = 0; #if __riscv_vector - for (; jj + (packn - 1) < max_jj; jj += packn) + const int elempack = B.elempack; + + for (; jj + 7 < max_jj; jj += 8) { if (elempack == packn) { - const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * packn; + const int q = (j + jj) / packn * packn; + const int r = (j + jj) % packn; + const unsigned short* p0 = (const unsigned short*)B + q * B_hstep + k * packn + r; for (int kk = 0; kk < max_kk; kk++) { - __riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl), vl); - pp += packn; + __riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl8), vl8); + pp += 8; p0 += packn; } } @@ -237,8 +243,35 @@ static void pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int max_jj, int for (int kk = 0; kk < max_kk; kk++) { - __riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0, B_hstep * sizeof(unsigned short), vl), vl); - pp += packn; + __riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0, B_hstep * sizeof(unsigned short), vl8), vl8); + pp += 8; + p0++; + } + } + } + for (; jj + 3 < max_jj; jj += 4) + { + if (elempack == packn) + { + const int q = (j + jj) / packn * packn; + const int r = (j + jj) % packn; + const unsigned short* p0 = (const unsigned short*)B + q * B_hstep + k * packn + r; + + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl4), vl4); + pp += 4; + p0 += packn; + } + } + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0, B_hstep * sizeof(unsigned short), vl4), vl4); + pp += 4; p0++; } } @@ -303,6 +336,8 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma #if __riscv_vector const int packn = csrr_vlenb() / 2; const size_t vl = __riscv_vsetvl_e16m1(packn); + const size_t vl8 = __riscv_vsetvl_e16m1(8); + const size_t vl4 = __riscv_vsetvl_e16m1(4); #endif const int elempack = B.elempack; @@ -312,7 +347,7 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma int jj = 0; #if __riscv_vector - for (; jj + (packn - 1) < max_jj; jj += packn) + for (; jj + 7 < max_jj; jj += 8) { if (elempack == packn) { @@ -321,11 +356,40 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma int kk = 0; for (; kk + (packn - 1) < max_kk; kk += packn) { - // transposeNxN for (int l = 0; l < packn; l++) { - __riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0 + l, packn * sizeof(unsigned short), vl), vl); - pp += packn; + __riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0 + l, packn * sizeof(unsigned short), vl8), vl8); + pp += 8; + } + p0 += B_hstep * packn; + } + } + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj); + + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl8), vl8); + pp += 8; + p0 += B_hstep; + } + } + } + for (; jj + 3 < max_jj; jj += 4) + { + if (elempack == packn) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * packn; + + int kk = 0; + for (; kk + (packn - 1) < max_kk; kk += packn) + { + // transposeNx4 + for (int l = 0; l < packn; l++) + { + __riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0 + l, packn * sizeof(unsigned short), vl4), vl4); + pp += 4; } p0 += B_hstep * packn; } @@ -337,8 +401,8 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma int kk = 0; for (; kk < max_kk; kk++) { - __riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl), vl); - pp += packn; + __riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl4), vl4); + pp += 4; p0 += B_hstep; } } @@ -407,111 +471,6 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma } } -static void transpose_unpack_output_tile_bf16_fp16(const Mat& topT, Mat& top_blob, int i, int max_ii, int j, int max_jj) -{ -#if __riscv_vector - const int packn = csrr_vlenb() / 2; - const size_t vl = __riscv_vsetvl_e16m1(packn); -#endif - - const int out_elempack = top_blob.elempack; - const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; - - const unsigned short* pp = topT; - - int ii = 0; -#if __riscv_vector - for (; ii + (packn - 1) < max_ii; ii += packn) - { - if (out_elempack == packn) - { - unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * packn; - - for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn) - { - // transposeNxN - for (int l = 0; l < packn; l++) - { - __riscv_vsse16_v_u16m1(p0 + l, packn * sizeof(unsigned short), __riscv_vle16_v_u16m1(pp, vl), vl); - pp += packn; - } - p0 += out_hstep * packn; - } - } - if (out_elempack == 1) - { - unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii); - - for (int jj = 0; jj < max_jj; jj += 1) - { - vuint16m1_t _r0 = __riscv_vle16_v_u16m1(pp, vl); - __riscv_vse16_v_u16m1(p0, _r0, vl); - pp += packn; - p0 += out_hstep; - } - } - } -#endif // __riscv_vector - for (; ii + 1 < max_ii; ii += 2) - { -#if __riscv_vector - if (out_elempack == packn) - { - unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * packn; - - for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn) - { - vuint16m1x2_t _s0 = __riscv_vlseg2e16_v_u16m1x2(pp, vl); - __riscv_vse16_v_u16m1(p0, __riscv_vget_v_u16m1x2_u16m1(_s0, 0), vl); - __riscv_vse16_v_u16m1(p0 + packn, __riscv_vget_v_u16m1x2_u16m1(_s0, 1), vl); - pp += packn * 2; - p0 += out_hstep * packn; - } - } -#endif // __riscv_vector - if (out_elempack == 1) - { - unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii); - - for (int jj = 0; jj < max_jj; jj += 1) - { - p0[0] = pp[0]; - p0[1] = pp[1]; - pp += 2; - p0 += out_hstep; - } - } - } - for (; ii < max_ii; ii += 1) - { -#if __riscv_vector - if (out_elempack == packn) - { - unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * packn; - - for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn) - { - vuint16m1_t _r0 = __riscv_vle16_v_u16m1(pp, vl); - __riscv_vse16_v_u16m1(p0, _r0, vl); - pp += packn; - p0 += out_hstep * packn; - } - } -#endif // __riscv_vector - if (out_elempack == 1) - { - unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii); - - for (int jj = 0; jj < max_jj; jj += 1) - { - p0[0] = pp[0]; - pp += 1; - p0 += out_hstep; - } - } - } -} - static void get_optimal_tile_mnk_bf16s_fp16s(int M, int N, int K, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) { // resolve optimal tile size from cache size @@ -524,12 +483,14 @@ static void get_optimal_tile_mnk_bf16s_fp16s(int M, int N, int K, int constant_T #if __riscv_vector const int packn = csrr_vlenb() / 2; + const int packn_n = 8; #else const int packn = 4; + const int packn_n = 4; #endif TILE_M = std::max(packn, tile_size / packn * packn); - TILE_N = std::max(packn, tile_size / packn * packn); + TILE_N = std::max(packn_n, tile_size / packn_n * packn_n); TILE_K = std::max(packn, tile_size / packn * packn); if (K > 0) @@ -542,7 +503,7 @@ static void get_optimal_tile_mnk_bf16s_fp16s(int M, int N, int K, int constant_T tile_size = (int)((float)l2_cache_size / 2 / sizeof(unsigned short) / TILE_K); TILE_M = std::max(packn, tile_size / packn * packn); - TILE_N = std::max(packn, tile_size / packn * packn); + TILE_N = std::max(packn_n, tile_size / packn_n * packn_n); } } @@ -557,7 +518,7 @@ static void get_optimal_tile_mnk_bf16s_fp16s(int M, int N, int K, int constant_T if (N > 0) { int nn_N = (N + TILE_N - 1) / TILE_N; - TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + (packn - 1)) / packn * packn); + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + (packn_n - 1)) / packn_n * packn_n); } if (nT > 1) @@ -573,7 +534,7 @@ static void get_optimal_tile_mnk_bf16s_fp16s(int M, int N, int K, int constant_T if (constant_TILE_N > 0) { - TILE_N = (constant_TILE_N + (packn - 1)) / packn * packn; + TILE_N = (constant_TILE_N + (packn_n - 1)) / packn_n * packn_n; } if (constant_TILE_K > 0) diff --git a/src/layer/riscv/gemm_fp16s.h b/src/layer/riscv/gemm_fp16s.h index 4efe28da1e97..2cdc3dee0958 100644 --- a/src/layer/riscv/gemm_fp16s.h +++ b/src/layer/riscv/gemm_fp16s.h @@ -1,20 +1,21 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -static void pack_A_tile_fp16s(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +static void pack_A_tile_fp32(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) { #if __riscv_vector const int packn = csrr_vlenb() / 2; const size_t vl = __riscv_vsetvl_e32m2(packn); #endif - const int elempack = A.elempack; const size_t A_hstep = A.dims == 3 ? A.cstep : (size_t)A.w; float* pp = AT; int ii = 0; #if __riscv_vector + const int elempack = A.elempack; + for (; ii + (packn - 1) < max_ii; ii += packn) { if (elempack == packn) @@ -95,7 +96,12 @@ static void pack_A_tile_fp16s(const Mat& A, Mat& AT, int i, int max_ii, int k, i } } -static void pack_A_tile_fp32_to_fp16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +#if __riscv_vector && __riscv_xtheadvector +// FIXME inline causes incorrect codegen on c906 +__attribute__((noinline)) +#endif +static void +pack_A_tile_fp32_to_fp16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) { #if __riscv_vector const int packn = csrr_vlenb() / 2; @@ -148,7 +154,12 @@ static void pack_A_tile_fp32_to_fp16(const Mat& A, Mat& AT, int i, int max_ii, i } } -static void transpose_pack_A_tile_fp32_to_fp16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +#if __riscv_vector && __riscv_xtheadvector +// FIXME inline causes incorrect codegen on c906 +__attribute__((noinline)) +#endif +static void +transpose_pack_A_tile_fp32_to_fp16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) { #if __riscv_vector const int packn = csrr_vlenb() / 2; @@ -201,11 +212,16 @@ static void transpose_pack_A_tile_fp32_to_fp16(const Mat& A, Mat& AT, int i, int } } -static void pack_B_tile_fp32_to_fp16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +#if __riscv_vector && __riscv_xtheadvector +// FIXME inline causes incorrect codegen on c906 +__attribute__((noinline)) +#endif +static void +pack_B_tile_fp32_to_fp16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) { #if __riscv_vector - const int packn = csrr_vlenb() / 2; - const size_t vl = __riscv_vsetvl_e16m1(packn); + const size_t vl8 = __riscv_vsetvl_e16m1(8); + const size_t vl4 = __riscv_vsetvl_e16m1(4); #endif const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; @@ -214,15 +230,27 @@ static void pack_B_tile_fp32_to_fp16(const Mat& B, Mat& BT, int j, int max_jj, i int jj = 0; #if __riscv_vector - for (; jj + (packn - 1) < max_jj; jj += packn) + for (; jj + 7 < max_jj; jj += 8) { const float* p0 = (const float*)B + (j + jj) * B_hstep + k; int kk = 0; for (; kk < max_kk; kk++) { - __riscv_vse16_v_f16m1(pp, __riscv_vfncvt_f_f_w_f16m1(__riscv_vlse32_v_f32m2(p0, B_hstep * sizeof(float), vl), vl), vl); - pp += packn; + __riscv_vse16_v_f16m1(pp, __riscv_vfncvt_f_f_w_f16m1(__riscv_vlse32_v_f32m2(p0, B_hstep * sizeof(float), vl8), vl8), vl8); + pp += 8; + p0++; + } + } + for (; jj + 3 < max_jj; jj += 4) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + + int kk = 0; + for (; kk < max_kk; kk++) + { + __riscv_vse16_v_f16m1(pp, __riscv_vfncvt_f_f_w_f16m1(__riscv_vlse32_v_f32m2(p0, B_hstep * sizeof(float), vl4), vl4), vl4); + pp += 4; p0++; } } @@ -256,11 +284,16 @@ static void pack_B_tile_fp32_to_fp16(const Mat& B, Mat& BT, int j, int max_jj, i } } -static void transpose_pack_B_tile_fp32_to_fp16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +#if __riscv_vector && __riscv_xtheadvector +// FIXME inline causes incorrect codegen on c906 +__attribute__((noinline)) +#endif +static void +transpose_pack_B_tile_fp32_to_fp16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) { #if __riscv_vector - const int packn = csrr_vlenb() / 2; - const size_t vl = __riscv_vsetvl_e32m2(packn); + const size_t vl8 = __riscv_vsetvl_e32m2(8); + const size_t vl4 = __riscv_vsetvl_e32m2(4); #endif const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; @@ -269,15 +302,27 @@ static void transpose_pack_B_tile_fp32_to_fp16(const Mat& B, Mat& BT, int j, int int jj = 0; #if __riscv_vector - for (; jj + (packn - 1) < max_jj; jj += packn) + for (; jj + 7 < max_jj; jj += 8) { const float* p0 = (const float*)B + k * B_hstep + (j + jj); int kk = 0; for (; kk < max_kk; kk++) { - __riscv_vse16_v_f16m1(pp, __riscv_vfncvt_f_f_w_f16m1(__riscv_vle32_v_f32m2(p0, vl), vl), vl); - pp += packn; + __riscv_vse16_v_f16m1(pp, __riscv_vfncvt_f_f_w_f16m1(__riscv_vle32_v_f32m2(p0, vl8), vl8), vl8); + pp += 8; + p0 += B_hstep; + } + } + for (; jj + 3 < max_jj; jj += 4) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + __riscv_vse16_v_f16m1(pp, __riscv_vfncvt_f_f_w_f16m1(__riscv_vle32_v_f32m2(p0, vl4), vl4), vl4); + pp += 4; p0 += B_hstep; } } @@ -314,9 +359,15 @@ static void transpose_unpack_output_tile_fp32_to_fp16(const Mat& topT, Mat& top_ #if __riscv_vector const int packn = csrr_vlenb() / 2; const size_t vl = __riscv_vsetvl_e32m2(packn); + const size_t vl8 = __riscv_vsetvl_e16m2(8); + const size_t vl8w = __riscv_vsetvl_e32m4(8); + const size_t vl4 = __riscv_vsetvl_e16m1(4); + const size_t vl4w = __riscv_vsetvl_e32m2(4); #endif +#if __riscv_vector const int out_elempack = top_blob.elempack; +#endif const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; const float* pp = topT; @@ -327,9 +378,25 @@ static void transpose_unpack_output_tile_fp32_to_fp16(const Mat& topT, Mat& top_ { if (out_elempack == packn) { - __fp16* p0 = (__fp16*)top_blob + j * out_hstep + (i + ii) * packn; + int jj = 0; + + const int r0 = j % packn; + if (r0 != 0) + { + const int nn = std::min(packn - r0, max_jj); + __fp16* p0 = (__fp16*)top_blob + (j / packn * packn) * out_hstep + r0 + (i + ii) * packn; - for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn) + for (; jj < nn; jj++) + { + __riscv_vsse16_v_f16m1(p0, packn * sizeof(__fp16), __riscv_vfncvt_f_f_w_f16m1(__riscv_vle32_v_f32m2(pp, vl), vl), vl); + pp += packn; + p0++; + } + } + + __fp16* p0 = (__fp16*)top_blob + (j + jj) * out_hstep + (i + ii) * packn; + + for (; jj + (packn - 1) < max_jj; jj += packn) { // transposeNxN for (int l = 0; l < packn; l++) @@ -337,8 +404,16 @@ static void transpose_unpack_output_tile_fp32_to_fp16(const Mat& topT, Mat& top_ __riscv_vsse16_v_f16m1(p0 + l, packn * sizeof(__fp16), __riscv_vfncvt_f_f_w_f16m1(__riscv_vle32_v_f32m2(pp, vl), vl), vl); pp += packn; } + p0 += out_hstep * packn; } + + for (; jj < max_jj; jj++) + { + __riscv_vsse16_v_f16m1(p0, packn * sizeof(__fp16), __riscv_vfncvt_f_f_w_f16m1(__riscv_vle32_v_f32m2(pp, vl), vl), vl); + pp += packn; + p0++; + } } if (out_elempack == 1) { @@ -358,25 +433,75 @@ static void transpose_unpack_output_tile_fp32_to_fp16(const Mat& topT, Mat& top_ #if __riscv_vector if (out_elempack == packn) { - __fp16* p0 = (__fp16*)top_blob + j * out_hstep + (i + ii) * packn; + int jj = 0; - for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn) + for (; jj + 7 < max_jj; jj += 8) { - for (int l = 0; l < packn; l++) - { - p0[l] = (__fp16)pp[l * 2 + 0]; - p0[packn + l] = (__fp16)pp[l * 2 + 1]; - } - pp += packn * 2; - p0 += out_hstep * packn; + const int out_j = j + jj; + __fp16* p0 = (__fp16*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + __riscv_vse16_v_f16m2(p0, __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pp, vl8w), vl8), vl8); + __riscv_vse16_v_f16m2(p0 + packn, __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pp + 8, vl8w), vl8), vl8); + pp += 8 * 2; + } + for (; jj + 3 < max_jj; jj += 4) + { + const int out_j = j + jj; + __fp16* p0 = (__fp16*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + __riscv_vse16_v_f16m1(p0, __riscv_vfncvt_f_f_w_f16m1(__riscv_vle32_v_f32m2(pp, vl4w), vl4), vl4); + __riscv_vse16_v_f16m1(p0 + packn, __riscv_vfncvt_f_f_w_f16m1(__riscv_vle32_v_f32m2(pp + 4, vl4w), vl4), vl4); + pp += 4 * 2; + } + for (; jj + 1 < max_jj; jj += 2) + { + const int out_j = j + jj; + __fp16* p0 = (__fp16*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + p0[0] = (__fp16)(pp[0]); + p0[1] = (__fp16)(pp[1]); + p0[packn] = (__fp16)(pp[2]); + p0[packn + 1] = (__fp16)(pp[3]); + pp += 2 * 2; + } + for (; jj < max_jj; jj += 1) + { + const int out_j = j + jj; + __fp16* p0 = (__fp16*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + p0[0] = (__fp16)(pp[0]); + p0[packn] = (__fp16)(pp[1]); + pp += 2; } } -#endif // __riscv_vector if (out_elempack == 1) +#endif // __riscv_vector { __fp16* p0 = (__fp16*)top_blob + j * out_hstep + (i + ii); - for (int jj = 0; jj < max_jj; jj += 1) + int jj = 0; +#if __riscv_vector + for (; jj + 7 < max_jj; jj += 8) + { + __riscv_vsse16_v_f16m2(p0, out_hstep * sizeof(__fp16), __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pp, vl8w), vl8), vl8); + __riscv_vsse16_v_f16m2(p0 + 1, out_hstep * sizeof(__fp16), __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pp + 8, vl8w), vl8), vl8); + pp += 8 * 2; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + __riscv_vsse16_v_f16m1(p0, out_hstep * sizeof(__fp16), __riscv_vfncvt_f_f_w_f16m1(__riscv_vle32_v_f32m2(pp, vl4w), vl4), vl4); + __riscv_vsse16_v_f16m1(p0 + 1, out_hstep * sizeof(__fp16), __riscv_vfncvt_f_f_w_f16m1(__riscv_vle32_v_f32m2(pp + 4, vl4w), vl4), vl4); + pp += 4 * 2; + p0 += out_hstep * 4; + } +#endif // __riscv_vector + for (; jj + 1 < max_jj; jj += 2) + { + p0[0] = (__fp16)(pp[0]); + p0[out_hstep] = (__fp16)(pp[1]); + p0[1] = (__fp16)(pp[2]); + p0[out_hstep + 1] = (__fp16)(pp[3]); + pp += 2 * 2; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) { p0[0] = (__fp16)(pp[0]); p0[1] = (__fp16)(pp[1]); @@ -390,24 +515,51 @@ static void transpose_unpack_output_tile_fp32_to_fp16(const Mat& topT, Mat& top_ #if __riscv_vector if (out_elempack == packn) { - __fp16* p0 = (__fp16*)top_blob + j * out_hstep + (i + ii) * packn; + int jj = 0; - for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn) + for (; jj + 7 < max_jj; jj += 8) { - for (int l = 0; l < packn; l++) - { - p0[l] = (__fp16)pp[l]; - } - pp += packn; - p0 += out_hstep * packn; + const int out_j = j + jj; + __fp16* p0 = (__fp16*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + __riscv_vse16_v_f16m2(p0, __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pp, vl8w), vl8), vl8); + pp += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + const int out_j = j + jj; + __fp16* p0 = (__fp16*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + __riscv_vse16_v_f16m1(p0, __riscv_vfncvt_f_f_w_f16m1(__riscv_vle32_v_f32m2(pp, vl4w), vl4), vl4); + pp += 4; + } + for (; jj < max_jj; jj += 1) + { + const int out_j = j + jj; + __fp16* p0 = (__fp16*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + p0[0] = (__fp16)(pp[0]); + pp += 1; } } -#endif // __riscv_vector if (out_elempack == 1) +#endif // __riscv_vector { __fp16* p0 = (__fp16*)top_blob + j * out_hstep + (i + ii); - for (int jj = 0; jj < max_jj; jj += 1) + int jj = 0; +#if __riscv_vector + for (; jj + 7 < max_jj; jj += 8) + { + __riscv_vsse16_v_f16m2(p0, out_hstep * sizeof(__fp16), __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pp, vl8w), vl8), vl8); + pp += 8; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + __riscv_vsse16_v_f16m1(p0, out_hstep * sizeof(__fp16), __riscv_vfncvt_f_f_w_f16m1(__riscv_vle32_v_f32m2(pp, vl4w), vl4), vl4); + pp += 4; + p0 += out_hstep * 4; + } +#endif // __riscv_vector + for (; jj < max_jj; jj += 1) { p0[0] = (__fp16)(pp[0]); pp += 1; @@ -423,9 +575,15 @@ static void gemm_transB_packed_tile_fp16s(const Mat& AT_tile, const Mat& BT_tile const int packn = csrr_vlenb() / 2; const size_t vl = __riscv_vsetvl_e16m1(packn); const size_t vl2 = __riscv_vsetvl_e32m2(packn); + const size_t vl8 = __riscv_vsetvl_e16m2(8); + const size_t vl8w = __riscv_vsetvl_e32m4(8); + const size_t vl4 = __riscv_vsetvl_e16m1(4); + const size_t vl4w = __riscv_vsetvl_e32m2(4); #endif +#if __riscv_vector const int out_elempack = top_blob.elempack; +#endif const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; const __fp16* pAT = AT_tile; @@ -456,28 +614,69 @@ static void gemm_transB_packed_tile_fp16s(const Mat& AT_tile, const Mat& BT_tile } int jj = 0; - for (; jj + (packn - 1) < max_jj; jj += packn) + for (; jj + 7 < max_jj; jj += 8) { - if (packn == 16) - { - vfloat32m2_t _sum0; - vfloat32m2_t _sum1; - vfloat32m2_t _sum2; - vfloat32m2_t _sum3; - vfloat32m2_t _sum4; - vfloat32m2_t _sum5; - vfloat32m2_t _sum6; - vfloat32m2_t _sum7; - vfloat32m2_t _sum8; - vfloat32m2_t _sum9; - vfloat32m2_t _suma; - vfloat32m2_t _sumb; - vfloat32m2_t _sumc; - vfloat32m2_t _sumd; - vfloat32m2_t _sume; - vfloat32m2_t _sumf; - - if (k == 0) + vfloat32m2_t _sum0; + vfloat32m2_t _sum1; + vfloat32m2_t _sum2; + vfloat32m2_t _sum3; + vfloat32m2_t _sum4; + vfloat32m2_t _sum5; + vfloat32m2_t _sum6; + vfloat32m2_t _sum7; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = __riscv_vfmv_v_f_f32m2(pC[0], vl2); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = __riscv_vle32_v_f32m2(pC, vl2); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = __riscv_vle32_v_f32m2(pC, vl2); + _sum1 = __riscv_vle32_v_f32m2(pC + packn, vl2); + _sum2 = __riscv_vle32_v_f32m2(pC + packn * 2, vl2); + _sum3 = __riscv_vle32_v_f32m2(pC + packn * 3, vl2); + _sum4 = __riscv_vle32_v_f32m2(pC + packn * 4, vl2); + _sum5 = __riscv_vle32_v_f32m2(pC + packn * 5, vl2); + _sum6 = __riscv_vle32_v_f32m2(pC + packn * 6, vl2); + _sum7 = __riscv_vle32_v_f32m2(pC + packn * 7, vl2); + pC += packn * 8; + } + if (broadcast_type_C == 4) + { + _sum0 = __riscv_vfmv_v_f_f32m2(pC[0], vl2); + _sum1 = __riscv_vfmv_v_f_f32m2(pC[1], vl2); + _sum2 = __riscv_vfmv_v_f_f32m2(pC[2], vl2); + _sum3 = __riscv_vfmv_v_f_f32m2(pC[3], vl2); + _sum4 = __riscv_vfmv_v_f_f32m2(pC[4], vl2); + _sum5 = __riscv_vfmv_v_f_f32m2(pC[5], vl2); + _sum6 = __riscv_vfmv_v_f_f32m2(pC[6], vl2); + _sum7 = __riscv_vfmv_v_f_f32m2(pC[7], vl2); + pC += 8; + } + } + else { _sum0 = __riscv_vfmv_v_f_f32m2(0.f, vl2); _sum1 = __riscv_vfmv_v_f_f32m2(0.f, vl2); @@ -487,391 +686,201 @@ static void gemm_transB_packed_tile_fp16s(const Mat& AT_tile, const Mat& BT_tile _sum5 = __riscv_vfmv_v_f_f32m2(0.f, vl2); _sum6 = __riscv_vfmv_v_f_f32m2(0.f, vl2); _sum7 = __riscv_vfmv_v_f_f32m2(0.f, vl2); - _sum8 = __riscv_vfmv_v_f_f32m2(0.f, vl2); - _sum9 = __riscv_vfmv_v_f_f32m2(0.f, vl2); - _suma = __riscv_vfmv_v_f_f32m2(0.f, vl2); - _sumb = __riscv_vfmv_v_f_f32m2(0.f, vl2); - _sumc = __riscv_vfmv_v_f_f32m2(0.f, vl2); - _sumd = __riscv_vfmv_v_f_f32m2(0.f, vl2); - _sume = __riscv_vfmv_v_f_f32m2(0.f, vl2); - _sumf = __riscv_vfmv_v_f_f32m2(0.f, vl2); - - if (pC) - { - if (broadcast_type_C == 0) - { - _sum0 = __riscv_vfmv_v_f_f32m2(pC[0], vl2); - _sum1 = _sum0; - _sum2 = _sum0; - _sum3 = _sum0; - _sum4 = _sum0; - _sum5 = _sum0; - _sum6 = _sum0; - _sum7 = _sum0; - _sum8 = _sum0; - _sum9 = _sum0; - _suma = _sum0; - _sumb = _sum0; - _sumc = _sum0; - _sumd = _sum0; - _sume = _sum0; - _sumf = _sum0; - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _sum0 = __riscv_vle32_v_f32m2(pC, vl2); - _sum1 = _sum0; - _sum2 = _sum0; - _sum3 = _sum0; - _sum4 = _sum0; - _sum5 = _sum0; - _sum6 = _sum0; - _sum7 = _sum0; - _sum8 = _sum0; - _sum9 = _sum0; - _suma = _sum0; - _sumb = _sum0; - _sumc = _sum0; - _sumd = _sum0; - _sume = _sum0; - _sumf = _sum0; - } - if (broadcast_type_C == 3) - { - _sum0 = __riscv_vle32_v_f32m2(pC, vl2); - _sum1 = __riscv_vle32_v_f32m2(pC + packn, vl2); - _sum2 = __riscv_vle32_v_f32m2(pC + packn * 2, vl2); - _sum3 = __riscv_vle32_v_f32m2(pC + packn * 3, vl2); - _sum4 = __riscv_vle32_v_f32m2(pC + packn * 4, vl2); - _sum5 = __riscv_vle32_v_f32m2(pC + packn * 5, vl2); - _sum6 = __riscv_vle32_v_f32m2(pC + packn * 6, vl2); - _sum7 = __riscv_vle32_v_f32m2(pC + packn * 7, vl2); - _sum8 = __riscv_vle32_v_f32m2(pC + packn * 8, vl2); - _sum9 = __riscv_vle32_v_f32m2(pC + packn * 9, vl2); - _suma = __riscv_vle32_v_f32m2(pC + packn * 10, vl2); - _sumb = __riscv_vle32_v_f32m2(pC + packn * 11, vl2); - _sumc = __riscv_vle32_v_f32m2(pC + packn * 12, vl2); - _sumd = __riscv_vle32_v_f32m2(pC + packn * 13, vl2); - _sume = __riscv_vle32_v_f32m2(pC + packn * 14, vl2); - _sumf = __riscv_vle32_v_f32m2(pC + packn * 15, vl2); - pC += packn * 16; - } - if (broadcast_type_C == 4) - { - _sum0 = __riscv_vfmv_v_f_f32m2(pC[0], vl2); - _sum1 = __riscv_vfmv_v_f_f32m2(pC[1], vl2); - _sum2 = __riscv_vfmv_v_f_f32m2(pC[2], vl2); - _sum3 = __riscv_vfmv_v_f_f32m2(pC[3], vl2); - _sum4 = __riscv_vfmv_v_f_f32m2(pC[4], vl2); - _sum5 = __riscv_vfmv_v_f_f32m2(pC[5], vl2); - _sum6 = __riscv_vfmv_v_f_f32m2(pC[6], vl2); - _sum7 = __riscv_vfmv_v_f_f32m2(pC[7], vl2); - _sum8 = __riscv_vfmv_v_f_f32m2(pC[8], vl2); - _sum9 = __riscv_vfmv_v_f_f32m2(pC[9], vl2); - _suma = __riscv_vfmv_v_f_f32m2(pC[10], vl2); - _sumb = __riscv_vfmv_v_f_f32m2(pC[11], vl2); - _sumc = __riscv_vfmv_v_f_f32m2(pC[12], vl2); - _sumd = __riscv_vfmv_v_f_f32m2(pC[13], vl2); - _sume = __riscv_vfmv_v_f_f32m2(pC[14], vl2); - _sumf = __riscv_vfmv_v_f_f32m2(pC[15], vl2); - pC += 16; - } - } - } - else - { - _sum0 = __riscv_vle32_v_f32m2(outptr, vl2); - _sum1 = __riscv_vle32_v_f32m2(outptr + packn, vl2); - _sum2 = __riscv_vle32_v_f32m2(outptr + packn * 2, vl2); - _sum3 = __riscv_vle32_v_f32m2(outptr + packn * 3, vl2); - _sum4 = __riscv_vle32_v_f32m2(outptr + packn * 4, vl2); - _sum5 = __riscv_vle32_v_f32m2(outptr + packn * 5, vl2); - _sum6 = __riscv_vle32_v_f32m2(outptr + packn * 6, vl2); - _sum7 = __riscv_vle32_v_f32m2(outptr + packn * 7, vl2); - _sum8 = __riscv_vle32_v_f32m2(outptr + packn * 8, vl2); - _sum9 = __riscv_vle32_v_f32m2(outptr + packn * 9, vl2); - _suma = __riscv_vle32_v_f32m2(outptr + packn * 10, vl2); - _sumb = __riscv_vle32_v_f32m2(outptr + packn * 11, vl2); - _sumc = __riscv_vle32_v_f32m2(outptr + packn * 12, vl2); - _sumd = __riscv_vle32_v_f32m2(outptr + packn * 13, vl2); - _sume = __riscv_vle32_v_f32m2(outptr + packn * 14, vl2); - _sumf = __riscv_vle32_v_f32m2(outptr + packn * 15, vl2); } + } + else + { + _sum0 = __riscv_vle32_v_f32m2(outptr, vl2); + _sum1 = __riscv_vle32_v_f32m2(outptr + packn, vl2); + _sum2 = __riscv_vle32_v_f32m2(outptr + packn * 2, vl2); + _sum3 = __riscv_vle32_v_f32m2(outptr + packn * 3, vl2); + _sum4 = __riscv_vle32_v_f32m2(outptr + packn * 4, vl2); + _sum5 = __riscv_vle32_v_f32m2(outptr + packn * 5, vl2); + _sum6 = __riscv_vle32_v_f32m2(outptr + packn * 6, vl2); + _sum7 = __riscv_vle32_v_f32m2(outptr + packn * 7, vl2); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pA = __riscv_vle16_v_f16m1(pA, vl); + _sum0 = __riscv_vfwmacc_vf_f32m2(_sum0, pB[0], _pA, vl); + _sum1 = __riscv_vfwmacc_vf_f32m2(_sum1, pB[1], _pA, vl); + _sum2 = __riscv_vfwmacc_vf_f32m2(_sum2, pB[2], _pA, vl); + _sum3 = __riscv_vfwmacc_vf_f32m2(_sum3, pB[3], _pA, vl); + _sum4 = __riscv_vfwmacc_vf_f32m2(_sum4, pB[4], _pA, vl); + _sum5 = __riscv_vfwmacc_vf_f32m2(_sum5, pB[5], _pA, vl); + _sum6 = __riscv_vfwmacc_vf_f32m2(_sum6, pB[6], _pA, vl); + _sum7 = __riscv_vfwmacc_vf_f32m2(_sum7, pB[7], _pA, vl); + pA += packn; + pB += 8; + } + + if (alpha != 1.f) + { + _sum0 = __riscv_vfmul_vf_f32m2(_sum0, alpha, vl2); + _sum1 = __riscv_vfmul_vf_f32m2(_sum1, alpha, vl2); + _sum2 = __riscv_vfmul_vf_f32m2(_sum2, alpha, vl2); + _sum3 = __riscv_vfmul_vf_f32m2(_sum3, alpha, vl2); + _sum4 = __riscv_vfmul_vf_f32m2(_sum4, alpha, vl2); + _sum5 = __riscv_vfmul_vf_f32m2(_sum5, alpha, vl2); + _sum6 = __riscv_vfmul_vf_f32m2(_sum6, alpha, vl2); + _sum7 = __riscv_vfmul_vf_f32m2(_sum7, alpha, vl2); + } - const __fp16* pA = pAT; - int kk = 0; - for (; kk < max_kk; kk += 1) + if (k_end) + { + if (out_elempack == packn) { - vfloat16m1_t _pA = __riscv_vle16_v_f16m1(pA, vl); - _sum0 = __riscv_vfwmacc_vf_f32m2(_sum0, pB[0], _pA, vl); - _sum1 = __riscv_vfwmacc_vf_f32m2(_sum1, pB[1], _pA, vl); - _sum2 = __riscv_vfwmacc_vf_f32m2(_sum2, pB[2], _pA, vl); - _sum3 = __riscv_vfwmacc_vf_f32m2(_sum3, pB[3], _pA, vl); - _sum4 = __riscv_vfwmacc_vf_f32m2(_sum4, pB[4], _pA, vl); - _sum5 = __riscv_vfwmacc_vf_f32m2(_sum5, pB[5], _pA, vl); - _sum6 = __riscv_vfwmacc_vf_f32m2(_sum6, pB[6], _pA, vl); - _sum7 = __riscv_vfwmacc_vf_f32m2(_sum7, pB[7], _pA, vl); - _sum8 = __riscv_vfwmacc_vf_f32m2(_sum8, pB[8], _pA, vl); - _sum9 = __riscv_vfwmacc_vf_f32m2(_sum9, pB[9], _pA, vl); - _suma = __riscv_vfwmacc_vf_f32m2(_suma, pB[10], _pA, vl); - _sumb = __riscv_vfwmacc_vf_f32m2(_sumb, pB[11], _pA, vl); - _sumc = __riscv_vfwmacc_vf_f32m2(_sumc, pB[12], _pA, vl); - _sumd = __riscv_vfwmacc_vf_f32m2(_sumd, pB[13], _pA, vl); - _sume = __riscv_vfwmacc_vf_f32m2(_sume, pB[14], _pA, vl); - _sumf = __riscv_vfwmacc_vf_f32m2(_sumf, pB[15], _pA, vl); - pA += packn; - pB += 16; + __riscv_vse16_v_f16m1(outptr0, __riscv_vfncvt_f_f_w_f16m1(_sum0, vl), vl); + __riscv_vse16_v_f16m1(outptr0 + packn, __riscv_vfncvt_f_f_w_f16m1(_sum1, vl), vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 2, __riscv_vfncvt_f_f_w_f16m1(_sum2, vl), vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 3, __riscv_vfncvt_f_f_w_f16m1(_sum3, vl), vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 4, __riscv_vfncvt_f_f_w_f16m1(_sum4, vl), vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 5, __riscv_vfncvt_f_f_w_f16m1(_sum5, vl), vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 6, __riscv_vfncvt_f_f_w_f16m1(_sum6, vl), vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 7, __riscv_vfncvt_f_f_w_f16m1(_sum7, vl), vl); + outptr0 += packn * 8; } - - if (alpha != 1.f) + if (out_elempack == 1) { - _sum0 = __riscv_vfmul_vf_f32m2(_sum0, alpha, vl2); - _sum1 = __riscv_vfmul_vf_f32m2(_sum1, alpha, vl2); - _sum2 = __riscv_vfmul_vf_f32m2(_sum2, alpha, vl2); - _sum3 = __riscv_vfmul_vf_f32m2(_sum3, alpha, vl2); - _sum4 = __riscv_vfmul_vf_f32m2(_sum4, alpha, vl2); - _sum5 = __riscv_vfmul_vf_f32m2(_sum5, alpha, vl2); - _sum6 = __riscv_vfmul_vf_f32m2(_sum6, alpha, vl2); - _sum7 = __riscv_vfmul_vf_f32m2(_sum7, alpha, vl2); - _sum8 = __riscv_vfmul_vf_f32m2(_sum8, alpha, vl2); - _sum9 = __riscv_vfmul_vf_f32m2(_sum9, alpha, vl2); - _suma = __riscv_vfmul_vf_f32m2(_suma, alpha, vl2); - _sumb = __riscv_vfmul_vf_f32m2(_sumb, alpha, vl2); - _sumc = __riscv_vfmul_vf_f32m2(_sumc, alpha, vl2); - _sumd = __riscv_vfmul_vf_f32m2(_sumd, alpha, vl2); - _sume = __riscv_vfmul_vf_f32m2(_sume, alpha, vl2); - _sumf = __riscv_vfmul_vf_f32m2(_sumf, alpha, vl2); + vfloat16m1x8_t _sum_f16 = __riscv_vcreate_v_f16m1x8( + __riscv_vfncvt_f_f_w_f16m1(_sum0, vl), + __riscv_vfncvt_f_f_w_f16m1(_sum1, vl), + __riscv_vfncvt_f_f_w_f16m1(_sum2, vl), + __riscv_vfncvt_f_f_w_f16m1(_sum3, vl), + __riscv_vfncvt_f_f_w_f16m1(_sum4, vl), + __riscv_vfncvt_f_f_w_f16m1(_sum5, vl), + __riscv_vfncvt_f_f_w_f16m1(_sum6, vl), + __riscv_vfncvt_f_f_w_f16m1(_sum7, vl)); + __riscv_vssseg8e16_v_f16m1x8(outptr0, out_hstep * sizeof(__fp16), _sum_f16, vl); + outptr0 += 8; } + } + else + { + __riscv_vse32_v_f32m2(outptr, _sum0, vl2); + __riscv_vse32_v_f32m2(outptr + packn, _sum1, vl2); + __riscv_vse32_v_f32m2(outptr + packn * 2, _sum2, vl2); + __riscv_vse32_v_f32m2(outptr + packn * 3, _sum3, vl2); + __riscv_vse32_v_f32m2(outptr + packn * 4, _sum4, vl2); + __riscv_vse32_v_f32m2(outptr + packn * 5, _sum5, vl2); + __riscv_vse32_v_f32m2(outptr + packn * 6, _sum6, vl2); + __riscv_vse32_v_f32m2(outptr + packn * 7, _sum7, vl2); + } + + outptr += packn * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat32m2_t _sum0; + vfloat32m2_t _sum1; + vfloat32m2_t _sum2; + vfloat32m2_t _sum3; - if (k_end) + if (k == 0) + { + if (pC) { - if (out_elempack == packn) + if (broadcast_type_C == 0) { - __riscv_vse16_v_f16m1(outptr0, __riscv_vfncvt_f_f_w_f16m1(_sum0, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn, __riscv_vfncvt_f_f_w_f16m1(_sum1, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 2, __riscv_vfncvt_f_f_w_f16m1(_sum2, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 3, __riscv_vfncvt_f_f_w_f16m1(_sum3, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 4, __riscv_vfncvt_f_f_w_f16m1(_sum4, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 5, __riscv_vfncvt_f_f_w_f16m1(_sum5, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 6, __riscv_vfncvt_f_f_w_f16m1(_sum6, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 7, __riscv_vfncvt_f_f_w_f16m1(_sum7, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 8, __riscv_vfncvt_f_f_w_f16m1(_sum8, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 9, __riscv_vfncvt_f_f_w_f16m1(_sum9, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 10, __riscv_vfncvt_f_f_w_f16m1(_suma, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 11, __riscv_vfncvt_f_f_w_f16m1(_sumb, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 12, __riscv_vfncvt_f_f_w_f16m1(_sumc, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 13, __riscv_vfncvt_f_f_w_f16m1(_sumd, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 14, __riscv_vfncvt_f_f_w_f16m1(_sume, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 15, __riscv_vfncvt_f_f_w_f16m1(_sumf, vl), vl); - outptr0 += packn * 16; + _sum0 = __riscv_vfmv_v_f_f32m2(pC[0], vl2); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; } - if (out_elempack == 1) + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = __riscv_vle32_v_f32m2(pC, vl2); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = __riscv_vle32_v_f32m2(pC, vl2); + _sum1 = __riscv_vle32_v_f32m2(pC + packn, vl2); + _sum2 = __riscv_vle32_v_f32m2(pC + packn * 2, vl2); + _sum3 = __riscv_vle32_v_f32m2(pC + packn * 3, vl2); + pC += packn * 4; + } + if (broadcast_type_C == 4) { - vfloat16m1x8_t _sum0_f16 = __riscv_vcreate_v_f16m1x8( - __riscv_vfncvt_f_f_w_f16m1(_sum0, vl), - __riscv_vfncvt_f_f_w_f16m1(_sum1, vl), - __riscv_vfncvt_f_f_w_f16m1(_sum2, vl), - __riscv_vfncvt_f_f_w_f16m1(_sum3, vl), - __riscv_vfncvt_f_f_w_f16m1(_sum4, vl), - __riscv_vfncvt_f_f_w_f16m1(_sum5, vl), - __riscv_vfncvt_f_f_w_f16m1(_sum6, vl), - __riscv_vfncvt_f_f_w_f16m1(_sum7, vl)); - vfloat16m1x8_t _sum1_f16 = __riscv_vcreate_v_f16m1x8( - __riscv_vfncvt_f_f_w_f16m1(_sum8, vl), - __riscv_vfncvt_f_f_w_f16m1(_sum9, vl), - __riscv_vfncvt_f_f_w_f16m1(_suma, vl), - __riscv_vfncvt_f_f_w_f16m1(_sumb, vl), - __riscv_vfncvt_f_f_w_f16m1(_sumc, vl), - __riscv_vfncvt_f_f_w_f16m1(_sumd, vl), - __riscv_vfncvt_f_f_w_f16m1(_sume, vl), - __riscv_vfncvt_f_f_w_f16m1(_sumf, vl)); - __riscv_vssseg8e16_v_f16m1x8(outptr0, out_hstep * sizeof(__fp16), _sum0_f16, vl); - __riscv_vssseg8e16_v_f16m1x8(outptr0 + 8, out_hstep * sizeof(__fp16), _sum1_f16, vl); - outptr0 += 16; + _sum0 = __riscv_vfmv_v_f_f32m2(pC[0], vl2); + _sum1 = __riscv_vfmv_v_f_f32m2(pC[1], vl2); + _sum2 = __riscv_vfmv_v_f_f32m2(pC[2], vl2); + _sum3 = __riscv_vfmv_v_f_f32m2(pC[3], vl2); + pC += 4; } } else - { - __riscv_vse32_v_f32m2(outptr, _sum0, vl2); - __riscv_vse32_v_f32m2(outptr + packn, _sum1, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 2, _sum2, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 3, _sum3, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 4, _sum4, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 5, _sum5, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 6, _sum6, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 7, _sum7, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 8, _sum8, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 9, _sum9, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 10, _suma, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 11, _sumb, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 12, _sumc, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 13, _sumd, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 14, _sume, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 15, _sumf, vl2); - } - - outptr += packn * 16; - } - else if (packn == 8) - { - vfloat32m2_t _sum0; - vfloat32m2_t _sum1; - vfloat32m2_t _sum2; - vfloat32m2_t _sum3; - vfloat32m2_t _sum4; - vfloat32m2_t _sum5; - vfloat32m2_t _sum6; - vfloat32m2_t _sum7; - - if (k == 0) { _sum0 = __riscv_vfmv_v_f_f32m2(0.f, vl2); _sum1 = __riscv_vfmv_v_f_f32m2(0.f, vl2); _sum2 = __riscv_vfmv_v_f_f32m2(0.f, vl2); _sum3 = __riscv_vfmv_v_f_f32m2(0.f, vl2); - _sum4 = __riscv_vfmv_v_f_f32m2(0.f, vl2); - _sum5 = __riscv_vfmv_v_f_f32m2(0.f, vl2); - _sum6 = __riscv_vfmv_v_f_f32m2(0.f, vl2); - _sum7 = __riscv_vfmv_v_f_f32m2(0.f, vl2); - - if (pC) - { - if (broadcast_type_C == 0) - { - _sum0 = __riscv_vfmv_v_f_f32m2(pC[0], vl2); - _sum1 = _sum0; - _sum2 = _sum0; - _sum3 = _sum0; - _sum4 = _sum0; - _sum5 = _sum0; - _sum6 = _sum0; - _sum7 = _sum0; - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _sum0 = __riscv_vle32_v_f32m2(pC, vl2); - _sum1 = _sum0; - _sum2 = _sum0; - _sum3 = _sum0; - _sum4 = _sum0; - _sum5 = _sum0; - _sum6 = _sum0; - _sum7 = _sum0; - } - if (broadcast_type_C == 3) - { - _sum0 = __riscv_vle32_v_f32m2(pC, vl2); - _sum1 = __riscv_vle32_v_f32m2(pC + packn, vl2); - _sum2 = __riscv_vle32_v_f32m2(pC + packn * 2, vl2); - _sum3 = __riscv_vle32_v_f32m2(pC + packn * 3, vl2); - _sum4 = __riscv_vle32_v_f32m2(pC + packn * 4, vl2); - _sum5 = __riscv_vle32_v_f32m2(pC + packn * 5, vl2); - _sum6 = __riscv_vle32_v_f32m2(pC + packn * 6, vl2); - _sum7 = __riscv_vle32_v_f32m2(pC + packn * 7, vl2); - pC += packn * 8; - } - if (broadcast_type_C == 4) - { - _sum0 = __riscv_vfmv_v_f_f32m2(pC[0], vl2); - _sum1 = __riscv_vfmv_v_f_f32m2(pC[1], vl2); - _sum2 = __riscv_vfmv_v_f_f32m2(pC[2], vl2); - _sum3 = __riscv_vfmv_v_f_f32m2(pC[3], vl2); - _sum4 = __riscv_vfmv_v_f_f32m2(pC[4], vl2); - _sum5 = __riscv_vfmv_v_f_f32m2(pC[5], vl2); - _sum6 = __riscv_vfmv_v_f_f32m2(pC[6], vl2); - _sum7 = __riscv_vfmv_v_f_f32m2(pC[7], vl2); - pC += 8; - } - } - } - else - { - _sum0 = __riscv_vle32_v_f32m2(outptr, vl2); - _sum1 = __riscv_vle32_v_f32m2(outptr + packn, vl2); - _sum2 = __riscv_vle32_v_f32m2(outptr + packn * 2, vl2); - _sum3 = __riscv_vle32_v_f32m2(outptr + packn * 3, vl2); - _sum4 = __riscv_vle32_v_f32m2(outptr + packn * 4, vl2); - _sum5 = __riscv_vle32_v_f32m2(outptr + packn * 5, vl2); - _sum6 = __riscv_vle32_v_f32m2(outptr + packn * 6, vl2); - _sum7 = __riscv_vle32_v_f32m2(outptr + packn * 7, vl2); } + } + else + { + _sum0 = __riscv_vle32_v_f32m2(outptr, vl2); + _sum1 = __riscv_vle32_v_f32m2(outptr + packn, vl2); + _sum2 = __riscv_vle32_v_f32m2(outptr + packn * 2, vl2); + _sum3 = __riscv_vle32_v_f32m2(outptr + packn * 3, vl2); + } - const __fp16* pA = pAT; - int kk = 0; - for (; kk < max_kk; kk += 1) - { - vfloat16m1_t _pA = __riscv_vle16_v_f16m1(pA, vl); - _sum0 = __riscv_vfwmacc_vf_f32m2(_sum0, pB[0], _pA, vl); - _sum1 = __riscv_vfwmacc_vf_f32m2(_sum1, pB[1], _pA, vl); - _sum2 = __riscv_vfwmacc_vf_f32m2(_sum2, pB[2], _pA, vl); - _sum3 = __riscv_vfwmacc_vf_f32m2(_sum3, pB[3], _pA, vl); - _sum4 = __riscv_vfwmacc_vf_f32m2(_sum4, pB[4], _pA, vl); - _sum5 = __riscv_vfwmacc_vf_f32m2(_sum5, pB[5], _pA, vl); - _sum6 = __riscv_vfwmacc_vf_f32m2(_sum6, pB[6], _pA, vl); - _sum7 = __riscv_vfwmacc_vf_f32m2(_sum7, pB[7], _pA, vl); - pA += packn; - pB += 8; - } + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pA = __riscv_vle16_v_f16m1(pA, vl); + _sum0 = __riscv_vfwmacc_vf_f32m2(_sum0, pB[0], _pA, vl); + _sum1 = __riscv_vfwmacc_vf_f32m2(_sum1, pB[1], _pA, vl); + _sum2 = __riscv_vfwmacc_vf_f32m2(_sum2, pB[2], _pA, vl); + _sum3 = __riscv_vfwmacc_vf_f32m2(_sum3, pB[3], _pA, vl); + pA += packn; + pB += 4; + } - if (alpha != 1.f) - { - _sum0 = __riscv_vfmul_vf_f32m2(_sum0, alpha, vl2); - _sum1 = __riscv_vfmul_vf_f32m2(_sum1, alpha, vl2); - _sum2 = __riscv_vfmul_vf_f32m2(_sum2, alpha, vl2); - _sum3 = __riscv_vfmul_vf_f32m2(_sum3, alpha, vl2); - _sum4 = __riscv_vfmul_vf_f32m2(_sum4, alpha, vl2); - _sum5 = __riscv_vfmul_vf_f32m2(_sum5, alpha, vl2); - _sum6 = __riscv_vfmul_vf_f32m2(_sum6, alpha, vl2); - _sum7 = __riscv_vfmul_vf_f32m2(_sum7, alpha, vl2); - } + if (alpha != 1.f) + { + _sum0 = __riscv_vfmul_vf_f32m2(_sum0, alpha, vl2); + _sum1 = __riscv_vfmul_vf_f32m2(_sum1, alpha, vl2); + _sum2 = __riscv_vfmul_vf_f32m2(_sum2, alpha, vl2); + _sum3 = __riscv_vfmul_vf_f32m2(_sum3, alpha, vl2); + } - if (k_end) + if (k_end) + { + if (out_elempack == packn) { - if (out_elempack == packn) - { - __riscv_vse16_v_f16m1(outptr0, __riscv_vfncvt_f_f_w_f16m1(_sum0, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn, __riscv_vfncvt_f_f_w_f16m1(_sum1, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 2, __riscv_vfncvt_f_f_w_f16m1(_sum2, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 3, __riscv_vfncvt_f_f_w_f16m1(_sum3, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 4, __riscv_vfncvt_f_f_w_f16m1(_sum4, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 5, __riscv_vfncvt_f_f_w_f16m1(_sum5, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 6, __riscv_vfncvt_f_f_w_f16m1(_sum6, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + packn * 7, __riscv_vfncvt_f_f_w_f16m1(_sum7, vl), vl); - outptr0 += packn * 8; - } - if (out_elempack == 1) - { - vfloat16m1x8_t _sum_f16 = __riscv_vcreate_v_f16m1x8( - __riscv_vfncvt_f_f_w_f16m1(_sum0, vl), - __riscv_vfncvt_f_f_w_f16m1(_sum1, vl), - __riscv_vfncvt_f_f_w_f16m1(_sum2, vl), - __riscv_vfncvt_f_f_w_f16m1(_sum3, vl), - __riscv_vfncvt_f_f_w_f16m1(_sum4, vl), - __riscv_vfncvt_f_f_w_f16m1(_sum5, vl), - __riscv_vfncvt_f_f_w_f16m1(_sum6, vl), - __riscv_vfncvt_f_f_w_f16m1(_sum7, vl)); - __riscv_vssseg8e16_v_f16m1x8(outptr0, out_hstep * sizeof(__fp16), _sum_f16, vl); - outptr0 += 8; - } + __riscv_vse16_v_f16m1(outptr0, __riscv_vfncvt_f_f_w_f16m1(_sum0, vl), vl); + __riscv_vse16_v_f16m1(outptr0 + packn, __riscv_vfncvt_f_f_w_f16m1(_sum1, vl), vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 2, __riscv_vfncvt_f_f_w_f16m1(_sum2, vl), vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 3, __riscv_vfncvt_f_f_w_f16m1(_sum3, vl), vl); + outptr0 += packn * 4; } - else + if (out_elempack == 1) { - __riscv_vse32_v_f32m2(outptr, _sum0, vl2); - __riscv_vse32_v_f32m2(outptr + packn, _sum1, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 2, _sum2, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 3, _sum3, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 4, _sum4, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 5, _sum5, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 6, _sum6, vl2); - __riscv_vse32_v_f32m2(outptr + packn * 7, _sum7, vl2); + vfloat16m1x4_t _sum_f16 = __riscv_vcreate_v_f16m1x4( + __riscv_vfncvt_f_f_w_f16m1(_sum0, vl), + __riscv_vfncvt_f_f_w_f16m1(_sum1, vl), + __riscv_vfncvt_f_f_w_f16m1(_sum2, vl), + __riscv_vfncvt_f_f_w_f16m1(_sum3, vl)); + __riscv_vssseg4e16_v_f16m1x4(outptr0, out_hstep * sizeof(__fp16), _sum_f16, vl); + outptr0 += 4; } - - outptr += packn * 8; } else { - NCNN_LOGE("unsupported vector length"); + __riscv_vse32_v_f32m2(outptr, _sum0, vl2); + __riscv_vse32_v_f32m2(outptr + packn, _sum1, vl2); + __riscv_vse32_v_f32m2(outptr + packn * 2, _sum2, vl2); + __riscv_vse32_v_f32m2(outptr + packn * 3, _sum3, vl2); } + + outptr += packn * 4; } for (; jj + 1 < max_jj; jj += 2) { @@ -880,9 +889,6 @@ static void gemm_transB_packed_tile_fp16s(const Mat& AT_tile, const Mat& BT_tile if (k == 0) { - _sum0 = __riscv_vfmv_v_f_f32m2(0.f, vl2); - _sum1 = __riscv_vfmv_v_f_f32m2(0.f, vl2); - if (pC) { if (broadcast_type_C == 0) @@ -908,6 +914,11 @@ static void gemm_transB_packed_tile_fp16s(const Mat& AT_tile, const Mat& BT_tile pC += 2; } } + else + { + _sum0 = __riscv_vfmv_v_f_f32m2(0.f, vl2); + _sum1 = __riscv_vfmv_v_f_f32m2(0.f, vl2); + } } else { @@ -963,8 +974,6 @@ static void gemm_transB_packed_tile_fp16s(const Mat& AT_tile, const Mat& BT_tile if (k == 0) { - _sum0 = __riscv_vfmv_v_f_f32m2(0.f, vl2); - if (pC) { if (broadcast_type_C == 0) @@ -986,6 +995,10 @@ static void gemm_transB_packed_tile_fp16s(const Mat& AT_tile, const Mat& BT_tile pC += 1; } } + else + { + _sum0 = __riscv_vfmv_v_f_f32m2(0.f, vl2); + } } else { @@ -1051,82 +1064,163 @@ static void gemm_transB_packed_tile_fp16s(const Mat& AT_tile, const Mat& BT_tile int jj = 0; #if __riscv_vector - for (; jj + (packn - 1) < max_jj; jj += packn) + for (; jj + 7 < max_jj; jj += 8) + { + vfloat32m4_t _sum0; + vfloat32m4_t _sum1; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = __riscv_vfmv_v_f_f32m4(pC[0], vl8w); + _sum1 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = __riscv_vfmv_v_f_f32m4(pC[0], vl8w); + _sum1 = __riscv_vfmv_v_f_f32m4(pC[1], vl8w); + } + if (broadcast_type_C == 3) + { + vfloat32m4x2_t _s0 = __riscv_vlseg2e32_v_f32m4x2(pC, vl8w); + _sum0 = __riscv_vget_v_f32m4x2_f32m4(_s0, 0); + _sum1 = __riscv_vget_v_f32m4x2_f32m4(_s0, 1); + pC += 8 * 2; + } + if (broadcast_type_C == 4) + { + _sum0 = __riscv_vle32_v_f32m4(pC, vl8w); + _sum1 = _sum0; + pC += 8; + } + } + else + { + _sum0 = __riscv_vfmv_v_f_f32m4(0.f, vl8w); + _sum1 = __riscv_vfmv_v_f_f32m4(0.f, vl8w); + } + } + else + { + _sum0 = __riscv_vle32_v_f32m4(outptr, vl8w); + _sum1 = __riscv_vle32_v_f32m4(outptr + 8, vl8w); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m2_t _pB = __riscv_vle16_v_f16m2(pB, vl8); + _sum0 = __riscv_vfwmacc_vf_f32m4(_sum0, pA[0], _pB, vl8); + _sum1 = __riscv_vfwmacc_vf_f32m4(_sum1, pA[1], _pB, vl8); + pA += 2; + pB += 8; + } + + if (alpha != 1.f) + { + _sum0 = __riscv_vfmul_vf_f32m4(_sum0, alpha, vl8w); + _sum1 = __riscv_vfmul_vf_f32m4(_sum1, alpha, vl8w); + } + + if (k_end) + { + // if (out_elempack == 1) + { + __riscv_vse16_v_f16m2(outptr0, __riscv_vfncvt_f_f_w_f16m2(_sum0, vl8), vl8); + __riscv_vse16_v_f16m2(outptr0 + out_hstep, __riscv_vfncvt_f_f_w_f16m2(_sum1, vl8), vl8); + outptr0 += 8; + } + } + else + { + __riscv_vse32_v_f32m4(outptr, _sum0, vl8w); + __riscv_vse32_v_f32m4(outptr + 8, _sum1, vl8w); + } + + outptr += 8 * 2; + } + for (; jj + 3 < max_jj; jj += 4) { vfloat32m2_t _sum0; vfloat32m2_t _sum1; if (k == 0) { - _sum0 = __riscv_vfmv_v_f_f32m2(0.f, vl2); - _sum1 = __riscv_vfmv_v_f_f32m2(0.f, vl2); - if (pC) { if (broadcast_type_C == 0) { - _sum0 = __riscv_vfmv_v_f_f32m2(pC[0], vl2); + _sum0 = __riscv_vfmv_v_f_f32m2(pC[0], vl4w); _sum1 = _sum0; } if (broadcast_type_C == 1 || broadcast_type_C == 2) { - _sum0 = __riscv_vfmv_v_f_f32m2(pC[0], vl2); - _sum1 = __riscv_vfmv_v_f_f32m2(pC[1], vl2); + _sum0 = __riscv_vfmv_v_f_f32m2(pC[0], vl4w); + _sum1 = __riscv_vfmv_v_f_f32m2(pC[1], vl4w); } if (broadcast_type_C == 3) { - vfloat32m2x2_t _s0 = __riscv_vlseg2e32_v_f32m2x2(pC, vl2); + vfloat32m2x2_t _s0 = __riscv_vlseg2e32_v_f32m2x2(pC, vl4w); _sum0 = __riscv_vget_v_f32m2x2_f32m2(_s0, 0); _sum1 = __riscv_vget_v_f32m2x2_f32m2(_s0, 1); - pC += packn * 2; + pC += 4 * 2; } if (broadcast_type_C == 4) { - _sum0 = __riscv_vle32_v_f32m2(pC, vl2); + _sum0 = __riscv_vle32_v_f32m2(pC, vl4w); _sum1 = _sum0; - pC += packn; + pC += 4; } } + else + { + _sum0 = __riscv_vfmv_v_f_f32m2(0.f, vl4w); + _sum1 = __riscv_vfmv_v_f_f32m2(0.f, vl4w); + } } else { - vfloat32m2x2_t _s0 = __riscv_vlseg2e32_v_f32m2x2(outptr, vl2); - _sum0 = __riscv_vget_v_f32m2x2_f32m2(_s0, 0); - _sum1 = __riscv_vget_v_f32m2x2_f32m2(_s0, 1); + _sum0 = __riscv_vle32_v_f32m2(outptr, vl4w); + _sum1 = __riscv_vle32_v_f32m2(outptr + 4, vl4w); } const __fp16* pA = pAT; int kk = 0; for (; kk < max_kk; kk += 1) { - vfloat16m1_t _pB = __riscv_vle16_v_f16m1(pB, vl); - _sum0 = __riscv_vfwmacc_vf_f32m2(_sum0, pA[0], _pB, vl); - _sum1 = __riscv_vfwmacc_vf_f32m2(_sum1, pA[1], _pB, vl); + vfloat16m1_t _pB = __riscv_vle16_v_f16m1(pB, vl4); + _sum0 = __riscv_vfwmacc_vf_f32m2(_sum0, pA[0], _pB, vl4); + _sum1 = __riscv_vfwmacc_vf_f32m2(_sum1, pA[1], _pB, vl4); pA += 2; - pB += packn; + pB += 4; } if (alpha != 1.f) { - _sum0 = __riscv_vfmul_vf_f32m2(_sum0, alpha, vl2); - _sum1 = __riscv_vfmul_vf_f32m2(_sum1, alpha, vl2); + _sum0 = __riscv_vfmul_vf_f32m2(_sum0, alpha, vl4w); + _sum1 = __riscv_vfmul_vf_f32m2(_sum1, alpha, vl4w); } if (k_end) { // if (out_elempack == 1) { - __riscv_vse16_v_f16m1(outptr0, __riscv_vfncvt_f_f_w_f16m1(_sum0, vl), vl); - __riscv_vse16_v_f16m1(outptr0 + out_hstep, __riscv_vfncvt_f_f_w_f16m1(_sum1, vl), vl); - outptr0 += packn; + __riscv_vse16_v_f16m1(outptr0, __riscv_vfncvt_f_f_w_f16m1(_sum0, vl4), vl4); + __riscv_vse16_v_f16m1(outptr0 + out_hstep, __riscv_vfncvt_f_f_w_f16m1(_sum1, vl4), vl4); + outptr0 += 4; } } else { - __riscv_vsseg2e32_v_f32m2x2(outptr, __riscv_vcreate_v_f32m2x2(_sum0, _sum1), vl2); + __riscv_vse32_v_f32m2(outptr, _sum0, vl4w); + __riscv_vse32_v_f32m2(outptr + 4, _sum1, vl4w); } - outptr += packn * 2; + outptr += 4 * 2; } #endif // __riscv_vector for (; jj + 1 < max_jj; jj += 2) @@ -1180,8 +1274,8 @@ static void gemm_transB_packed_tile_fp16s(const Mat& AT_tile, const Mat& BT_tile else { sum00 = outptr[0]; - sum01 = outptr[1]; - sum10 = outptr[2]; + sum10 = outptr[1]; + sum01 = outptr[2]; sum11 = outptr[3]; } @@ -1219,8 +1313,8 @@ static void gemm_transB_packed_tile_fp16s(const Mat& AT_tile, const Mat& BT_tile else { outptr[0] = sum00; - outptr[1] = sum01; - outptr[2] = sum10; + outptr[1] = sum10; + outptr[2] = sum01; outptr[3] = sum11; } @@ -1324,61 +1418,121 @@ static void gemm_transB_packed_tile_fp16s(const Mat& AT_tile, const Mat& BT_tile int jj = 0; #if __riscv_vector - for (; jj + (packn - 1) < max_jj; jj += packn) + for (; jj + 7 < max_jj; jj += 8) { - vfloat32m2_t _sum; + vfloat32m4_t _sum; if (k == 0) { - _sum = __riscv_vfmv_v_f_f32m2(0.f, vl2); + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum = __riscv_vfmv_v_f_f32m4(pC[0], vl8w); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + _sum = __riscv_vle32_v_f32m4(pC, vl8w); + pC += 8; + } + } + else + { + _sum = __riscv_vfmv_v_f_f32m4(0.f, vl8w); + } + } + else + { + _sum = __riscv_vle32_v_f32m4(outptr, vl8w); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m2_t _pB = __riscv_vle16_v_f16m2(pB, vl8); + _sum = __riscv_vfwmacc_vf_f32m4(_sum, pA[0], _pB, vl8); + pA += 1; + pB += 8; + } + + if (alpha != 1.f) + { + _sum = __riscv_vfmul_vf_f32m4(_sum, alpha, vl8w); + } + if (k_end) + { + // if (out_elempack == 1) + { + __riscv_vse16_v_f16m2(outptr0, __riscv_vfncvt_f_f_w_f16m2(_sum, vl8), vl8); + outptr0 += 8; + } + } + else + { + __riscv_vse32_v_f32m4(outptr, _sum, vl8w); + } + + outptr += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat32m2_t _sum; + + if (k == 0) + { if (pC) { if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - _sum = __riscv_vfmv_v_f_f32m2(pC[0], vl2); + _sum = __riscv_vfmv_v_f_f32m2(pC[0], vl4w); } if (broadcast_type_C == 3 || broadcast_type_C == 4) { - _sum = __riscv_vle32_v_f32m2(pC, vl2); - pC += packn; + _sum = __riscv_vle32_v_f32m2(pC, vl4w); + pC += 4; } } + else + { + _sum = __riscv_vfmv_v_f_f32m2(0.f, vl4w); + } } else { - _sum = __riscv_vle32_v_f32m2(outptr, vl2); + _sum = __riscv_vle32_v_f32m2(outptr, vl4w); } const __fp16* pA = pAT; int kk = 0; for (; kk < max_kk; kk += 1) { - vfloat16m1_t _pB = __riscv_vle16_v_f16m1(pB, vl); - _sum = __riscv_vfwmacc_vf_f32m2(_sum, pA[0], _pB, vl); + vfloat16m1_t _pB = __riscv_vle16_v_f16m1(pB, vl4); + _sum = __riscv_vfwmacc_vf_f32m2(_sum, pA[0], _pB, vl4); pA += 1; - pB += packn; + pB += 4; } if (alpha != 1.f) { - _sum = __riscv_vfmul_vf_f32m2(_sum, alpha, vl2); + _sum = __riscv_vfmul_vf_f32m2(_sum, alpha, vl4w); } if (k_end) { // if (out_elempack == 1) { - __riscv_vse16_v_f16m1(outptr0, __riscv_vfncvt_f_f_w_f16m1(_sum, vl), vl); - outptr0 += packn; + __riscv_vse16_v_f16m1(outptr0, __riscv_vfncvt_f_f_w_f16m1(_sum, vl4), vl4); + outptr0 += 4; } } else { - __riscv_vse32_v_f32m2(outptr, _sum, vl2); + __riscv_vse32_v_f32m2(outptr, _sum, vl4w); } - outptr += packn; + outptr += 4; } #endif // __riscv_vector for (; jj + 1 < max_jj; jj += 2) diff --git a/src/layer/riscv/gemm_fp16sa.h b/src/layer/riscv/gemm_fp16sa.h new file mode 100644 index 000000000000..798d0e0dd4d3 --- /dev/null +++ b/src/layer/riscv/gemm_fp16sa.h @@ -0,0 +1,2134 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +static void pack_B_tile_fp16sa(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ +#if __riscv_vector + const int packn = csrr_vlenb() / 2; + const size_t vl = __riscv_vsetvl_e16m1(packn); + const size_t vl16 = __riscv_vsetvl_e16m2(16); + const size_t vl8 = __riscv_vsetvl_e16m1(8); + const size_t vl4 = __riscv_vsetvl_e16m1(4); +#endif + + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + unsigned short* pp = BT; + + int jj = 0; +#if __riscv_vector + for (; jj + 15 < max_jj; jj += 16) + { + if (elempack == packn) + { + const int q = (j + jj) / packn * packn; + const int r = (j + jj) % packn; + const unsigned short* p0 = (const unsigned short*)B + q * B_hstep + k * packn + r; + + if (packn >= 16) + { + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl16), vl16); + pp += 16; + p0 += packn; + } + } + if (packn == 8) + { + const unsigned short* p1 = (const unsigned short*)B + (q + 8) * B_hstep + k * 8; + + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl8), vl8); + __riscv_vse16_v_u16m1(pp + 8, __riscv_vle16_v_u16m1(p1, vl8), vl8); + pp += 16; + p0 += 8; + p1 += 8; + } + } + } + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse16_v_u16m2(pp, __riscv_vlse16_v_u16m2(p0, B_hstep * sizeof(unsigned short), vl16), vl16); + pp += 16; + p0++; + } + } + } + for (; jj + 7 < max_jj; jj += 8) + { + if (elempack == packn) + { + const int q = (j + jj) / packn * packn; + const int r = (j + jj) % packn; + const unsigned short* p0 = (const unsigned short*)B + q * B_hstep + k * packn + r; + + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl8), vl8); + pp += 8; + p0 += packn; + } + } + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0, B_hstep * sizeof(unsigned short), vl8), vl8); + pp += 8; + p0++; + } + } + } + for (; jj + 3 < max_jj; jj += 4) + { + if (elempack == packn) + { + const int q = (j + jj) / packn * packn; + const int r = (j + jj) % packn; + const unsigned short* p0 = (const unsigned short*)B + q * B_hstep + k * packn + r; + + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl4), vl4); + pp += 4; + p0 += packn; + } + } + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0, B_hstep * sizeof(unsigned short), vl4), vl4); + pp += 4; + p0++; + } + } + } +#endif // __riscv_vector + for (; jj + 1 < max_jj; jj += 2) + { +#if __riscv_vector + if (elempack == packn) + { + const int q = (j + jj) / packn * packn; + const int r = (j + jj) % packn; + const unsigned short* p0 = (const unsigned short*)B + q * B_hstep + k * packn + r; + + for (int kk = 0; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += packn; + } + } +#endif // __riscv_vector + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + const unsigned short* p1 = (const unsigned short*)B + (j + jj + 1) * B_hstep + k; + + int kk = 0; +#if __riscv_vector + for (; kk + (packn - 1) < max_kk; kk += packn) + { + vuint16m1_t v0 = __riscv_vle16_v_u16m1(p0, vl); + vuint16m1_t v1 = __riscv_vle16_v_u16m1(p1, vl); + __riscv_vsseg2e16_v_u16m1x2(pp, __riscv_vcreate_v_u16m1x2(v0, v1), vl); + pp += packn * 2; + p0 += packn; + p1 += packn; + } +#endif // __riscv_vector + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp += 2; + p0++; + p1++; + } + } + } + for (; jj < max_jj; jj += 1) + { +#if __riscv_vector + if (elempack == packn) + { + const int q = (j + jj) / packn * packn; + const int r = (j + jj) % packn; + const unsigned short* p0 = (const unsigned short*)B + q * B_hstep + k * packn + r; + + for (int kk = 0; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0 += packn; + } + } +#endif // __riscv_vector + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + + int kk = 0; +#if __riscv_vector + for (; kk + (packn - 1) < max_kk; kk += packn) + { + __riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl), vl); + pp += packn; + p0 += packn; + } +#endif // __riscv_vector + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0++; + } + } + } +} + +static void transpose_pack_B_tile_fp16sa(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ +#if __riscv_vector + const int packn = csrr_vlenb() / 2; + const size_t vl = __riscv_vsetvl_e16m1(packn); + const size_t vl16 = __riscv_vsetvl_e16m2(16); + const size_t vl8 = __riscv_vsetvl_e16m1(8); + const size_t vl4 = __riscv_vsetvl_e16m1(4); +#endif + + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + unsigned short* pp = BT; + + int jj = 0; +#if __riscv_vector + for (; jj + 15 < max_jj; jj += 16) + { + if (elempack == packn) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * packn; + + if (packn >= 16) + { + int kk = 0; + for (; kk + (packn - 1) < max_kk; kk += packn) + { + // transposeNx16 + for (int l = 0; l < packn; l++) + { + __riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0 + l, packn * sizeof(unsigned short), vl16), vl16); + pp += 16; + } + + p0 += B_hstep * packn; + } + } + if (packn == 8) + { + const unsigned short* p1 = p0 + 8 * 8; + + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + // transpose8x16 + for (int l = 0; l < 8; l++) + { + __riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0 + l, 8 * sizeof(unsigned short), vl8), vl8); + __riscv_vse16_v_u16m1(pp + 8, __riscv_vlse16_v_u16m1(p1 + l, 8 * sizeof(unsigned short), vl8), vl8); + pp += 16; + } + + p0 += B_hstep * 8; + p1 += B_hstep * 8; + } + } + } + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj); + + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse16_v_u16m2(pp, __riscv_vle16_v_u16m2(p0, vl16), vl16); + pp += 16; + p0 += B_hstep; + } + } + } + for (; jj + 7 < max_jj; jj += 8) + { + if (elempack == packn) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * packn; + + int kk = 0; + for (; kk + (packn - 1) < max_kk; kk += packn) + { + // transposeNx8 + for (int l = 0; l < packn; l++) + { + __riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0 + l, packn * sizeof(unsigned short), vl8), vl8); + pp += 8; + } + + p0 += B_hstep * packn; + } + } + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + __riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl8), vl8); + pp += 8; + p0 += B_hstep; + } + } + } + for (; jj + 3 < max_jj; jj += 4) + { + if (elempack == packn) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * packn; + + int kk = 0; + for (; kk + (packn - 1) < max_kk; kk += packn) + { + // transposeNx4 + for (int l = 0; l < packn; l++) + { + __riscv_vse16_v_u16m1(pp, __riscv_vlse16_v_u16m1(p0 + l, packn * sizeof(unsigned short), vl4), vl4); + pp += 4; + } + p0 += B_hstep * packn; + } + } + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + __riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl4), vl4); + pp += 4; + p0 += B_hstep; + } + } + } +#endif // __riscv_vector + for (; jj + 1 < max_jj; jj += 2) + { +#if __riscv_vector + if (elempack == packn) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * packn; + + int kk = 0; + for (; kk + (packn - 1) < max_kk; kk += packn) + { + vuint16m1_t v0 = __riscv_vle16_v_u16m1(p0, vl); + vuint16m1_t v1 = __riscv_vle16_v_u16m1(p0 + packn, vl); + __riscv_vsseg2e16_v_u16m1x2(pp, __riscv_vcreate_v_u16m1x2(v0, v1), vl); + pp += packn * 2; + p0 += B_hstep * packn; + } + } +#endif // __riscv_vector + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += B_hstep; + } + } + } + for (; jj < max_jj; jj += 1) + { +#if __riscv_vector + if (elempack == packn) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * packn; + + int kk = 0; + for (; kk + (packn - 1) < max_kk; kk += packn) + { + __riscv_vse16_v_u16m1(pp, __riscv_vle16_v_u16m1(p0, vl), vl); + pp += packn; + p0 += B_hstep * packn; + } + } +#endif // __riscv_vector + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0 += B_hstep; + } + } + } +} + +static void transpose_unpack_output_tile_fp16sa(const Mat& topT, Mat& top_blob, int i, int max_ii, int j, int max_jj) +{ +#if __riscv_vector + const int packn = csrr_vlenb() / 2; + const size_t vl = __riscv_vsetvl_e16m1(packn); + const size_t vl16 = __riscv_vsetvl_e16m2(16); + const size_t vl8 = __riscv_vsetvl_e16m1(8); + const size_t vl4 = __riscv_vsetvl_e16m1(4); +#endif + +#if __riscv_vector + const int out_elempack = top_blob.elempack; +#endif + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const __fp16* pp = topT; + + int ii = 0; +#if __riscv_vector + for (; ii + (packn - 1) < max_ii; ii += packn) + { + if (out_elempack == packn) + { + int jj = 0; + + const int r0 = j % packn; + if (r0 != 0) + { + const int nn = std::min(packn - r0, max_jj); + __fp16* p0 = (__fp16*)top_blob + (j / packn * packn) * out_hstep + r0 + (i + ii) * packn; + + for (; jj < nn; jj++) + { + __riscv_vsse16_v_f16m1(p0, packn * sizeof(__fp16), __riscv_vle16_v_f16m1(pp, vl), vl); + pp += packn; + p0++; + } + } + + __fp16* p0 = (__fp16*)top_blob + (j + jj) * out_hstep + (i + ii) * packn; + + for (; jj + (packn - 1) < max_jj; jj += packn) + { + // transposeNxN + for (int l = 0; l < packn; l++) + { + __riscv_vsse16_v_f16m1(p0 + l, packn * sizeof(__fp16), __riscv_vle16_v_f16m1(pp, vl), vl); + pp += packn; + } + + p0 += out_hstep * packn; + } + + for (; jj < max_jj; jj++) + { + __riscv_vsse16_v_f16m1(p0, packn * sizeof(__fp16), __riscv_vle16_v_f16m1(pp, vl), vl); + pp += packn; + p0++; + } + } + if (out_elempack == 1) + { + __fp16* p0 = (__fp16*)top_blob + j * out_hstep + (i + ii); + + for (int jj = 0; jj < max_jj; jj += 1) + { + __riscv_vse16_v_f16m1(p0, __riscv_vle16_v_f16m1(pp, vl), vl); + pp += packn; + p0 += out_hstep; + } + } + } +#endif // __riscv_vector + for (; ii + 1 < max_ii; ii += 2) + { +#if __riscv_vector + if (out_elempack == packn) + { + int jj = 0; + + for (; jj + 15 < max_jj; jj += 16) + { + const int out_j = j + jj; + __fp16* p0 = (__fp16*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + if (packn == 8) + { + __riscv_vse16_v_f16m1(p0, __riscv_vle16_v_f16m1(pp, vl8), vl8); + __riscv_vse16_v_f16m1(p0 + packn, __riscv_vle16_v_f16m1(pp + 16, vl8), vl8); + p0 += out_hstep * 8; + __riscv_vse16_v_f16m1(p0, __riscv_vle16_v_f16m1(pp + 8, vl8), vl8); + __riscv_vse16_v_f16m1(p0 + packn, __riscv_vle16_v_f16m1(pp + 24, vl8), vl8); + } + if (packn >= 16) + { + __riscv_vse16_v_f16m2(p0, __riscv_vle16_v_f16m2(pp, vl16), vl16); + __riscv_vse16_v_f16m2(p0 + packn, __riscv_vle16_v_f16m2(pp + 16, vl16), vl16); + } + pp += 16 * 2; + } + for (; jj + 7 < max_jj; jj += 8) + { + const int out_j = j + jj; + __fp16* p0 = (__fp16*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + __riscv_vse16_v_f16m1(p0, __riscv_vle16_v_f16m1(pp, vl8), vl8); + __riscv_vse16_v_f16m1(p0 + packn, __riscv_vle16_v_f16m1(pp + 8, vl8), vl8); + pp += 8 * 2; + } + for (; jj + 3 < max_jj; jj += 4) + { + const int out_j = j + jj; + __fp16* p0 = (__fp16*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + __riscv_vse16_v_f16m1(p0, __riscv_vle16_v_f16m1(pp, vl4), vl4); + __riscv_vse16_v_f16m1(p0 + packn, __riscv_vle16_v_f16m1(pp + 4, vl4), vl4); + pp += 4 * 2; + } + for (; jj + 1 < max_jj; jj += 2) + { + const int out_j = j + jj; + __fp16* p0 = (__fp16*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + p0[0] = pp[0]; + p0[1] = pp[1]; + p0[packn] = pp[2]; + p0[packn + 1] = pp[3]; + pp += 2 * 2; + } + for (; jj < max_jj; jj += 1) + { + const int out_j = j + jj; + __fp16* p0 = (__fp16*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + p0[0] = pp[0]; + p0[packn] = pp[1]; + pp += 2; + } + } + if (out_elempack == 1) +#endif // __riscv_vector + { + __fp16* p0 = (__fp16*)top_blob + j * out_hstep + (i + ii); + + int jj = 0; +#if __riscv_vector + for (; jj + 15 < max_jj; jj += 16) + { + __riscv_vsse16_v_f16m2(p0, out_hstep * sizeof(__fp16), __riscv_vle16_v_f16m2(pp, vl16), vl16); + __riscv_vsse16_v_f16m2(p0 + 1, out_hstep * sizeof(__fp16), __riscv_vle16_v_f16m2(pp + 16, vl16), vl16); + pp += 16 * 2; + p0 += out_hstep * 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + __riscv_vsse16_v_f16m1(p0, out_hstep * sizeof(__fp16), __riscv_vle16_v_f16m1(pp, vl8), vl8); + __riscv_vsse16_v_f16m1(p0 + 1, out_hstep * sizeof(__fp16), __riscv_vle16_v_f16m1(pp + 8, vl8), vl8); + pp += 8 * 2; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + __riscv_vsse16_v_f16m1(p0, out_hstep * sizeof(__fp16), __riscv_vle16_v_f16m1(pp, vl4), vl4); + __riscv_vsse16_v_f16m1(p0 + 1, out_hstep * sizeof(__fp16), __riscv_vle16_v_f16m1(pp + 4, vl4), vl4); + pp += 4 * 2; + p0 += out_hstep * 4; + } +#endif // __riscv_vector + for (; jj + 1 < max_jj; jj += 2) + { + p0[0] = pp[0]; + p0[out_hstep] = pp[1]; + p0[1] = pp[2]; + p0[out_hstep + 1] = pp[3]; + pp += 2 * 2; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) + { + p0[0] = pp[0]; + p0[1] = pp[1]; + pp += 2; + p0 += out_hstep; + } + } + } + for (; ii < max_ii; ii += 1) + { +#if __riscv_vector + if (out_elempack == packn) + { + int jj = 0; + + for (; jj + 15 < max_jj; jj += 16) + { + const int out_j = j + jj; + __fp16* p0 = (__fp16*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + if (packn == 8) + { + __riscv_vse16_v_f16m1(p0, __riscv_vle16_v_f16m1(pp, vl8), vl8); + p0 += out_hstep * 8; + __riscv_vse16_v_f16m1(p0, __riscv_vle16_v_f16m1(pp + 8, vl8), vl8); + } + if (packn >= 16) + { + __riscv_vse16_v_f16m2(p0, __riscv_vle16_v_f16m2(pp, vl16), vl16); + } + pp += 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + const int out_j = j + jj; + __fp16* p0 = (__fp16*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + __riscv_vse16_v_f16m1(p0, __riscv_vle16_v_f16m1(pp, vl8), vl8); + pp += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + const int out_j = j + jj; + __fp16* p0 = (__fp16*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + __riscv_vse16_v_f16m1(p0, __riscv_vle16_v_f16m1(pp, vl4), vl4); + pp += 4; + } + for (; jj < max_jj; jj += 1) + { + const int out_j = j + jj; + __fp16* p0 = (__fp16*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + p0[0] = pp[0]; + pp += 1; + } + } + if (out_elempack == 1) +#endif // __riscv_vector + { + __fp16* p0 = (__fp16*)top_blob + j * out_hstep + (i + ii); + + int jj = 0; +#if __riscv_vector + for (; jj + 15 < max_jj; jj += 16) + { + __riscv_vsse16_v_f16m2(p0, out_hstep * sizeof(__fp16), __riscv_vle16_v_f16m2(pp, vl16), vl16); + pp += 16; + p0 += out_hstep * 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + __riscv_vsse16_v_f16m1(p0, out_hstep * sizeof(__fp16), __riscv_vle16_v_f16m1(pp, vl8), vl8); + pp += 8; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + __riscv_vsse16_v_f16m1(p0, out_hstep * sizeof(__fp16), __riscv_vle16_v_f16m1(pp, vl4), vl4); + pp += 4; + p0 += out_hstep * 4; + } +#endif // __riscv_vector + for (; jj < max_jj; jj += 1) + { + p0[0] = pp[0]; + pp += 1; + p0 += out_hstep; + } + } + } +} + +static void gemm_transB_packed_tile_fp16sa(const Mat& AT_tile, const Mat& BT_tile, const Mat& CT_tile, Mat& topT_tile, Mat& top_blob, int broadcast_type_C, float alpha, int i, int max_ii, int j, int max_jj, int k, int max_kk, bool k_end) +{ +#if __riscv_vector + const int packn = csrr_vlenb() / 2; + const size_t vl = __riscv_vsetvl_e16m1(packn); + const size_t vl16 = __riscv_vsetvl_e16m2(16); + const size_t vl8 = __riscv_vsetvl_e16m1(8); + const size_t vl4 = __riscv_vsetvl_e16m1(4); +#endif + +#if __riscv_vector + const int out_elempack = top_blob.elempack; +#endif + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + const __fp16 _alpha = (__fp16)alpha; + + const __fp16* pAT = AT_tile; + const __fp16* pBT = BT_tile; + const __fp16* pC = CT_tile; + + __fp16* outptr = topT_tile; + + int ii = 0; +#if __riscv_vector + for (; ii + (packn - 1) < max_ii; ii += packn) + { + __fp16* outptr0 = (__fp16*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + const __fp16* pB = pBT; + + if (pC) + { + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const __fp16*)CT_tile + i + ii; + } + if (broadcast_type_C == 4) + { + pC = (const __fp16*)CT_tile + j; + } + } + + int jj = 0; + for (; jj + 15 < max_jj; jj += 16) + { + vfloat16m1_t _sum0; + vfloat16m1_t _sum1; + vfloat16m1_t _sum2; + vfloat16m1_t _sum3; + vfloat16m1_t _sum4; + vfloat16m1_t _sum5; + vfloat16m1_t _sum6; + vfloat16m1_t _sum7; + vfloat16m1_t _sum8; + vfloat16m1_t _sum9; + vfloat16m1_t _suma; + vfloat16m1_t _sumb; + vfloat16m1_t _sumc; + vfloat16m1_t _sumd; + vfloat16m1_t _sume; + vfloat16m1_t _sumf; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = __riscv_vfmv_v_f_f16m1(pC[0], vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + _sum8 = _sum0; + _sum9 = _sum0; + _suma = _sum0; + _sumb = _sum0; + _sumc = _sum0; + _sumd = _sum0; + _sume = _sum0; + _sumf = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = __riscv_vle16_v_f16m1(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + _sum8 = _sum0; + _sum9 = _sum0; + _suma = _sum0; + _sumb = _sum0; + _sumc = _sum0; + _sumd = _sum0; + _sume = _sum0; + _sumf = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = __riscv_vle16_v_f16m1(pC, vl); + _sum1 = __riscv_vle16_v_f16m1(pC + packn, vl); + _sum2 = __riscv_vle16_v_f16m1(pC + packn * 2, vl); + _sum3 = __riscv_vle16_v_f16m1(pC + packn * 3, vl); + _sum4 = __riscv_vle16_v_f16m1(pC + packn * 4, vl); + _sum5 = __riscv_vle16_v_f16m1(pC + packn * 5, vl); + _sum6 = __riscv_vle16_v_f16m1(pC + packn * 6, vl); + _sum7 = __riscv_vle16_v_f16m1(pC + packn * 7, vl); + _sum8 = __riscv_vle16_v_f16m1(pC + packn * 8, vl); + _sum9 = __riscv_vle16_v_f16m1(pC + packn * 9, vl); + _suma = __riscv_vle16_v_f16m1(pC + packn * 10, vl); + _sumb = __riscv_vle16_v_f16m1(pC + packn * 11, vl); + _sumc = __riscv_vle16_v_f16m1(pC + packn * 12, vl); + _sumd = __riscv_vle16_v_f16m1(pC + packn * 13, vl); + _sume = __riscv_vle16_v_f16m1(pC + packn * 14, vl); + _sumf = __riscv_vle16_v_f16m1(pC + packn * 15, vl); + pC += packn * 16; + } + if (broadcast_type_C == 4) + { + _sum0 = __riscv_vfmv_v_f_f16m1(pC[0], vl); + _sum1 = __riscv_vfmv_v_f_f16m1(pC[1], vl); + _sum2 = __riscv_vfmv_v_f_f16m1(pC[2], vl); + _sum3 = __riscv_vfmv_v_f_f16m1(pC[3], vl); + _sum4 = __riscv_vfmv_v_f_f16m1(pC[4], vl); + _sum5 = __riscv_vfmv_v_f_f16m1(pC[5], vl); + _sum6 = __riscv_vfmv_v_f_f16m1(pC[6], vl); + _sum7 = __riscv_vfmv_v_f_f16m1(pC[7], vl); + _sum8 = __riscv_vfmv_v_f_f16m1(pC[8], vl); + _sum9 = __riscv_vfmv_v_f_f16m1(pC[9], vl); + _suma = __riscv_vfmv_v_f_f16m1(pC[10], vl); + _sumb = __riscv_vfmv_v_f_f16m1(pC[11], vl); + _sumc = __riscv_vfmv_v_f_f16m1(pC[12], vl); + _sumd = __riscv_vfmv_v_f_f16m1(pC[13], vl); + _sume = __riscv_vfmv_v_f_f16m1(pC[14], vl); + _sumf = __riscv_vfmv_v_f_f16m1(pC[15], vl); + pC += 16; + } + } + else + { + _sum0 = __riscv_vfmv_v_f_f16m1((__fp16)0.f, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + _sum8 = _sum0; + _sum9 = _sum0; + _suma = _sum0; + _sumb = _sum0; + _sumc = _sum0; + _sumd = _sum0; + _sume = _sum0; + _sumf = _sum0; + } + } + else + { + _sum0 = __riscv_vle16_v_f16m1(outptr, vl); + _sum1 = __riscv_vle16_v_f16m1(outptr + packn, vl); + _sum2 = __riscv_vle16_v_f16m1(outptr + packn * 2, vl); + _sum3 = __riscv_vle16_v_f16m1(outptr + packn * 3, vl); + _sum4 = __riscv_vle16_v_f16m1(outptr + packn * 4, vl); + _sum5 = __riscv_vle16_v_f16m1(outptr + packn * 5, vl); + _sum6 = __riscv_vle16_v_f16m1(outptr + packn * 6, vl); + _sum7 = __riscv_vle16_v_f16m1(outptr + packn * 7, vl); + _sum8 = __riscv_vle16_v_f16m1(outptr + packn * 8, vl); + _sum9 = __riscv_vle16_v_f16m1(outptr + packn * 9, vl); + _suma = __riscv_vle16_v_f16m1(outptr + packn * 10, vl); + _sumb = __riscv_vle16_v_f16m1(outptr + packn * 11, vl); + _sumc = __riscv_vle16_v_f16m1(outptr + packn * 12, vl); + _sumd = __riscv_vle16_v_f16m1(outptr + packn * 13, vl); + _sume = __riscv_vle16_v_f16m1(outptr + packn * 14, vl); + _sumf = __riscv_vle16_v_f16m1(outptr + packn * 15, vl); + } + + const __fp16* pA = pAT; + for (int kk = 0; kk < max_kk; kk++) + { + vfloat16m1_t _pA = __riscv_vle16_v_f16m1(pA, vl); + _sum0 = __riscv_vfmacc_vf_f16m1(_sum0, pB[0], _pA, vl); + _sum1 = __riscv_vfmacc_vf_f16m1(_sum1, pB[1], _pA, vl); + _sum2 = __riscv_vfmacc_vf_f16m1(_sum2, pB[2], _pA, vl); + _sum3 = __riscv_vfmacc_vf_f16m1(_sum3, pB[3], _pA, vl); + _sum4 = __riscv_vfmacc_vf_f16m1(_sum4, pB[4], _pA, vl); + _sum5 = __riscv_vfmacc_vf_f16m1(_sum5, pB[5], _pA, vl); + _sum6 = __riscv_vfmacc_vf_f16m1(_sum6, pB[6], _pA, vl); + _sum7 = __riscv_vfmacc_vf_f16m1(_sum7, pB[7], _pA, vl); + _sum8 = __riscv_vfmacc_vf_f16m1(_sum8, pB[8], _pA, vl); + _sum9 = __riscv_vfmacc_vf_f16m1(_sum9, pB[9], _pA, vl); + _suma = __riscv_vfmacc_vf_f16m1(_suma, pB[10], _pA, vl); + _sumb = __riscv_vfmacc_vf_f16m1(_sumb, pB[11], _pA, vl); + _sumc = __riscv_vfmacc_vf_f16m1(_sumc, pB[12], _pA, vl); + _sumd = __riscv_vfmacc_vf_f16m1(_sumd, pB[13], _pA, vl); + _sume = __riscv_vfmacc_vf_f16m1(_sume, pB[14], _pA, vl); + _sumf = __riscv_vfmacc_vf_f16m1(_sumf, pB[15], _pA, vl); + pA += packn; + pB += 16; + } + + if (alpha != 1.f) + { + _sum0 = __riscv_vfmul_vf_f16m1(_sum0, _alpha, vl); + _sum1 = __riscv_vfmul_vf_f16m1(_sum1, _alpha, vl); + _sum2 = __riscv_vfmul_vf_f16m1(_sum2, _alpha, vl); + _sum3 = __riscv_vfmul_vf_f16m1(_sum3, _alpha, vl); + _sum4 = __riscv_vfmul_vf_f16m1(_sum4, _alpha, vl); + _sum5 = __riscv_vfmul_vf_f16m1(_sum5, _alpha, vl); + _sum6 = __riscv_vfmul_vf_f16m1(_sum6, _alpha, vl); + _sum7 = __riscv_vfmul_vf_f16m1(_sum7, _alpha, vl); + _sum8 = __riscv_vfmul_vf_f16m1(_sum8, _alpha, vl); + _sum9 = __riscv_vfmul_vf_f16m1(_sum9, _alpha, vl); + _suma = __riscv_vfmul_vf_f16m1(_suma, _alpha, vl); + _sumb = __riscv_vfmul_vf_f16m1(_sumb, _alpha, vl); + _sumc = __riscv_vfmul_vf_f16m1(_sumc, _alpha, vl); + _sumd = __riscv_vfmul_vf_f16m1(_sumd, _alpha, vl); + _sume = __riscv_vfmul_vf_f16m1(_sume, _alpha, vl); + _sumf = __riscv_vfmul_vf_f16m1(_sumf, _alpha, vl); + } + + if (k_end) + { + if (out_elempack == packn) + { + __riscv_vse16_v_f16m1(outptr0, _sum0, vl); + __riscv_vse16_v_f16m1(outptr0 + packn, _sum1, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 2, _sum2, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 3, _sum3, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 4, _sum4, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 5, _sum5, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 6, _sum6, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 7, _sum7, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 8, _sum8, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 9, _sum9, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 10, _suma, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 11, _sumb, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 12, _sumc, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 13, _sumd, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 14, _sume, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 15, _sumf, vl); + outptr0 += packn * 16; + } + if (out_elempack == 1) + { + vfloat16m1x8_t _sum01 = __riscv_vcreate_v_f16m1x8(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7); + vfloat16m1x8_t _sum23 = __riscv_vcreate_v_f16m1x8(_sum8, _sum9, _suma, _sumb, _sumc, _sumd, _sume, _sumf); + __riscv_vssseg8e16_v_f16m1x8(outptr0, out_hstep * sizeof(__fp16), _sum01, vl); + __riscv_vssseg8e16_v_f16m1x8(outptr0 + 8, out_hstep * sizeof(__fp16), _sum23, vl); + outptr0 += 16; + } + } + else + { + __riscv_vse16_v_f16m1(outptr, _sum0, vl); + __riscv_vse16_v_f16m1(outptr + packn, _sum1, vl); + __riscv_vse16_v_f16m1(outptr + packn * 2, _sum2, vl); + __riscv_vse16_v_f16m1(outptr + packn * 3, _sum3, vl); + __riscv_vse16_v_f16m1(outptr + packn * 4, _sum4, vl); + __riscv_vse16_v_f16m1(outptr + packn * 5, _sum5, vl); + __riscv_vse16_v_f16m1(outptr + packn * 6, _sum6, vl); + __riscv_vse16_v_f16m1(outptr + packn * 7, _sum7, vl); + __riscv_vse16_v_f16m1(outptr + packn * 8, _sum8, vl); + __riscv_vse16_v_f16m1(outptr + packn * 9, _sum9, vl); + __riscv_vse16_v_f16m1(outptr + packn * 10, _suma, vl); + __riscv_vse16_v_f16m1(outptr + packn * 11, _sumb, vl); + __riscv_vse16_v_f16m1(outptr + packn * 12, _sumc, vl); + __riscv_vse16_v_f16m1(outptr + packn * 13, _sumd, vl); + __riscv_vse16_v_f16m1(outptr + packn * 14, _sume, vl); + __riscv_vse16_v_f16m1(outptr + packn * 15, _sumf, vl); + } + + outptr += packn * 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + vfloat16m1_t _sum0; + vfloat16m1_t _sum1; + vfloat16m1_t _sum2; + vfloat16m1_t _sum3; + vfloat16m1_t _sum4; + vfloat16m1_t _sum5; + vfloat16m1_t _sum6; + vfloat16m1_t _sum7; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = __riscv_vfmv_v_f_f16m1(pC[0], vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = __riscv_vle16_v_f16m1(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = __riscv_vle16_v_f16m1(pC, vl); + _sum1 = __riscv_vle16_v_f16m1(pC + packn, vl); + _sum2 = __riscv_vle16_v_f16m1(pC + packn * 2, vl); + _sum3 = __riscv_vle16_v_f16m1(pC + packn * 3, vl); + _sum4 = __riscv_vle16_v_f16m1(pC + packn * 4, vl); + _sum5 = __riscv_vle16_v_f16m1(pC + packn * 5, vl); + _sum6 = __riscv_vle16_v_f16m1(pC + packn * 6, vl); + _sum7 = __riscv_vle16_v_f16m1(pC + packn * 7, vl); + pC += packn * 8; + } + if (broadcast_type_C == 4) + { + _sum0 = __riscv_vfmv_v_f_f16m1(pC[0], vl); + _sum1 = __riscv_vfmv_v_f_f16m1(pC[1], vl); + _sum2 = __riscv_vfmv_v_f_f16m1(pC[2], vl); + _sum3 = __riscv_vfmv_v_f_f16m1(pC[3], vl); + _sum4 = __riscv_vfmv_v_f_f16m1(pC[4], vl); + _sum5 = __riscv_vfmv_v_f_f16m1(pC[5], vl); + _sum6 = __riscv_vfmv_v_f_f16m1(pC[6], vl); + _sum7 = __riscv_vfmv_v_f_f16m1(pC[7], vl); + pC += 8; + } + } + else + { + _sum0 = __riscv_vfmv_v_f_f16m1((__fp16)0.f, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + } + else + { + _sum0 = __riscv_vle16_v_f16m1(outptr, vl); + _sum1 = __riscv_vle16_v_f16m1(outptr + packn, vl); + _sum2 = __riscv_vle16_v_f16m1(outptr + packn * 2, vl); + _sum3 = __riscv_vle16_v_f16m1(outptr + packn * 3, vl); + _sum4 = __riscv_vle16_v_f16m1(outptr + packn * 4, vl); + _sum5 = __riscv_vle16_v_f16m1(outptr + packn * 5, vl); + _sum6 = __riscv_vle16_v_f16m1(outptr + packn * 6, vl); + _sum7 = __riscv_vle16_v_f16m1(outptr + packn * 7, vl); + } + + const __fp16* pA = pAT; + for (int kk = 0; kk < max_kk; kk++) + { + vfloat16m1_t _pA = __riscv_vle16_v_f16m1(pA, vl); + _sum0 = __riscv_vfmacc_vf_f16m1(_sum0, pB[0], _pA, vl); + _sum1 = __riscv_vfmacc_vf_f16m1(_sum1, pB[1], _pA, vl); + _sum2 = __riscv_vfmacc_vf_f16m1(_sum2, pB[2], _pA, vl); + _sum3 = __riscv_vfmacc_vf_f16m1(_sum3, pB[3], _pA, vl); + _sum4 = __riscv_vfmacc_vf_f16m1(_sum4, pB[4], _pA, vl); + _sum5 = __riscv_vfmacc_vf_f16m1(_sum5, pB[5], _pA, vl); + _sum6 = __riscv_vfmacc_vf_f16m1(_sum6, pB[6], _pA, vl); + _sum7 = __riscv_vfmacc_vf_f16m1(_sum7, pB[7], _pA, vl); + pA += packn; + pB += 8; + } + + if (alpha != 1.f) + { + _sum0 = __riscv_vfmul_vf_f16m1(_sum0, _alpha, vl); + _sum1 = __riscv_vfmul_vf_f16m1(_sum1, _alpha, vl); + _sum2 = __riscv_vfmul_vf_f16m1(_sum2, _alpha, vl); + _sum3 = __riscv_vfmul_vf_f16m1(_sum3, _alpha, vl); + _sum4 = __riscv_vfmul_vf_f16m1(_sum4, _alpha, vl); + _sum5 = __riscv_vfmul_vf_f16m1(_sum5, _alpha, vl); + _sum6 = __riscv_vfmul_vf_f16m1(_sum6, _alpha, vl); + _sum7 = __riscv_vfmul_vf_f16m1(_sum7, _alpha, vl); + } + + if (k_end) + { + if (out_elempack == packn) + { + __riscv_vse16_v_f16m1(outptr0, _sum0, vl); + __riscv_vse16_v_f16m1(outptr0 + packn, _sum1, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 2, _sum2, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 3, _sum3, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 4, _sum4, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 5, _sum5, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 6, _sum6, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 7, _sum7, vl); + outptr0 += packn * 8; + } + if (out_elempack == 1) + { + vfloat16m1x8_t _sum = __riscv_vcreate_v_f16m1x8(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7); + __riscv_vssseg8e16_v_f16m1x8(outptr0, out_hstep * sizeof(__fp16), _sum, vl); + outptr0 += 8; + } + } + else + { + __riscv_vse16_v_f16m1(outptr, _sum0, vl); + __riscv_vse16_v_f16m1(outptr + packn, _sum1, vl); + __riscv_vse16_v_f16m1(outptr + packn * 2, _sum2, vl); + __riscv_vse16_v_f16m1(outptr + packn * 3, _sum3, vl); + __riscv_vse16_v_f16m1(outptr + packn * 4, _sum4, vl); + __riscv_vse16_v_f16m1(outptr + packn * 5, _sum5, vl); + __riscv_vse16_v_f16m1(outptr + packn * 6, _sum6, vl); + __riscv_vse16_v_f16m1(outptr + packn * 7, _sum7, vl); + } + + outptr += packn * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat16m1_t _sum0; + vfloat16m1_t _sum1; + vfloat16m1_t _sum2; + vfloat16m1_t _sum3; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = __riscv_vfmv_v_f_f16m1(pC[0], vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = __riscv_vle16_v_f16m1(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = __riscv_vle16_v_f16m1(pC, vl); + _sum1 = __riscv_vle16_v_f16m1(pC + packn, vl); + _sum2 = __riscv_vle16_v_f16m1(pC + packn * 2, vl); + _sum3 = __riscv_vle16_v_f16m1(pC + packn * 3, vl); + pC += packn * 4; + } + if (broadcast_type_C == 4) + { + _sum0 = __riscv_vfmv_v_f_f16m1(pC[0], vl); + _sum1 = __riscv_vfmv_v_f_f16m1(pC[1], vl); + _sum2 = __riscv_vfmv_v_f_f16m1(pC[2], vl); + _sum3 = __riscv_vfmv_v_f_f16m1(pC[3], vl); + pC += 4; + } + } + else + { + _sum0 = __riscv_vfmv_v_f_f16m1((__fp16)0.f, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + } + else + { + _sum0 = __riscv_vle16_v_f16m1(outptr, vl); + _sum1 = __riscv_vle16_v_f16m1(outptr + packn, vl); + _sum2 = __riscv_vle16_v_f16m1(outptr + packn * 2, vl); + _sum3 = __riscv_vle16_v_f16m1(outptr + packn * 3, vl); + } + + const __fp16* pA = pAT; + for (int kk = 0; kk < max_kk; kk++) + { + vfloat16m1_t _pA = __riscv_vle16_v_f16m1(pA, vl); + _sum0 = __riscv_vfmacc_vf_f16m1(_sum0, pB[0], _pA, vl); + _sum1 = __riscv_vfmacc_vf_f16m1(_sum1, pB[1], _pA, vl); + _sum2 = __riscv_vfmacc_vf_f16m1(_sum2, pB[2], _pA, vl); + _sum3 = __riscv_vfmacc_vf_f16m1(_sum3, pB[3], _pA, vl); + pA += packn; + pB += 4; + } + + if (alpha != 1.f) + { + _sum0 = __riscv_vfmul_vf_f16m1(_sum0, _alpha, vl); + _sum1 = __riscv_vfmul_vf_f16m1(_sum1, _alpha, vl); + _sum2 = __riscv_vfmul_vf_f16m1(_sum2, _alpha, vl); + _sum3 = __riscv_vfmul_vf_f16m1(_sum3, _alpha, vl); + } + + if (k_end) + { + if (out_elempack == packn) + { + __riscv_vse16_v_f16m1(outptr0, _sum0, vl); + __riscv_vse16_v_f16m1(outptr0 + packn, _sum1, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 2, _sum2, vl); + __riscv_vse16_v_f16m1(outptr0 + packn * 3, _sum3, vl); + outptr0 += packn * 4; + } + if (out_elempack == 1) + { + vfloat16m1x4_t _sum = __riscv_vcreate_v_f16m1x4(_sum0, _sum1, _sum2, _sum3); + __riscv_vssseg4e16_v_f16m1x4(outptr0, out_hstep * sizeof(__fp16), _sum, vl); + outptr0 += 4; + } + } + else + { + __riscv_vse16_v_f16m1(outptr, _sum0, vl); + __riscv_vse16_v_f16m1(outptr + packn, _sum1, vl); + __riscv_vse16_v_f16m1(outptr + packn * 2, _sum2, vl); + __riscv_vse16_v_f16m1(outptr + packn * 3, _sum3, vl); + } + + outptr += packn * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + vfloat16m1_t _sum0; + vfloat16m1_t _sum1; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = __riscv_vfmv_v_f_f16m1(pC[0], vl); + _sum1 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = __riscv_vle16_v_f16m1(pC, vl); + _sum1 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = __riscv_vle16_v_f16m1(pC, vl); + _sum1 = __riscv_vle16_v_f16m1(pC + packn, vl); + pC += packn * 2; + } + if (broadcast_type_C == 4) + { + _sum0 = __riscv_vfmv_v_f_f16m1(pC[0], vl); + _sum1 = __riscv_vfmv_v_f_f16m1(pC[1], vl); + pC += 2; + } + } + else + { + _sum0 = __riscv_vfmv_v_f_f16m1((__fp16)0.f, vl); + _sum1 = _sum0; + } + } + else + { + _sum0 = __riscv_vle16_v_f16m1(outptr, vl); + _sum1 = __riscv_vle16_v_f16m1(outptr + packn, vl); + } + + const __fp16* pA = pAT; + for (int kk = 0; kk < max_kk; kk++) + { + vfloat16m1_t _pA = __riscv_vle16_v_f16m1(pA, vl); + _sum0 = __riscv_vfmacc_vf_f16m1(_sum0, pB[0], _pA, vl); + _sum1 = __riscv_vfmacc_vf_f16m1(_sum1, pB[1], _pA, vl); + pA += packn; + pB += 2; + } + + if (alpha != 1.f) + { + _sum0 = __riscv_vfmul_vf_f16m1(_sum0, _alpha, vl); + _sum1 = __riscv_vfmul_vf_f16m1(_sum1, _alpha, vl); + } + + if (k_end) + { + if (out_elempack == packn) + { + __riscv_vse16_v_f16m1(outptr0, _sum0, vl); + __riscv_vse16_v_f16m1(outptr0 + packn, _sum1, vl); + outptr0 += packn * 2; + } + if (out_elempack == 1) + { + vfloat16m1x2_t _sum = __riscv_vcreate_v_f16m1x2(_sum0, _sum1); + __riscv_vssseg2e16_v_f16m1x2(outptr0, out_hstep * sizeof(__fp16), _sum, vl); + outptr0 += 2; + } + } + else + { + __riscv_vse16_v_f16m1(outptr, _sum0, vl); + __riscv_vse16_v_f16m1(outptr + packn, _sum1, vl); + } + + outptr += packn * 2; + } + for (; jj < max_jj; jj += 1) + { + vfloat16m1_t _sum0; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0) + _sum0 = __riscv_vfmv_v_f_f16m1(pC[0], vl); + if (broadcast_type_C == 1 || broadcast_type_C == 2) + _sum0 = __riscv_vle16_v_f16m1(pC, vl); + if (broadcast_type_C == 3) + { + _sum0 = __riscv_vle16_v_f16m1(pC, vl); + pC += packn; + } + if (broadcast_type_C == 4) + { + _sum0 = __riscv_vfmv_v_f_f16m1(pC[0], vl); + pC += 1; + } + } + else + { + _sum0 = __riscv_vfmv_v_f_f16m1((__fp16)0.f, vl); + } + } + else + { + _sum0 = __riscv_vle16_v_f16m1(outptr, vl); + } + + const __fp16* pA = pAT; + for (int kk = 0; kk < max_kk; kk++) + { + vfloat16m1_t _pA = __riscv_vle16_v_f16m1(pA, vl); + _sum0 = __riscv_vfmacc_vf_f16m1(_sum0, pB[0], _pA, vl); + pA += packn; + pB += 1; + } + + if (alpha != 1.f) + { + _sum0 = __riscv_vfmul_vf_f16m1(_sum0, _alpha, vl); + } + + if (k_end) + { + if (out_elempack == packn) + { + __riscv_vse16_v_f16m1(outptr0, _sum0, vl); + outptr0 += packn; + } + if (out_elempack == 1) + { + __riscv_vsse16_v_f16m1(outptr0, out_hstep * sizeof(__fp16), _sum0, vl); + outptr0++; + } + } + else + { + __riscv_vse16_v_f16m1(outptr, _sum0, vl); + } + + outptr += packn; + } + + pAT += max_kk * packn; + } +#endif // __riscv_vector + for (; ii + 1 < max_ii; ii += 2) + { + __fp16* outptr0 = (__fp16*)top_blob + (i + ii) * out_hstep + j; + + const __fp16* pB = pBT; + + if (pC) + { + if (broadcast_type_C == 1 || broadcast_type_C == 2) + pC = (const __fp16*)CT_tile + i + ii; + if (broadcast_type_C == 4) + pC = (const __fp16*)CT_tile + j; + } + + int jj = 0; +#if __riscv_vector + for (; jj + 15 < max_jj; jj += 16) + { + vfloat16m2_t _sum0; + vfloat16m2_t _sum1; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = __riscv_vfmv_v_f_f16m2(pC[0], vl16); + _sum1 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = __riscv_vfmv_v_f_f16m2(pC[0], vl16); + _sum1 = __riscv_vfmv_v_f_f16m2(pC[1], vl16); + } + if (broadcast_type_C == 3) + { + vfloat16m2x2_t _s0 = __riscv_vlseg2e16_v_f16m2x2(pC, vl16); + _sum0 = __riscv_vget_v_f16m2x2_f16m2(_s0, 0); + _sum1 = __riscv_vget_v_f16m2x2_f16m2(_s0, 1); + pC += 16 * 2; + } + if (broadcast_type_C == 4) + { + _sum0 = __riscv_vle16_v_f16m2(pC, vl16); + _sum1 = _sum0; + pC += 16; + } + } + else + { + _sum0 = __riscv_vfmv_v_f_f16m2((__fp16)0.f, vl16); + _sum1 = _sum0; + } + } + else + { + _sum0 = __riscv_vle16_v_f16m2(outptr, vl16); + _sum1 = __riscv_vle16_v_f16m2(outptr + 16, vl16); + } + + const __fp16* pA = pAT; + for (int kk = 0; kk < max_kk; kk++) + { + vfloat16m2_t _pB = __riscv_vle16_v_f16m2(pB, vl16); + _sum0 = __riscv_vfmacc_vf_f16m2(_sum0, pA[0], _pB, vl16); + _sum1 = __riscv_vfmacc_vf_f16m2(_sum1, pA[1], _pB, vl16); + pA += 2; + pB += 16; + } + + if (alpha != 1.f) + { + _sum0 = __riscv_vfmul_vf_f16m2(_sum0, _alpha, vl16); + _sum1 = __riscv_vfmul_vf_f16m2(_sum1, _alpha, vl16); + } + + if (k_end) + { + __riscv_vse16_v_f16m2(outptr0, _sum0, vl16); + __riscv_vse16_v_f16m2(outptr0 + out_hstep, _sum1, vl16); + outptr0 += 16; + } + else + { + __riscv_vse16_v_f16m2(outptr, _sum0, vl16); + __riscv_vse16_v_f16m2(outptr + 16, _sum1, vl16); + } + + outptr += 16 * 2; + } + for (; jj + 7 < max_jj; jj += 8) + { + vfloat16m1_t _sum0; + vfloat16m1_t _sum1; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = __riscv_vfmv_v_f_f16m1(pC[0], vl); + _sum1 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = __riscv_vfmv_v_f_f16m1(pC[0], vl); + _sum1 = __riscv_vfmv_v_f_f16m1(pC[1], vl); + } + if (broadcast_type_C == 3) + { + vfloat16m1x2_t _s0 = __riscv_vlseg2e16_v_f16m1x2(pC, vl8); + _sum0 = __riscv_vget_v_f16m1x2_f16m1(_s0, 0); + _sum1 = __riscv_vget_v_f16m1x2_f16m1(_s0, 1); + pC += 8 * 2; + } + if (broadcast_type_C == 4) + { + _sum0 = __riscv_vle16_v_f16m1(pC, vl8); + _sum1 = _sum0; + pC += 8; + } + } + else + { + _sum0 = __riscv_vfmv_v_f_f16m1((__fp16)0.f, vl); + _sum1 = _sum0; + } + } + else + { + _sum0 = __riscv_vle16_v_f16m1(outptr, vl8); + _sum1 = __riscv_vle16_v_f16m1(outptr + 8, vl8); + } + + const __fp16* pA = pAT; + for (int kk = 0; kk < max_kk; kk++) + { + vfloat16m1_t _pB = __riscv_vle16_v_f16m1(pB, vl8); + _sum0 = __riscv_vfmacc_vf_f16m1(_sum0, pA[0], _pB, vl8); + _sum1 = __riscv_vfmacc_vf_f16m1(_sum1, pA[1], _pB, vl8); + pA += 2; + pB += 8; + } + + if (alpha != 1.f) + { + _sum0 = __riscv_vfmul_vf_f16m1(_sum0, _alpha, vl8); + _sum1 = __riscv_vfmul_vf_f16m1(_sum1, _alpha, vl8); + } + + if (k_end) + { + __riscv_vse16_v_f16m1(outptr0, _sum0, vl8); + __riscv_vse16_v_f16m1(outptr0 + out_hstep, _sum1, vl8); + outptr0 += 8; + } + else + { + __riscv_vse16_v_f16m1(outptr, _sum0, vl8); + __riscv_vse16_v_f16m1(outptr + 8, _sum1, vl8); + } + + outptr += 8 * 2; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat16m1_t _sum0; + vfloat16m1_t _sum1; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = __riscv_vfmv_v_f_f16m1(pC[0], vl4); + _sum1 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = __riscv_vfmv_v_f_f16m1(pC[0], vl4); + _sum1 = __riscv_vfmv_v_f_f16m1(pC[1], vl4); + } + if (broadcast_type_C == 3) + { + vfloat16m1x2_t _s0 = __riscv_vlseg2e16_v_f16m1x2(pC, vl4); + _sum0 = __riscv_vget_v_f16m1x2_f16m1(_s0, 0); + _sum1 = __riscv_vget_v_f16m1x2_f16m1(_s0, 1); + pC += 4 * 2; + } + if (broadcast_type_C == 4) + { + _sum0 = __riscv_vle16_v_f16m1(pC, vl4); + _sum1 = _sum0; + pC += 4; + } + } + else + { + _sum0 = __riscv_vfmv_v_f_f16m1((__fp16)0.f, vl4); + _sum1 = _sum0; + } + } + else + { + _sum0 = __riscv_vle16_v_f16m1(outptr, vl4); + _sum1 = __riscv_vle16_v_f16m1(outptr + 4, vl4); + } + + const __fp16* pA = pAT; + for (int kk = 0; kk < max_kk; kk++) + { + vfloat16m1_t _pB = __riscv_vle16_v_f16m1(pB, vl4); + _sum0 = __riscv_vfmacc_vf_f16m1(_sum0, pA[0], _pB, vl4); + _sum1 = __riscv_vfmacc_vf_f16m1(_sum1, pA[1], _pB, vl4); + pA += 2; + pB += 4; + } + + if (alpha != 1.f) + { + _sum0 = __riscv_vfmul_vf_f16m1(_sum0, _alpha, vl4); + _sum1 = __riscv_vfmul_vf_f16m1(_sum1, _alpha, vl4); + } + + if (k_end) + { + __riscv_vse16_v_f16m1(outptr0, _sum0, vl4); + __riscv_vse16_v_f16m1(outptr0 + out_hstep, _sum1, vl4); + outptr0 += 4; + } + else + { + __riscv_vse16_v_f16m1(outptr, _sum0, vl4); + __riscv_vse16_v_f16m1(outptr + 4, _sum1, vl4); + } + + outptr += 4 * 2; + } +#endif // __riscv_vector + for (; jj + 1 < max_jj; jj += 2) + { + __fp16 sum00; + __fp16 sum01; + __fp16 sum10; + __fp16 sum11; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0) + { + sum00 = pC[0]; + sum01 = pC[0]; + sum10 = pC[0]; + sum11 = pC[0]; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + sum00 = pC[0]; + sum01 = pC[1]; + sum10 = pC[0]; + sum11 = pC[1]; + } + if (broadcast_type_C == 3) + { + sum00 = pC[0]; + sum01 = pC[1]; + sum10 = pC[2]; + sum11 = pC[3]; + pC += 4; + } + if (broadcast_type_C == 4) + { + sum00 = pC[0]; + sum01 = pC[0]; + sum10 = pC[1]; + sum11 = pC[1]; + pC += 2; + } + } + else + { + sum00 = (__fp16)0.f; + sum01 = (__fp16)0.f; + sum10 = (__fp16)0.f; + sum11 = (__fp16)0.f; + } + } + else + { + sum00 = outptr[0]; + sum10 = outptr[1]; + sum01 = outptr[2]; + sum11 = outptr[3]; + } + + const __fp16* pA = pAT; + for (int kk = 0; kk < max_kk; kk++) + { + sum00 = (__fp16)(sum00 + pA[0] * pB[0]); + sum01 = (__fp16)(sum01 + pA[1] * pB[0]); + sum10 = (__fp16)(sum10 + pA[0] * pB[1]); + sum11 = (__fp16)(sum11 + pA[1] * pB[1]); + pA += 2; + pB += 2; + } + + if (alpha != 1.f) + { + sum00 = (__fp16)(sum00 * _alpha); + sum01 = (__fp16)(sum01 * _alpha); + sum10 = (__fp16)(sum10 * _alpha); + sum11 = (__fp16)(sum11 * _alpha); + } + + if (k_end) + { + outptr0[0] = sum00; + outptr0[1] = sum10; + outptr0[out_hstep] = sum01; + outptr0[out_hstep + 1] = sum11; + outptr0 += 2; + } + else + { + outptr[0] = sum00; + outptr[1] = sum10; + outptr[2] = sum01; + outptr[3] = sum11; + } + + outptr += 4; + } + for (; jj < max_jj; jj += 1) + { + __fp16 sum0; + __fp16 sum1; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0) + { + sum0 = pC[0]; + sum1 = pC[0]; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + sum0 = pC[0]; + sum1 = pC[1]; + } + if (broadcast_type_C == 3) + { + sum0 = pC[0]; + sum1 = pC[1]; + pC += 2; + } + if (broadcast_type_C == 4) + { + sum0 = pC[0]; + sum1 = pC[0]; + pC += 1; + } + } + else + { + sum0 = (__fp16)0.f; + sum1 = (__fp16)0.f; + } + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + const __fp16* pA = pAT; + for (int kk = 0; kk < max_kk; kk++) + { + sum0 = (__fp16)(sum0 + pA[0] * pB[0]); + sum1 = (__fp16)(sum1 + pA[1] * pB[0]); + pA += 2; + pB += 1; + } + + if (alpha != 1.f) + { + sum0 = (__fp16)(sum0 * _alpha); + sum1 = (__fp16)(sum1 * _alpha); + } + + if (k_end) + { + outptr0[0] = sum0; + outptr0[out_hstep] = sum1; + outptr0++; + } + else + { + outptr[0] = sum0; + outptr[1] = sum1; + } + + outptr += 2; + } + + pAT += max_kk * 2; + } + for (; ii < max_ii; ii += 1) + { + __fp16* outptr0 = (__fp16*)top_blob + (i + ii) * out_hstep + j; + + const __fp16* pB = pBT; + + if (pC) + { + if (broadcast_type_C == 1 || broadcast_type_C == 2) + pC = (const __fp16*)CT_tile + i + ii; + if (broadcast_type_C == 4) + pC = (const __fp16*)CT_tile + j; + } + + int jj = 0; +#if __riscv_vector + for (; jj + 15 < max_jj; jj += 16) + { + vfloat16m2_t _sum; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum = __riscv_vfmv_v_f_f16m2(pC[0], vl16); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + _sum = __riscv_vle16_v_f16m2(pC, vl16); + pC += 16; + } + } + else + { + _sum = __riscv_vfmv_v_f_f16m2((__fp16)0.f, vl16); + } + } + else + { + _sum = __riscv_vle16_v_f16m2(outptr, vl16); + } + + const __fp16* pA = pAT; + for (int kk = 0; kk < max_kk; kk++) + { + vfloat16m2_t _pB = __riscv_vle16_v_f16m2(pB, vl16); + _sum = __riscv_vfmacc_vf_f16m2(_sum, pA[0], _pB, vl16); + pA += 1; + pB += 16; + } + + if (alpha != 1.f) + { + _sum = __riscv_vfmul_vf_f16m2(_sum, _alpha, vl16); + } + + if (k_end) + { + __riscv_vse16_v_f16m2(outptr0, _sum, vl16); + outptr0 += 16; + } + else + { + __riscv_vse16_v_f16m2(outptr, _sum, vl16); + } + + outptr += 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + vfloat16m1_t _sum; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum = __riscv_vfmv_v_f_f16m1(pC[0], vl8); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + _sum = __riscv_vle16_v_f16m1(pC, vl8); + pC += 8; + } + } + else + { + _sum = __riscv_vfmv_v_f_f16m1((__fp16)0.f, vl8); + } + } + else + { + _sum = __riscv_vle16_v_f16m1(outptr, vl8); + } + + const __fp16* pA = pAT; + for (int kk = 0; kk < max_kk; kk++) + { + vfloat16m1_t _pB = __riscv_vle16_v_f16m1(pB, vl8); + _sum = __riscv_vfmacc_vf_f16m1(_sum, pA[0], _pB, vl8); + pA += 1; + pB += 8; + } + + if (alpha != 1.f) + { + _sum = __riscv_vfmul_vf_f16m1(_sum, _alpha, vl8); + } + + if (k_end) + { + __riscv_vse16_v_f16m1(outptr0, _sum, vl8); + outptr0 += 8; + } + else + { + __riscv_vse16_v_f16m1(outptr, _sum, vl8); + } + + outptr += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat16m1_t _sum; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum = __riscv_vfmv_v_f_f16m1(pC[0], vl4); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + _sum = __riscv_vle16_v_f16m1(pC, vl4); + pC += 4; + } + } + else + { + _sum = __riscv_vfmv_v_f_f16m1((__fp16)0.f, vl4); + } + } + else + { + _sum = __riscv_vle16_v_f16m1(outptr, vl4); + } + + const __fp16* pA = pAT; + for (int kk = 0; kk < max_kk; kk++) + { + vfloat16m1_t _pB = __riscv_vle16_v_f16m1(pB, vl4); + _sum = __riscv_vfmacc_vf_f16m1(_sum, pA[0], _pB, vl4); + pA += 1; + pB += 4; + } + + if (alpha != 1.f) + { + _sum = __riscv_vfmul_vf_f16m1(_sum, _alpha, vl4); + } + + if (k_end) + { + __riscv_vse16_v_f16m1(outptr0, _sum, vl4); + outptr0 += 4; + } + else + { + __riscv_vse16_v_f16m1(outptr, _sum, vl4); + } + + outptr += 4; + } +#endif // __riscv_vector + for (; jj + 1 < max_jj; jj += 2) + { + __fp16 sum0; + __fp16 sum1; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + sum0 = pC[0]; + sum1 = pC[0]; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + sum0 = pC[0]; + sum1 = pC[1]; + pC += 2; + } + } + else + { + sum0 = (__fp16)0.f; + sum1 = (__fp16)0.f; + } + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + const __fp16* pA = pAT; + for (int kk = 0; kk < max_kk; kk++) + { + sum0 = (__fp16)(sum0 + pA[0] * pB[0]); + sum1 = (__fp16)(sum1 + pA[0] * pB[1]); + pA += 1; + pB += 2; + } + + if (alpha != 1.f) + { + sum0 = (__fp16)(sum0 * _alpha); + sum1 = (__fp16)(sum1 * _alpha); + } + + if (k_end) + { + outptr0[0] = sum0; + outptr0[1] = sum1; + outptr0 += 2; + } + else + { + outptr[0] = sum0; + outptr[1] = sum1; + } + + outptr += 2; + } + for (; jj < max_jj; jj += 1) + { + __fp16 sum; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + sum = pC[0]; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + sum = pC[0]; + pC += 1; + } + } + else + { + sum = (__fp16)0.f; + } + } + else + { + sum = outptr[0]; + } + + const __fp16* pA = pAT; + for (int kk = 0; kk < max_kk; kk++) + { + sum = (__fp16)(sum + pA[0] * pB[0]); + pA += 1; + pB += 1; + } + + if (alpha != 1.f) + { + sum = (__fp16)(sum * _alpha); + } + + if (k_end) + { + outptr0[0] = sum; + outptr0++; + } + else + { + outptr[0] = sum; + } + + outptr += 1; + } + + pAT += max_kk; + } +} + +static void get_optimal_tile_mnk_fp16sa(int M, int N, int K, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) +{ + // resolve optimal tile size from cache size + const size_t l2_cache_size = get_cpu_level2_cache_size(); + + if (nT == 0) + nT = get_physical_big_cpu_count(); + + int tile_size = (int)sqrtf((float)l2_cache_size / 3 / sizeof(__fp16)); + +#if __riscv_vector + const int packn = csrr_vlenb() / 2; + const int packn_n = 16; +#else + const int packn = 4; + const int packn_n = 4; +#endif + + TILE_M = std::max(packn, tile_size / packn * packn); + TILE_N = std::max(packn_n, tile_size / packn_n * packn_n); + TILE_K = std::max(packn, tile_size / packn * packn); + + if (K > 0) + { + int nn_K = (K + TILE_K - 1) / TILE_K; + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + (packn - 1)) / packn * packn); + + if (nn_K == 1) + { + tile_size = (int)((float)l2_cache_size / 2 / sizeof(__fp16) / TILE_K); + TILE_M = std::max(packn, tile_size / packn * packn); + TILE_N = std::max(packn_n, tile_size / packn_n * packn_n); + } + } + + TILE_M *= std::min(nT, get_physical_cpu_count()); + + if (M > 0) + { + int nn_M = (M + TILE_M - 1) / TILE_M; + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + (packn - 1)) / packn * packn); + } + + if (N > 0) + { + int nn_N = (N + TILE_N - 1) / TILE_N; + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + (packn_n - 1)) / packn_n * packn_n); + } + + if (nT > 1) + { + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + (packn - 1)) / packn * packn); + } + + // always take constant TILE_M/N/K value when provided + if (constant_TILE_M > 0) + { + TILE_M = (constant_TILE_M + (packn - 1)) / packn * packn; + } + + if (constant_TILE_N > 0) + { + TILE_N = (constant_TILE_N + (packn_n - 1)) / packn_n * packn_n; + } + + if (constant_TILE_K > 0) + { + TILE_K = (constant_TILE_K + (packn - 1)) / packn * packn; + } +} diff --git a/src/layer/riscv/gemm_riscv.cpp b/src/layer/riscv/gemm_riscv.cpp index b8578e704aa1..f3ae95d516e0 100644 --- a/src/layer/riscv/gemm_riscv.cpp +++ b/src/layer/riscv/gemm_riscv.cpp @@ -38,7 +38,6 @@ static void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max const size_t vl = __riscv_vsetvl_e32m1(packn); #endif - const int elempack = A.elempack; const size_t A_hstep = A.dims == 3 ? A.cstep : (size_t)A.w; // NCNN_LOGE("pack_A_tile %d", elempack); @@ -47,6 +46,8 @@ static void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max int ii = 0; #if __riscv_vector + const int elempack = A.elempack; + for (; ii + (packn - 1) < max_ii; ii += packn) { if (elempack == packn) @@ -243,6 +244,9 @@ static void pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max #if __riscv_vector const int packn = csrr_vlenb() / 4; const size_t vl = __riscv_vsetvl_e32m1(packn); + const size_t vl16 = __riscv_vsetvl_e32m4(16); + const size_t vl8 = __riscv_vsetvl_e32m2(8); + const size_t vl4 = __riscv_vsetvl_e32m1(4); #endif const int elempack = B.elempack; @@ -254,16 +258,123 @@ static void pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max int jj = 0; #if __riscv_vector - for (; jj + (packn - 1) < max_jj; jj += packn) + for (; jj + 15 < max_jj; jj += 16) { if (elempack == packn) { - const float* p0 = (const float*)B + (j + jj) * B_hstep + k * packn; + const int q = (j + jj) / packn * packn; + const int r = (j + jj) % packn; + const float* p0 = (const float*)B + q * B_hstep + k * packn + r; + + if (packn >= 16) + { + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse32_v_f32m4(pp, __riscv_vle32_v_f32m4(p0, vl16), vl16); + pp += 16; + p0 += packn; + } + } + if (packn == 8) + { + const float* p1 = (const float*)B + (q + 8) * B_hstep + k * 8; + + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); + __riscv_vse32_v_f32m1(pp + 8, __riscv_vle32_v_f32m1(p1, vl), vl); + pp += 16; + p0 += 8; + p1 += 8; + } + } + if (packn == 4) + { + const float* p1 = (const float*)B + (q + 4) * B_hstep + k * 4; + const float* p2 = (const float*)B + (q + 8) * B_hstep + k * 4; + const float* p3 = (const float*)B + (q + 12) * B_hstep + k * 4; + + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); + __riscv_vse32_v_f32m1(pp + 4, __riscv_vle32_v_f32m1(p1, vl), vl); + __riscv_vse32_v_f32m1(pp + 8, __riscv_vle32_v_f32m1(p2, vl), vl); + __riscv_vse32_v_f32m1(pp + 12, __riscv_vle32_v_f32m1(p3, vl), vl); + pp += 16; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + } + } + } + if (elempack == 1) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; for (int kk = 0; kk < max_kk; kk++) { - __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); - pp += packn; + __riscv_vse32_v_f32m4(pp, __riscv_vlse32_v_f32m4(p0, B_hstep * sizeof(float), vl16), vl16); + pp += 16; + p0++; + } + } + } + for (; jj + 7 < max_jj; jj += 8) + { + if (elempack == packn) + { + const int q = (j + jj) / packn * packn; + const int r = (j + jj) % packn; + const float* p0 = (const float*)B + q * B_hstep + k * packn + r; + + if (packn >= 8) + { + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse32_v_f32m2(pp, __riscv_vle32_v_f32m2(p0, vl8), vl8); + pp += 8; + p0 += packn; + } + } + if (packn == 4) + { + const float* p1 = (const float*)B + (q + 4) * B_hstep + k * 4; + + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); + __riscv_vse32_v_f32m1(pp + 4, __riscv_vle32_v_f32m1(p1, vl), vl); + pp += 8; + p0 += 4; + p1 += 4; + } + } + } + if (elempack == 1) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse32_v_f32m2(pp, __riscv_vlse32_v_f32m2(p0, B_hstep * sizeof(float), vl8), vl8); + pp += 8; + p0++; + } + } + } + for (; jj + 3 < max_jj; jj += 4) + { + if (elempack == packn) + { + const int q = (j + jj) / packn * packn; + const int r = (j + jj) % packn; + const float* p0 = (const float*)B + q * B_hstep + k * packn + r; + + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl4), vl4); + pp += 4; p0 += packn; } } @@ -273,8 +384,8 @@ static void pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max for (int kk = 0; kk < max_kk; kk++) { - __riscv_vse32_v_f32m1(pp, __riscv_vlse32_v_f32m1(p0, B_hstep * sizeof(float), vl), vl); - pp += packn; + __riscv_vse32_v_f32m1(pp, __riscv_vlse32_v_f32m1(p0, B_hstep * sizeof(float), vl4), vl4); + pp += 4; p0++; } } @@ -282,7 +393,23 @@ static void pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max #endif // __riscv_vector for (; jj + 1 < max_jj; jj += 2) { - // if (elempack == 1) +#if __riscv_vector + if (elempack == packn) + { + const int q = (j + jj) / packn * packn; + const int r = (j + jj) % packn; + const float* p0 = (const float*)B + q * B_hstep + k * packn + r; + + for (int kk = 0; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += packn; + } + } +#endif // __riscv_vector + if (elempack == 1) { const float* p0 = (const float*)B + (j + jj) * B_hstep + k; const float* p1 = (const float*)B + (j + jj + 1) * B_hstep + k; @@ -311,7 +438,22 @@ static void pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max } for (; jj < max_jj; jj += 1) { - // if (elempack == 1) +#if __riscv_vector + if (elempack == packn) + { + const int q = (j + jj) / packn * packn; + const int r = (j + jj) % packn; + const float* p0 = (const float*)B + q * B_hstep + k * packn + r; + + for (int kk = 0; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0 += packn; + } + } +#endif // __riscv_vector + if (elempack == 1) { const float* p0 = (const float*)B + (j + jj) * B_hstep + k; @@ -339,6 +481,9 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int #if __riscv_vector const int packn = csrr_vlenb() / 4; const size_t vl = __riscv_vsetvl_e32m1(packn); + const size_t vl16 = __riscv_vsetvl_e32m4(16); + const size_t vl8 = __riscv_vsetvl_e32m2(8); + const size_t vl4 = __riscv_vsetvl_e32m1(4); #endif const int elempack = B.elempack; @@ -350,7 +495,134 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int int jj = 0; #if __riscv_vector - for (; jj + (packn - 1) < max_jj; jj += packn) + for (; jj + 15 < max_jj; jj += 16) + { + if (elempack == packn) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * packn; + + if (packn >= 16) + { + int kk = 0; + for (; kk + (packn - 1) < max_kk; kk += packn) + { + // transposeNx16 + for (int l = 0; l < packn; l++) + { + __riscv_vse32_v_f32m4(pp, __riscv_vlse32_v_f32m4(p0 + l, packn * sizeof(float), vl16), vl16); + pp += 16; + } + p0 += B_hstep * packn; + } + } + if (packn == 8) + { + const float* p1 = p0 + 8 * 8; + + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + // transpose8x8 + transpose8x8 + for (int l = 0; l < 8; l++) + { + __riscv_vse32_v_f32m1(pp, __riscv_vlse32_v_f32m1(p0 + l, 8 * sizeof(float), vl), vl); + __riscv_vse32_v_f32m1(pp + 8, __riscv_vlse32_v_f32m1(p1 + l, 8 * sizeof(float), vl), vl); + pp += 16; + } + p0 += B_hstep * 8; + p1 += B_hstep * 8; + } + } + if (packn == 4) + { + const float* p1 = p0 + 4 * 4; + const float* p2 = p0 + 8 * 4; + const float* p3 = p0 + 12 * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + // transpose4x4 + transpose4x4 + transpose4x4 + transpose4x4 + for (int l = 0; l < 4; l++) + { + __riscv_vse32_v_f32m1(pp, __riscv_vlse32_v_f32m1(p0 + l, 4 * sizeof(float), vl), vl); + __riscv_vse32_v_f32m1(pp + 4, __riscv_vlse32_v_f32m1(p1 + l, 4 * sizeof(float), vl), vl); + __riscv_vse32_v_f32m1(pp + 8, __riscv_vlse32_v_f32m1(p2 + l, 4 * sizeof(float), vl), vl); + __riscv_vse32_v_f32m1(pp + 12, __riscv_vlse32_v_f32m1(p3 + l, 4 * sizeof(float), vl), vl); + pp += 16; + } + p0 += B_hstep * 4; + p1 += B_hstep * 4; + p2 += B_hstep * 4; + p3 += B_hstep * 4; + } + } + } + if (elempack == 1) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + for (int kk = 0; kk < max_kk; kk++) + { + __riscv_vse32_v_f32m4(pp, __riscv_vle32_v_f32m4(p0, vl16), vl16); + pp += 16; + p0 += B_hstep; + } + } + } + for (; jj + 7 < max_jj; jj += 8) + { + if (elempack == packn) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * packn; + + if (packn >= 8) + { + int kk = 0; + for (; kk + (packn - 1) < max_kk; kk += packn) + { + // transposeNx8 + for (int l = 0; l < packn; l++) + { + __riscv_vse32_v_f32m2(pp, __riscv_vlse32_v_f32m2(p0 + l, packn * sizeof(float), vl8), vl8); + pp += 8; + } + p0 += B_hstep * packn; + } + } + if (packn == 4) + { + const float* p1 = p0 + 4 * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + // transpose4x4 + transpose4x4 + for (int l = 0; l < 4; l++) + { + __riscv_vse32_v_f32m1(pp, __riscv_vlse32_v_f32m1(p0 + l, 4 * sizeof(float), vl), vl); + __riscv_vse32_v_f32m1(pp + 4, __riscv_vlse32_v_f32m1(p1 + l, 4 * sizeof(float), vl), vl); + pp += 8; + } + p0 += B_hstep * 4; + p1 += B_hstep * 4; + } + } + } + if (elempack == 1) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + __riscv_vse32_v_f32m2(pp, __riscv_vle32_v_f32m2(p0, vl8), vl8); + pp += 8; + p0 += B_hstep; + } + } + } + for (; jj + 3 < max_jj; jj += 4) { if (elempack == packn) { @@ -359,11 +631,11 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int int kk = 0; for (; kk + (packn - 1) < max_kk; kk += packn) { - // transposeNxN + // transposeNx4 for (int l = 0; l < packn; l++) { - __riscv_vse32_v_f32m1(pp, __riscv_vlse32_v_f32m1(p0 + l, packn * sizeof(float), vl), vl); - pp += packn; + __riscv_vse32_v_f32m1(pp, __riscv_vlse32_v_f32m1(p0 + l, packn * sizeof(float), vl4), vl4); + pp += 4; } p0 += B_hstep * packn; } @@ -375,8 +647,8 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int int kk = 0; for (; kk < max_kk; kk++) { - __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); - pp += packn; + __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl4), vl4); + pp += 4; p0 += B_hstep; } } @@ -450,9 +722,14 @@ static void transpose_unpack_output_tile(const Mat& topT, Mat& top_blob, int i, #if __riscv_vector const int packn = csrr_vlenb() / 4; const size_t vl = __riscv_vsetvl_e32m1(packn); + const size_t vl16 = __riscv_vsetvl_e32m4(16); + const size_t vl8 = __riscv_vsetvl_e32m2(8); + const size_t vl4 = __riscv_vsetvl_e32m1(4); #endif +#if __riscv_vector const int out_elempack = top_blob.elempack; +#endif const size_t out_hstep = top_blob.dims == 3 ? top_blob.cstep : (size_t)top_blob.w; // NCNN_LOGE("transpose_unpack_output_tile %d", out_elempack); @@ -465,9 +742,25 @@ static void transpose_unpack_output_tile(const Mat& topT, Mat& top_blob, int i, { if (out_elempack == packn) { - float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * packn; + int jj = 0; + + const int r0 = j % packn; + if (r0 != 0) + { + const int nn = std::min(packn - r0, max_jj); + float* p0 = (float*)top_blob + (j / packn * packn) * out_hstep + r0 + (i + ii) * packn; + + for (; jj < nn; jj++) + { + __riscv_vsse32_v_f32m1(p0, packn * sizeof(float), __riscv_vle32_v_f32m1(pp, vl), vl); + pp += packn; + p0++; + } + } + + float* p0 = (float*)top_blob + (j + jj) * out_hstep + (i + ii) * packn; - for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn) + for (; jj + (packn - 1) < max_jj; jj += packn) { // transposeNxN for (int l = 0; l < packn; l++) @@ -475,8 +768,16 @@ static void transpose_unpack_output_tile(const Mat& topT, Mat& top_blob, int i, __riscv_vsse32_v_f32m1(p0 + l, packn * sizeof(float), __riscv_vle32_v_f32m1(pp, vl), vl); pp += packn; } + p0 += out_hstep * packn; } + + for (; jj < max_jj; jj++) + { + __riscv_vsse32_v_f32m1(p0, packn * sizeof(float), __riscv_vle32_v_f32m1(pp, vl), vl); + pp += packn; + p0++; + } } if (out_elempack == 1) { @@ -484,8 +785,7 @@ static void transpose_unpack_output_tile(const Mat& topT, Mat& top_blob, int i, for (int jj = 0; jj < max_jj; jj += 1) { - vfloat32m1_t _r0 = __riscv_vle32_v_f32m1(pp, vl); - __riscv_vse32_v_f32m1(p0, _r0, vl); + __riscv_vse32_v_f32m1(p0, __riscv_vle32_v_f32m1(pp, vl), vl); pp += packn; p0 += out_hstep; } @@ -497,23 +797,126 @@ static void transpose_unpack_output_tile(const Mat& topT, Mat& top_blob, int i, #if __riscv_vector if (out_elempack == packn) { - float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * packn; + int jj = 0; - for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn) + for (; jj + 15 < max_jj; jj += 16) { - vfloat32m1x2_t _s0 = __riscv_vlseg2e32_v_f32m1x2(pp, vl); - __riscv_vse32_v_f32m1(p0, __riscv_vget_v_f32m1x2_f32m1(_s0, 0), vl); - __riscv_vse32_v_f32m1(p0 + packn, __riscv_vget_v_f32m1x2_f32m1(_s0, 1), vl); - pp += packn * 2; - p0 += out_hstep * packn; + const int out_j = j + jj; + float* p0 = (float*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + if (packn == 4) + { + __riscv_vse32_v_f32m1(p0, __riscv_vle32_v_f32m1(pp, vl4), vl4); + __riscv_vse32_v_f32m1(p0 + packn, __riscv_vle32_v_f32m1(pp + 16, vl4), vl4); + p0 += out_hstep * 4; + __riscv_vse32_v_f32m1(p0, __riscv_vle32_v_f32m1(pp + 4, vl4), vl4); + __riscv_vse32_v_f32m1(p0 + packn, __riscv_vle32_v_f32m1(pp + 20, vl4), vl4); + p0 += out_hstep * 4; + __riscv_vse32_v_f32m1(p0, __riscv_vle32_v_f32m1(pp + 8, vl4), vl4); + __riscv_vse32_v_f32m1(p0 + packn, __riscv_vle32_v_f32m1(pp + 24, vl4), vl4); + p0 += out_hstep * 4; + __riscv_vse32_v_f32m1(p0, __riscv_vle32_v_f32m1(pp + 12, vl4), vl4); + __riscv_vse32_v_f32m1(p0 + packn, __riscv_vle32_v_f32m1(pp + 28, vl4), vl4); + } + if (packn == 8) + { + __riscv_vse32_v_f32m2(p0, __riscv_vle32_v_f32m2(pp, vl8), vl8); + __riscv_vse32_v_f32m2(p0 + packn, __riscv_vle32_v_f32m2(pp + 16, vl8), vl8); + p0 += out_hstep * 8; + __riscv_vse32_v_f32m2(p0, __riscv_vle32_v_f32m2(pp + 8, vl8), vl8); + __riscv_vse32_v_f32m2(p0 + packn, __riscv_vle32_v_f32m2(pp + 24, vl8), vl8); + } + if (packn >= 16) + { + __riscv_vse32_v_f32m4(p0, __riscv_vle32_v_f32m4(pp, vl16), vl16); + __riscv_vse32_v_f32m4(p0 + packn, __riscv_vle32_v_f32m4(pp + 16, vl16), vl16); + } + pp += 16 * 2; + } + for (; jj + 7 < max_jj; jj += 8) + { + const int out_j = j + jj; + float* p0 = (float*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + if (packn == 4) + { + __riscv_vse32_v_f32m1(p0, __riscv_vle32_v_f32m1(pp, vl4), vl4); + __riscv_vse32_v_f32m1(p0 + packn, __riscv_vle32_v_f32m1(pp + 8, vl4), vl4); + p0 += out_hstep * 4; + __riscv_vse32_v_f32m1(p0, __riscv_vle32_v_f32m1(pp + 4, vl4), vl4); + __riscv_vse32_v_f32m1(p0 + packn, __riscv_vle32_v_f32m1(pp + 12, vl4), vl4); + } + if (packn >= 8) + { + __riscv_vse32_v_f32m2(p0, __riscv_vle32_v_f32m2(pp, vl8), vl8); + __riscv_vse32_v_f32m2(p0 + packn, __riscv_vle32_v_f32m2(pp + 8, vl8), vl8); + } + pp += 8 * 2; + } + for (; jj + 3 < max_jj; jj += 4) + { + const int out_j = j + jj; + float* p0 = (float*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + __riscv_vse32_v_f32m1(p0, __riscv_vle32_v_f32m1(pp, vl4), vl4); + __riscv_vse32_v_f32m1(p0 + packn, __riscv_vle32_v_f32m1(pp + 4, vl4), vl4); + pp += 4 * 2; + } + for (; jj + 1 < max_jj; jj += 2) + { + const int out_j = j + jj; + float* p0 = (float*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + p0[0] = pp[0]; + p0[1] = pp[1]; + p0[packn] = pp[2]; + p0[packn + 1] = pp[3]; + pp += 2 * 2; + } + for (; jj < max_jj; jj += 1) + { + const int out_j = j + jj; + float* p0 = (float*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + p0[0] = pp[0]; + p0[packn] = pp[1]; + pp += 2; } } -#endif // __riscv_vector if (out_elempack == 1) +#endif // __riscv_vector { float* p0 = (float*)top_blob + j * out_hstep + (i + ii); - for (int jj = 0; jj < max_jj; jj += 1) + int jj = 0; +#if __riscv_vector + for (; jj + 15 < max_jj; jj += 16) + { + __riscv_vsse32_v_f32m4(p0, out_hstep * sizeof(float), __riscv_vle32_v_f32m4(pp, vl16), vl16); + __riscv_vsse32_v_f32m4(p0 + 1, out_hstep * sizeof(float), __riscv_vle32_v_f32m4(pp + 16, vl16), vl16); + pp += 16 * 2; + p0 += out_hstep * 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + __riscv_vsse32_v_f32m2(p0, out_hstep * sizeof(float), __riscv_vle32_v_f32m2(pp, vl8), vl8); + __riscv_vsse32_v_f32m2(p0 + 1, out_hstep * sizeof(float), __riscv_vle32_v_f32m2(pp + 8, vl8), vl8); + pp += 8 * 2; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + __riscv_vsse32_v_f32m1(p0, out_hstep * sizeof(float), __riscv_vle32_v_f32m1(pp, vl4), vl4); + __riscv_vsse32_v_f32m1(p0 + 1, out_hstep * sizeof(float), __riscv_vle32_v_f32m1(pp + 4, vl4), vl4); + pp += 4 * 2; + p0 += out_hstep * 4; + } +#endif // __riscv_vector + for (; jj + 1 < max_jj; jj += 2) + { + p0[0] = pp[0]; + p0[out_hstep] = pp[1]; + p0[1] = pp[2]; + p0[out_hstep + 1] = pp[3]; + pp += 2 * 2; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) { p0[0] = pp[0]; p0[1] = pp[1]; @@ -527,22 +930,92 @@ static void transpose_unpack_output_tile(const Mat& topT, Mat& top_blob, int i, #if __riscv_vector if (out_elempack == packn) { - float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * packn; + int jj = 0; - for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn) + for (; jj + 15 < max_jj; jj += 16) { - vfloat32m1_t _r0 = __riscv_vle32_v_f32m1(pp, vl); - __riscv_vse32_v_f32m1(p0, _r0, vl); - pp += packn; - p0 += out_hstep * packn; + const int out_j = j + jj; + float* p0 = (float*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + if (packn == 4) + { + __riscv_vse32_v_f32m1(p0, __riscv_vle32_v_f32m1(pp, vl4), vl4); + p0 += out_hstep * 4; + __riscv_vse32_v_f32m1(p0, __riscv_vle32_v_f32m1(pp + 4, vl4), vl4); + p0 += out_hstep * 4; + __riscv_vse32_v_f32m1(p0, __riscv_vle32_v_f32m1(pp + 8, vl4), vl4); + p0 += out_hstep * 4; + __riscv_vse32_v_f32m1(p0, __riscv_vle32_v_f32m1(pp + 12, vl4), vl4); + } + if (packn == 8) + { + __riscv_vse32_v_f32m2(p0, __riscv_vle32_v_f32m2(pp, vl8), vl8); + p0 += out_hstep * 8; + __riscv_vse32_v_f32m2(p0, __riscv_vle32_v_f32m2(pp + 8, vl8), vl8); + } + if (packn >= 16) + { + __riscv_vse32_v_f32m4(p0, __riscv_vle32_v_f32m4(pp, vl16), vl16); + } + pp += 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + const int out_j = j + jj; + float* p0 = (float*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + if (packn == 4) + { + __riscv_vse32_v_f32m1(p0, __riscv_vle32_v_f32m1(pp, vl4), vl4); + p0 += out_hstep * 4; + __riscv_vse32_v_f32m1(p0, __riscv_vle32_v_f32m1(pp + 4, vl4), vl4); + } + if (packn >= 8) + { + __riscv_vse32_v_f32m2(p0, __riscv_vle32_v_f32m2(pp, vl8), vl8); + } + pp += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + const int out_j = j + jj; + float* p0 = (float*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + __riscv_vse32_v_f32m1(p0, __riscv_vle32_v_f32m1(pp, vl4), vl4); + pp += 4; + } + for (; jj < max_jj; jj += 1) + { + const int out_j = j + jj; + float* p0 = (float*)top_blob + (out_j / packn * packn) * out_hstep + out_j % packn + (i + ii) * packn; + p0[0] = pp[0]; + pp += 1; } } -#endif // __riscv_vector if (out_elempack == 1) +#endif // __riscv_vector { float* p0 = (float*)top_blob + j * out_hstep + (i + ii); - for (int jj = 0; jj < max_jj; jj += 1) + int jj = 0; +#if __riscv_vector + for (; jj + 15 < max_jj; jj += 16) + { + __riscv_vsse32_v_f32m4(p0, out_hstep * sizeof(float), __riscv_vle32_v_f32m4(pp, vl16), vl16); + pp += 16; + p0 += out_hstep * 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + __riscv_vsse32_v_f32m2(p0, out_hstep * sizeof(float), __riscv_vle32_v_f32m2(pp, vl8), vl8); + pp += 8; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + __riscv_vsse32_v_f32m1(p0, out_hstep * sizeof(float), __riscv_vle32_v_f32m1(pp, vl4), vl4); + pp += 4; + p0 += out_hstep * 4; + } +#endif // __riscv_vector + for (; jj < max_jj; jj += 1) { p0[0] = pp[0]; pp += 1; @@ -557,9 +1030,14 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons #if __riscv_vector const int packn = csrr_vlenb() / 4; const size_t vl = __riscv_vsetvl_e32m1(packn); + const size_t vl16 = __riscv_vsetvl_e32m4(16); + const size_t vl8 = __riscv_vsetvl_e32m2(8); + const size_t vl4 = __riscv_vsetvl_e32m1(4); #endif +#if __riscv_vector const int out_elempack = top_blob.elempack; +#endif const size_t out_hstep = top_blob.dims == 3 ? top_blob.cstep : (size_t)top_blob.w; const float* pAT = AT_tile; @@ -589,20 +1067,109 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } int jj = 0; - for (; jj + (packn - 1) < max_jj; jj += packn) + for (; jj + 15 < max_jj; jj += 16) { - if (packn == 8) - { - vfloat32m1_t _sum0; - vfloat32m1_t _sum1; - vfloat32m1_t _sum2; - vfloat32m1_t _sum3; - vfloat32m1_t _sum4; - vfloat32m1_t _sum5; - vfloat32m1_t _sum6; - vfloat32m1_t _sum7; - - if (k == 0) + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + vfloat32m1_t _sum2; + vfloat32m1_t _sum3; + vfloat32m1_t _sum4; + vfloat32m1_t _sum5; + vfloat32m1_t _sum6; + vfloat32m1_t _sum7; + vfloat32m1_t _sum8; + vfloat32m1_t _sum9; + vfloat32m1_t _suma; + vfloat32m1_t _sumb; + vfloat32m1_t _sumc; + vfloat32m1_t _sumd; + vfloat32m1_t _sume; + vfloat32m1_t _sumf; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + _sum8 = _sum0; + _sum9 = _sum0; + _suma = _sum0; + _sumb = _sum0; + _sumc = _sum0; + _sumd = _sum0; + _sume = _sum0; + _sumf = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = __riscv_vle32_v_f32m1(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + _sum8 = _sum0; + _sum9 = _sum0; + _suma = _sum0; + _sumb = _sum0; + _sumc = _sum0; + _sumd = _sum0; + _sume = _sum0; + _sumf = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = __riscv_vle32_v_f32m1(pC, vl); + _sum1 = __riscv_vle32_v_f32m1(pC + packn, vl); + _sum2 = __riscv_vle32_v_f32m1(pC + packn * 2, vl); + _sum3 = __riscv_vle32_v_f32m1(pC + packn * 3, vl); + _sum4 = __riscv_vle32_v_f32m1(pC + packn * 4, vl); + _sum5 = __riscv_vle32_v_f32m1(pC + packn * 5, vl); + _sum6 = __riscv_vle32_v_f32m1(pC + packn * 6, vl); + _sum7 = __riscv_vle32_v_f32m1(pC + packn * 7, vl); + _sum8 = __riscv_vle32_v_f32m1(pC + packn * 8, vl); + _sum9 = __riscv_vle32_v_f32m1(pC + packn * 9, vl); + _suma = __riscv_vle32_v_f32m1(pC + packn * 10, vl); + _sumb = __riscv_vle32_v_f32m1(pC + packn * 11, vl); + _sumc = __riscv_vle32_v_f32m1(pC + packn * 12, vl); + _sumd = __riscv_vle32_v_f32m1(pC + packn * 13, vl); + _sume = __riscv_vle32_v_f32m1(pC + packn * 14, vl); + _sumf = __riscv_vle32_v_f32m1(pC + packn * 15, vl); + pC += packn * 16; + } + if (broadcast_type_C == 4) + { + _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); + _sum1 = __riscv_vfmv_v_f_f32m1(pC[1], vl); + _sum2 = __riscv_vfmv_v_f_f32m1(pC[2], vl); + _sum3 = __riscv_vfmv_v_f_f32m1(pC[3], vl); + _sum4 = __riscv_vfmv_v_f_f32m1(pC[4], vl); + _sum5 = __riscv_vfmv_v_f_f32m1(pC[5], vl); + _sum6 = __riscv_vfmv_v_f_f32m1(pC[6], vl); + _sum7 = __riscv_vfmv_v_f_f32m1(pC[7], vl); + _sum8 = __riscv_vfmv_v_f_f32m1(pC[8], vl); + _sum9 = __riscv_vfmv_v_f_f32m1(pC[9], vl); + _suma = __riscv_vfmv_v_f_f32m1(pC[10], vl); + _sumb = __riscv_vfmv_v_f_f32m1(pC[11], vl); + _sumc = __riscv_vfmv_v_f_f32m1(pC[12], vl); + _sumd = __riscv_vfmv_v_f_f32m1(pC[13], vl); + _sume = __riscv_vfmv_v_f_f32m1(pC[14], vl); + _sumf = __riscv_vfmv_v_f_f32m1(pC[15], vl); + pC += 16; + } + } + else { _sum0 = __riscv_vfmv_v_f_f32m1(0.f, vl); _sum1 = __riscv_vfmv_v_f_f32m1(0.f, vl); @@ -612,221 +1179,361 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons _sum5 = __riscv_vfmv_v_f_f32m1(0.f, vl); _sum6 = __riscv_vfmv_v_f_f32m1(0.f, vl); _sum7 = __riscv_vfmv_v_f_f32m1(0.f, vl); - - if (pC) - { - if (broadcast_type_C == 0) - { - _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum1 = _sum0; - _sum2 = _sum0; - _sum3 = _sum0; - _sum4 = _sum0; - _sum5 = _sum0; - _sum6 = _sum0; - _sum7 = _sum0; - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _sum0 = __riscv_vle32_v_f32m1(pC, vl); - _sum1 = _sum0; - _sum2 = _sum0; - _sum3 = _sum0; - _sum4 = _sum0; - _sum5 = _sum0; - _sum6 = _sum0; - _sum7 = _sum0; - } - if (broadcast_type_C == 3) - { - _sum0 = __riscv_vle32_v_f32m1(pC, vl); - _sum1 = __riscv_vle32_v_f32m1(pC + packn, vl); - _sum2 = __riscv_vle32_v_f32m1(pC + packn * 2, vl); - _sum3 = __riscv_vle32_v_f32m1(pC + packn * 3, vl); - _sum4 = __riscv_vle32_v_f32m1(pC + packn * 4, vl); - _sum5 = __riscv_vle32_v_f32m1(pC + packn * 5, vl); - _sum6 = __riscv_vle32_v_f32m1(pC + packn * 6, vl); - _sum7 = __riscv_vle32_v_f32m1(pC + packn * 7, vl); - pC += packn * 8; - } - if (broadcast_type_C == 4) - { - _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum1 = __riscv_vfmv_v_f_f32m1(pC[1], vl); - _sum2 = __riscv_vfmv_v_f_f32m1(pC[2], vl); - _sum3 = __riscv_vfmv_v_f_f32m1(pC[3], vl); - _sum4 = __riscv_vfmv_v_f_f32m1(pC[4], vl); - _sum5 = __riscv_vfmv_v_f_f32m1(pC[5], vl); - _sum6 = __riscv_vfmv_v_f_f32m1(pC[6], vl); - _sum7 = __riscv_vfmv_v_f_f32m1(pC[7], vl); - pC += 8; - } - } + _sum8 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum9 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _suma = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sumb = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sumc = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sumd = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sume = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sumf = __riscv_vfmv_v_f_f32m1(0.f, vl); } - else + } + else + { + _sum0 = __riscv_vle32_v_f32m1(outptr, vl); + _sum1 = __riscv_vle32_v_f32m1(outptr + packn, vl); + _sum2 = __riscv_vle32_v_f32m1(outptr + packn * 2, vl); + _sum3 = __riscv_vle32_v_f32m1(outptr + packn * 3, vl); + _sum4 = __riscv_vle32_v_f32m1(outptr + packn * 4, vl); + _sum5 = __riscv_vle32_v_f32m1(outptr + packn * 5, vl); + _sum6 = __riscv_vle32_v_f32m1(outptr + packn * 6, vl); + _sum7 = __riscv_vle32_v_f32m1(outptr + packn * 7, vl); + _sum8 = __riscv_vle32_v_f32m1(outptr + packn * 8, vl); + _sum9 = __riscv_vle32_v_f32m1(outptr + packn * 9, vl); + _suma = __riscv_vle32_v_f32m1(outptr + packn * 10, vl); + _sumb = __riscv_vle32_v_f32m1(outptr + packn * 11, vl); + _sumc = __riscv_vle32_v_f32m1(outptr + packn * 12, vl); + _sumd = __riscv_vle32_v_f32m1(outptr + packn * 13, vl); + _sume = __riscv_vle32_v_f32m1(outptr + packn * 14, vl); + _sumf = __riscv_vle32_v_f32m1(outptr + packn * 15, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA = __riscv_vle32_v_f32m1(pA, vl); + _sum0 = __riscv_vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); + _sum1 = __riscv_vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); + _sum2 = __riscv_vfmadd_vf_f32m1(_pA, pB[2], _sum2, vl); + _sum3 = __riscv_vfmadd_vf_f32m1(_pA, pB[3], _sum3, vl); + _sum4 = __riscv_vfmadd_vf_f32m1(_pA, pB[4], _sum4, vl); + _sum5 = __riscv_vfmadd_vf_f32m1(_pA, pB[5], _sum5, vl); + _sum6 = __riscv_vfmadd_vf_f32m1(_pA, pB[6], _sum6, vl); + _sum7 = __riscv_vfmadd_vf_f32m1(_pA, pB[7], _sum7, vl); + _sum8 = __riscv_vfmadd_vf_f32m1(_pA, pB[8], _sum8, vl); + _sum9 = __riscv_vfmadd_vf_f32m1(_pA, pB[9], _sum9, vl); + _suma = __riscv_vfmadd_vf_f32m1(_pA, pB[10], _suma, vl); + _sumb = __riscv_vfmadd_vf_f32m1(_pA, pB[11], _sumb, vl); + _sumc = __riscv_vfmadd_vf_f32m1(_pA, pB[12], _sumc, vl); + _sumd = __riscv_vfmadd_vf_f32m1(_pA, pB[13], _sumd, vl); + _sume = __riscv_vfmadd_vf_f32m1(_pA, pB[14], _sume, vl); + _sumf = __riscv_vfmadd_vf_f32m1(_pA, pB[15], _sumf, vl); + pA += packn; + pB += 16; + } + + if (k_end) + { + if (out_elempack == packn) { - _sum0 = __riscv_vle32_v_f32m1(outptr, vl); - _sum1 = __riscv_vle32_v_f32m1(outptr + packn, vl); - _sum2 = __riscv_vle32_v_f32m1(outptr + packn * 2, vl); - _sum3 = __riscv_vle32_v_f32m1(outptr + packn * 3, vl); - _sum4 = __riscv_vle32_v_f32m1(outptr + packn * 4, vl); - _sum5 = __riscv_vle32_v_f32m1(outptr + packn * 5, vl); - _sum6 = __riscv_vle32_v_f32m1(outptr + packn * 6, vl); - _sum7 = __riscv_vle32_v_f32m1(outptr + packn * 7, vl); + __riscv_vse32_v_f32m1(outptr0, _sum0, vl); + __riscv_vse32_v_f32m1(outptr0 + packn, _sum1, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 2, _sum2, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 3, _sum3, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 4, _sum4, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 5, _sum5, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 6, _sum6, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 7, _sum7, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 8, _sum8, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 9, _sum9, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 10, _suma, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 11, _sumb, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 12, _sumc, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 13, _sumd, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 14, _sume, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 15, _sumf, vl); + outptr0 += packn * 16; } - - const float* pA = pAT; - int kk = 0; - for (; kk < max_kk; kk += 1) + if (out_elempack == 1) { - vfloat32m1_t _pA = __riscv_vle32_v_f32m1(pA, vl); - _sum0 = __riscv_vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); - _sum1 = __riscv_vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); - _sum2 = __riscv_vfmadd_vf_f32m1(_pA, pB[2], _sum2, vl); - _sum3 = __riscv_vfmadd_vf_f32m1(_pA, pB[3], _sum3, vl); - _sum4 = __riscv_vfmadd_vf_f32m1(_pA, pB[4], _sum4, vl); - _sum5 = __riscv_vfmadd_vf_f32m1(_pA, pB[5], _sum5, vl); - _sum6 = __riscv_vfmadd_vf_f32m1(_pA, pB[6], _sum6, vl); - _sum7 = __riscv_vfmadd_vf_f32m1(_pA, pB[7], _sum7, vl); - pA += packn; - pB += 8; + __riscv_vsse32_v_f32m1(outptr0, out_hstep * sizeof(float), _sum0, vl); + __riscv_vsse32_v_f32m1(outptr0 + 1, out_hstep * sizeof(float), _sum1, vl); + __riscv_vsse32_v_f32m1(outptr0 + 2, out_hstep * sizeof(float), _sum2, vl); + __riscv_vsse32_v_f32m1(outptr0 + 3, out_hstep * sizeof(float), _sum3, vl); + __riscv_vsse32_v_f32m1(outptr0 + 4, out_hstep * sizeof(float), _sum4, vl); + __riscv_vsse32_v_f32m1(outptr0 + 5, out_hstep * sizeof(float), _sum5, vl); + __riscv_vsse32_v_f32m1(outptr0 + 6, out_hstep * sizeof(float), _sum6, vl); + __riscv_vsse32_v_f32m1(outptr0 + 7, out_hstep * sizeof(float), _sum7, vl); + __riscv_vsse32_v_f32m1(outptr0 + 8, out_hstep * sizeof(float), _sum8, vl); + __riscv_vsse32_v_f32m1(outptr0 + 9, out_hstep * sizeof(float), _sum9, vl); + __riscv_vsse32_v_f32m1(outptr0 + 10, out_hstep * sizeof(float), _suma, vl); + __riscv_vsse32_v_f32m1(outptr0 + 11, out_hstep * sizeof(float), _sumb, vl); + __riscv_vsse32_v_f32m1(outptr0 + 12, out_hstep * sizeof(float), _sumc, vl); + __riscv_vsse32_v_f32m1(outptr0 + 13, out_hstep * sizeof(float), _sumd, vl); + __riscv_vsse32_v_f32m1(outptr0 + 14, out_hstep * sizeof(float), _sume, vl); + __riscv_vsse32_v_f32m1(outptr0 + 15, out_hstep * sizeof(float), _sumf, vl); + outptr0 += 16; } + } + else + { + __riscv_vse32_v_f32m1(outptr, _sum0, vl); + __riscv_vse32_v_f32m1(outptr + packn, _sum1, vl); + __riscv_vse32_v_f32m1(outptr + packn * 2, _sum2, vl); + __riscv_vse32_v_f32m1(outptr + packn * 3, _sum3, vl); + __riscv_vse32_v_f32m1(outptr + packn * 4, _sum4, vl); + __riscv_vse32_v_f32m1(outptr + packn * 5, _sum5, vl); + __riscv_vse32_v_f32m1(outptr + packn * 6, _sum6, vl); + __riscv_vse32_v_f32m1(outptr + packn * 7, _sum7, vl); + __riscv_vse32_v_f32m1(outptr + packn * 8, _sum8, vl); + __riscv_vse32_v_f32m1(outptr + packn * 9, _sum9, vl); + __riscv_vse32_v_f32m1(outptr + packn * 10, _suma, vl); + __riscv_vse32_v_f32m1(outptr + packn * 11, _sumb, vl); + __riscv_vse32_v_f32m1(outptr + packn * 12, _sumc, vl); + __riscv_vse32_v_f32m1(outptr + packn * 13, _sumd, vl); + __riscv_vse32_v_f32m1(outptr + packn * 14, _sume, vl); + __riscv_vse32_v_f32m1(outptr + packn * 15, _sumf, vl); + } + + outptr += packn * 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + vfloat32m1_t _sum2; + vfloat32m1_t _sum3; + vfloat32m1_t _sum4; + vfloat32m1_t _sum5; + vfloat32m1_t _sum6; + vfloat32m1_t _sum7; - if (k_end) + if (k == 0) + { + if (pC) { - if (out_elempack == packn) + if (broadcast_type_C == 0) + { + _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - __riscv_vse32_v_f32m1(outptr0, _sum0, vl); - __riscv_vse32_v_f32m1(outptr0 + packn, _sum1, vl); - __riscv_vse32_v_f32m1(outptr0 + packn * 2, _sum2, vl); - __riscv_vse32_v_f32m1(outptr0 + packn * 3, _sum3, vl); - __riscv_vse32_v_f32m1(outptr0 + packn * 4, _sum4, vl); - __riscv_vse32_v_f32m1(outptr0 + packn * 5, _sum5, vl); - __riscv_vse32_v_f32m1(outptr0 + packn * 6, _sum6, vl); - __riscv_vse32_v_f32m1(outptr0 + packn * 7, _sum7, vl); - outptr0 += packn * 8; + _sum0 = __riscv_vle32_v_f32m1(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = __riscv_vle32_v_f32m1(pC, vl); + _sum1 = __riscv_vle32_v_f32m1(pC + packn, vl); + _sum2 = __riscv_vle32_v_f32m1(pC + packn * 2, vl); + _sum3 = __riscv_vle32_v_f32m1(pC + packn * 3, vl); + _sum4 = __riscv_vle32_v_f32m1(pC + packn * 4, vl); + _sum5 = __riscv_vle32_v_f32m1(pC + packn * 5, vl); + _sum6 = __riscv_vle32_v_f32m1(pC + packn * 6, vl); + _sum7 = __riscv_vle32_v_f32m1(pC + packn * 7, vl); + pC += packn * 8; } - if (out_elempack == 1) + if (broadcast_type_C == 4) { - vfloat32m1x8_t _sum = __riscv_vcreate_v_f32m1x8(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7); - __riscv_vssseg8e32_v_f32m1x8(outptr0, out_hstep * sizeof(float), _sum, vl); - outptr0 += 8; + _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); + _sum1 = __riscv_vfmv_v_f_f32m1(pC[1], vl); + _sum2 = __riscv_vfmv_v_f_f32m1(pC[2], vl); + _sum3 = __riscv_vfmv_v_f_f32m1(pC[3], vl); + _sum4 = __riscv_vfmv_v_f_f32m1(pC[4], vl); + _sum5 = __riscv_vfmv_v_f_f32m1(pC[5], vl); + _sum6 = __riscv_vfmv_v_f_f32m1(pC[6], vl); + _sum7 = __riscv_vfmv_v_f_f32m1(pC[7], vl); + pC += 8; } } else - { - __riscv_vse32_v_f32m1(outptr, _sum0, vl); - __riscv_vse32_v_f32m1(outptr + packn, _sum1, vl); - __riscv_vse32_v_f32m1(outptr + packn * 2, _sum2, vl); - __riscv_vse32_v_f32m1(outptr + packn * 3, _sum3, vl); - __riscv_vse32_v_f32m1(outptr + packn * 4, _sum4, vl); - __riscv_vse32_v_f32m1(outptr + packn * 5, _sum5, vl); - __riscv_vse32_v_f32m1(outptr + packn * 6, _sum6, vl); - __riscv_vse32_v_f32m1(outptr + packn * 7, _sum7, vl); - } - - outptr += packn * 8; - } - else if (packn == 4) - { - vfloat32m1_t _sum0; - vfloat32m1_t _sum1; - vfloat32m1_t _sum2; - vfloat32m1_t _sum3; - - if (k == 0) { _sum0 = __riscv_vfmv_v_f_f32m1(0.f, vl); _sum1 = __riscv_vfmv_v_f_f32m1(0.f, vl); _sum2 = __riscv_vfmv_v_f_f32m1(0.f, vl); _sum3 = __riscv_vfmv_v_f_f32m1(0.f, vl); - - if (pC) - { - if (broadcast_type_C == 0) - { - _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum1 = _sum0; - _sum2 = _sum0; - _sum3 = _sum0; - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _sum0 = __riscv_vle32_v_f32m1(pC, vl); - _sum1 = _sum0; - _sum2 = _sum0; - _sum3 = _sum0; - } - if (broadcast_type_C == 3) - { - _sum0 = __riscv_vle32_v_f32m1(pC, vl); - _sum1 = __riscv_vle32_v_f32m1(pC + packn, vl); - _sum2 = __riscv_vle32_v_f32m1(pC + packn * 2, vl); - _sum3 = __riscv_vle32_v_f32m1(pC + packn * 3, vl); - pC += packn * 4; - } - if (broadcast_type_C == 4) - { - _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum1 = __riscv_vfmv_v_f_f32m1(pC[1], vl); - _sum2 = __riscv_vfmv_v_f_f32m1(pC[2], vl); - _sum3 = __riscv_vfmv_v_f_f32m1(pC[3], vl); - pC += 4; - } - } + _sum4 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum5 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum6 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum7 = __riscv_vfmv_v_f_f32m1(0.f, vl); } - else + } + else + { + _sum0 = __riscv_vle32_v_f32m1(outptr, vl); + _sum1 = __riscv_vle32_v_f32m1(outptr + packn, vl); + _sum2 = __riscv_vle32_v_f32m1(outptr + packn * 2, vl); + _sum3 = __riscv_vle32_v_f32m1(outptr + packn * 3, vl); + _sum4 = __riscv_vle32_v_f32m1(outptr + packn * 4, vl); + _sum5 = __riscv_vle32_v_f32m1(outptr + packn * 5, vl); + _sum6 = __riscv_vle32_v_f32m1(outptr + packn * 6, vl); + _sum7 = __riscv_vle32_v_f32m1(outptr + packn * 7, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA = __riscv_vle32_v_f32m1(pA, vl); + _sum0 = __riscv_vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); + _sum1 = __riscv_vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); + _sum2 = __riscv_vfmadd_vf_f32m1(_pA, pB[2], _sum2, vl); + _sum3 = __riscv_vfmadd_vf_f32m1(_pA, pB[3], _sum3, vl); + _sum4 = __riscv_vfmadd_vf_f32m1(_pA, pB[4], _sum4, vl); + _sum5 = __riscv_vfmadd_vf_f32m1(_pA, pB[5], _sum5, vl); + _sum6 = __riscv_vfmadd_vf_f32m1(_pA, pB[6], _sum6, vl); + _sum7 = __riscv_vfmadd_vf_f32m1(_pA, pB[7], _sum7, vl); + pA += packn; + pB += 8; + } + + if (k_end) + { + if (out_elempack == packn) { - _sum0 = __riscv_vle32_v_f32m1(outptr, vl); - _sum1 = __riscv_vle32_v_f32m1(outptr + packn, vl); - _sum2 = __riscv_vle32_v_f32m1(outptr + packn * 2, vl); - _sum3 = __riscv_vle32_v_f32m1(outptr + packn * 3, vl); + __riscv_vse32_v_f32m1(outptr0, _sum0, vl); + __riscv_vse32_v_f32m1(outptr0 + packn, _sum1, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 2, _sum2, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 3, _sum3, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 4, _sum4, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 5, _sum5, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 6, _sum6, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 7, _sum7, vl); + outptr0 += packn * 8; } - - const float* pA = pAT; - int kk = 0; - for (; kk < max_kk; kk += 1) + if (out_elempack == 1) { - vfloat32m1_t _pA = __riscv_vle32_v_f32m1(pA, vl); - _sum0 = __riscv_vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); - _sum1 = __riscv_vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); - _sum2 = __riscv_vfmadd_vf_f32m1(_pA, pB[2], _sum2, vl); - _sum3 = __riscv_vfmadd_vf_f32m1(_pA, pB[3], _sum3, vl); - pA += packn; - pB += 4; + vfloat32m1x8_t _sum = __riscv_vcreate_v_f32m1x8(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7); + __riscv_vssseg8e32_v_f32m1x8(outptr0, out_hstep * sizeof(float), _sum, vl); + outptr0 += 8; } + } + else + { + __riscv_vse32_v_f32m1(outptr, _sum0, vl); + __riscv_vse32_v_f32m1(outptr + packn, _sum1, vl); + __riscv_vse32_v_f32m1(outptr + packn * 2, _sum2, vl); + __riscv_vse32_v_f32m1(outptr + packn * 3, _sum3, vl); + __riscv_vse32_v_f32m1(outptr + packn * 4, _sum4, vl); + __riscv_vse32_v_f32m1(outptr + packn * 5, _sum5, vl); + __riscv_vse32_v_f32m1(outptr + packn * 6, _sum6, vl); + __riscv_vse32_v_f32m1(outptr + packn * 7, _sum7, vl); + } - if (k_end) + outptr += packn * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + vfloat32m1_t _sum2; + vfloat32m1_t _sum3; + + if (k == 0) + { + if (pC) { - if (out_elempack == packn) + if (broadcast_type_C == 0) + { + _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - __riscv_vse32_v_f32m1(outptr0, _sum0, vl); - __riscv_vse32_v_f32m1(outptr0 + packn, _sum1, vl); - __riscv_vse32_v_f32m1(outptr0 + packn * 2, _sum2, vl); - __riscv_vse32_v_f32m1(outptr0 + packn * 3, _sum3, vl); - outptr0 += packn * 4; + _sum0 = __riscv_vle32_v_f32m1(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; } - if (out_elempack == 1) + if (broadcast_type_C == 3) { - vfloat32m1x4_t _sum = __riscv_vcreate_v_f32m1x4(_sum0, _sum1, _sum2, _sum3); - __riscv_vssseg4e32_v_f32m1x4(outptr0, out_hstep * sizeof(float), _sum, vl); - outptr0 += 4; + _sum0 = __riscv_vle32_v_f32m1(pC, vl); + _sum1 = __riscv_vle32_v_f32m1(pC + packn, vl); + _sum2 = __riscv_vle32_v_f32m1(pC + packn * 2, vl); + _sum3 = __riscv_vle32_v_f32m1(pC + packn * 3, vl); + pC += packn * 4; + } + if (broadcast_type_C == 4) + { + _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); + _sum1 = __riscv_vfmv_v_f_f32m1(pC[1], vl); + _sum2 = __riscv_vfmv_v_f_f32m1(pC[2], vl); + _sum3 = __riscv_vfmv_v_f_f32m1(pC[3], vl); + pC += 4; } } else { - __riscv_vse32_v_f32m1(outptr, _sum0, vl); - __riscv_vse32_v_f32m1(outptr + packn, _sum1, vl); - __riscv_vse32_v_f32m1(outptr + packn * 2, _sum2, vl); - __riscv_vse32_v_f32m1(outptr + packn * 3, _sum3, vl); + _sum0 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum1 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum2 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum3 = __riscv_vfmv_v_f_f32m1(0.f, vl); } + } + else + { + _sum0 = __riscv_vle32_v_f32m1(outptr, vl); + _sum1 = __riscv_vle32_v_f32m1(outptr + packn, vl); + _sum2 = __riscv_vle32_v_f32m1(outptr + packn * 2, vl); + _sum3 = __riscv_vle32_v_f32m1(outptr + packn * 3, vl); + } - outptr += packn * 4; + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA = __riscv_vle32_v_f32m1(pA, vl); + _sum0 = __riscv_vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); + _sum1 = __riscv_vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); + _sum2 = __riscv_vfmadd_vf_f32m1(_pA, pB[2], _sum2, vl); + _sum3 = __riscv_vfmadd_vf_f32m1(_pA, pB[3], _sum3, vl); + pA += packn; + pB += 4; + } + + if (k_end) + { + if (out_elempack == packn) + { + __riscv_vse32_v_f32m1(outptr0, _sum0, vl); + __riscv_vse32_v_f32m1(outptr0 + packn, _sum1, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 2, _sum2, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 3, _sum3, vl); + outptr0 += packn * 4; + } + if (out_elempack == 1) + { + vfloat32m1x4_t _sum = __riscv_vcreate_v_f32m1x4(_sum0, _sum1, _sum2, _sum3); + __riscv_vssseg4e32_v_f32m1x4(outptr0, out_hstep * sizeof(float), _sum, vl); + outptr0 += 4; + } } else { - NCNN_LOGE("unsupported vector length"); + __riscv_vse32_v_f32m1(outptr, _sum0, vl); + __riscv_vse32_v_f32m1(outptr + packn, _sum1, vl); + __riscv_vse32_v_f32m1(outptr + packn * 2, _sum2, vl); + __riscv_vse32_v_f32m1(outptr + packn * 3, _sum3, vl); } + + outptr += packn * 4; } for (; jj + 1 < max_jj; jj += 2) { @@ -835,9 +1542,6 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons if (k == 0) { - _sum0 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum1 = __riscv_vfmv_v_f_f32m1(0.f, vl); - if (pC) { if (broadcast_type_C == 0) @@ -863,6 +1567,11 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons pC += 2; } } + else + { + _sum0 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum1 = __riscv_vfmv_v_f_f32m1(0.f, vl); + } } else { @@ -912,8 +1621,6 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons if (k == 0) { - _sum0 = __riscv_vfmv_v_f_f32m1(0.f, vl); - if (pC) { if (broadcast_type_C == 0) @@ -935,6 +1642,10 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons pC += 1; } } + else + { + _sum0 = __riscv_vfmv_v_f_f32m1(0.f, vl); + } } else { @@ -946,9 +1657,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons for (; kk < max_kk; kk += 1) { vfloat32m1_t _pA = __riscv_vle32_v_f32m1(pA, vl); - vfloat32m1_t _pB = __riscv_vfmv_v_f_f32m1(pB[0], vl); - - _sum0 = __riscv_vfmadd_vv_f32m1(_pA, _pB, _sum0, vl); + _sum0 = __riscv_vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); pA += packn; pB += 1; @@ -998,78 +1707,230 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons int jj = 0; #if __riscv_vector - for (; jj + (packn - 1) < max_jj; jj += packn) + for (; jj + 15 < max_jj; jj += 16) + { + vfloat32m4_t _sum0; + vfloat32m4_t _sum1; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = __riscv_vfmv_v_f_f32m4(pC[0], vl16); + _sum1 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = __riscv_vfmv_v_f_f32m4(pC[0], vl16); + _sum1 = __riscv_vfmv_v_f_f32m4(pC[1], vl16); + } + if (broadcast_type_C == 3) + { + vfloat32m4x2_t _s0 = __riscv_vlseg2e32_v_f32m4x2(pC, vl16); + _sum0 = __riscv_vget_v_f32m4x2_f32m4(_s0, 0); + _sum1 = __riscv_vget_v_f32m4x2_f32m4(_s0, 1); + pC += 16 * 2; + } + if (broadcast_type_C == 4) + { + _sum0 = __riscv_vle32_v_f32m4(pC, vl16); + _sum1 = _sum0; + pC += 16; + } + } + else + { + _sum0 = __riscv_vfmv_v_f_f32m4(0.f, vl16); + _sum1 = __riscv_vfmv_v_f_f32m4(0.f, vl16); + } + } + else + { + _sum0 = __riscv_vle32_v_f32m4(outptr, vl16); + _sum1 = __riscv_vle32_v_f32m4(outptr + 16, vl16); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m4_t _pB = __riscv_vle32_v_f32m4(pB, vl16); + + _sum0 = __riscv_vfmadd_vf_f32m4(_pB, pA[0], _sum0, vl16); + _sum1 = __riscv_vfmadd_vf_f32m4(_pB, pA[1], _sum1, vl16); + + pA += 2; + pB += 16; + } + + if (k_end) + { + // if (out_elempack == 1) + { + __riscv_vse32_v_f32m4(outptr0, _sum0, vl16); + __riscv_vse32_v_f32m4(outptr0 + out_hstep, _sum1, vl16); + outptr0 += 16; + } + } + else + { + __riscv_vse32_v_f32m4(outptr, _sum0, vl16); + __riscv_vse32_v_f32m4(outptr + 16, _sum1, vl16); + } + + outptr += 16 * 2; + } + for (; jj + 7 < max_jj; jj += 8) + { + vfloat32m2_t _sum0; + vfloat32m2_t _sum1; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = __riscv_vfmv_v_f_f32m2(pC[0], vl8); + _sum1 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = __riscv_vfmv_v_f_f32m2(pC[0], vl8); + _sum1 = __riscv_vfmv_v_f_f32m2(pC[1], vl8); + } + if (broadcast_type_C == 3) + { + vfloat32m2x2_t _s0 = __riscv_vlseg2e32_v_f32m2x2(pC, vl8); + _sum0 = __riscv_vget_v_f32m2x2_f32m2(_s0, 0); + _sum1 = __riscv_vget_v_f32m2x2_f32m2(_s0, 1); + pC += 8 * 2; + } + if (broadcast_type_C == 4) + { + _sum0 = __riscv_vle32_v_f32m2(pC, vl8); + _sum1 = _sum0; + pC += 8; + } + } + else + { + _sum0 = __riscv_vfmv_v_f_f32m2(0.f, vl8); + _sum1 = __riscv_vfmv_v_f_f32m2(0.f, vl8); + } + } + else + { + _sum0 = __riscv_vle32_v_f32m2(outptr, vl8); + _sum1 = __riscv_vle32_v_f32m2(outptr + 8, vl8); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m2_t _pB = __riscv_vle32_v_f32m2(pB, vl8); + + _sum0 = __riscv_vfmadd_vf_f32m2(_pB, pA[0], _sum0, vl8); + _sum1 = __riscv_vfmadd_vf_f32m2(_pB, pA[1], _sum1, vl8); + + pA += 2; + pB += 8; + } + + if (k_end) + { + // if (out_elempack == 1) + { + __riscv_vse32_v_f32m2(outptr0, _sum0, vl8); + __riscv_vse32_v_f32m2(outptr0 + out_hstep, _sum1, vl8); + outptr0 += 8; + } + } + else + { + __riscv_vse32_v_f32m2(outptr, _sum0, vl8); + __riscv_vse32_v_f32m2(outptr + 8, _sum1, vl8); + } + + outptr += 8 * 2; + } + for (; jj + 3 < max_jj; jj += 4) { vfloat32m1_t _sum0; vfloat32m1_t _sum1; if (k == 0) { - _sum0 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum1 = __riscv_vfmv_v_f_f32m1(0.f, vl); - if (pC) { if (broadcast_type_C == 0) { - _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum1 = __riscv_vfmv_v_f_f32m1(pC[0], vl); + _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl4); + _sum1 = __riscv_vfmv_v_f_f32m1(pC[0], vl4); } if (broadcast_type_C == 1 || broadcast_type_C == 2) { - _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum1 = __riscv_vfmv_v_f_f32m1(pC[1], vl); + _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl4); + _sum1 = __riscv_vfmv_v_f_f32m1(pC[1], vl4); } if (broadcast_type_C == 3) { - vfloat32m1x2_t _s0 = __riscv_vlseg2e32_v_f32m1x2(pC, vl); + vfloat32m1x2_t _s0 = __riscv_vlseg2e32_v_f32m1x2(pC, vl4); _sum0 = __riscv_vget_v_f32m1x2_f32m1(_s0, 0); _sum1 = __riscv_vget_v_f32m1x2_f32m1(_s0, 1); - pC += packn * 2; + pC += 4 * 2; } if (broadcast_type_C == 4) { - _sum0 = __riscv_vle32_v_f32m1(pC, vl); + _sum0 = __riscv_vle32_v_f32m1(pC, vl4); _sum1 = _sum0; - pC += packn; + pC += 4; } } + else + { + _sum0 = __riscv_vfmv_v_f_f32m1(0.f, vl4); + _sum1 = __riscv_vfmv_v_f_f32m1(0.f, vl4); + } } else { - vfloat32m1x2_t _s0 = __riscv_vlseg2e32_v_f32m1x2(outptr, vl); - _sum0 = __riscv_vget_v_f32m1x2_f32m1(_s0, 0); - _sum1 = __riscv_vget_v_f32m1x2_f32m1(_s0, 1); + _sum0 = __riscv_vle32_v_f32m1(outptr, vl4); + _sum1 = __riscv_vle32_v_f32m1(outptr + 4, vl4); } const float* pA = pAT; int kk = 0; for (; kk < max_kk; kk += 1) { - vfloat32m1_t _pB = __riscv_vle32_v_f32m1(pB, vl); + vfloat32m1_t _pB = __riscv_vle32_v_f32m1(pB, vl4); - _sum0 = __riscv_vfmadd_vf_f32m1(_pB, pA[0], _sum0, vl); - _sum1 = __riscv_vfmadd_vf_f32m1(_pB, pA[1], _sum1, vl); + _sum0 = __riscv_vfmadd_vf_f32m1(_pB, pA[0], _sum0, vl4); + _sum1 = __riscv_vfmadd_vf_f32m1(_pB, pA[1], _sum1, vl4); pA += 2; - pB += packn; + pB += 4; } if (k_end) { // if (out_elempack == 1) { - __riscv_vse32_v_f32m1(outptr0, _sum0, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep, _sum1, vl); - outptr0 += packn; + __riscv_vse32_v_f32m1(outptr0, _sum0, vl4); + __riscv_vse32_v_f32m1(outptr0 + out_hstep, _sum1, vl4); + outptr0 += 4; } } else { - __riscv_vsseg2e32_v_f32m1x2(outptr, __riscv_vcreate_v_f32m1x2(_sum0, _sum1), vl); + __riscv_vse32_v_f32m1(outptr, _sum0, vl4); + __riscv_vse32_v_f32m1(outptr + 4, _sum1, vl4); } - outptr += packn * 2; + outptr += 4 * 2; } #endif // __riscv_vector for (; jj + 1 < max_jj; jj += 2) @@ -1123,8 +1984,8 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons else { sum00 = outptr[0]; - sum01 = outptr[1]; - sum10 = outptr[2]; + sum10 = outptr[1]; + sum01 = outptr[2]; sum11 = outptr[3]; } @@ -1154,8 +2015,8 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons else { outptr[0] = sum00; - outptr[1] = sum01; - outptr[2] = sum10; + outptr[1] = sum10; + outptr[2] = sum01; outptr[3] = sum11; } @@ -1185,127 +2046,238 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } if (broadcast_type_C == 3) { - sum0 = pC[0]; - sum1 = pC[1]; - pC += 2; + sum0 = pC[0]; + sum1 = pC[1]; + pC += 2; + } + if (broadcast_type_C == 4) + { + sum0 = pC[0]; + sum1 = pC[0]; + pC += 1; + } + } + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[1] * pB[0]; + pA += 2; + pB += 1; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum0; + outptr0[out_hstep] = sum1; + outptr0++; + } + } + else + { + outptr[0] = sum0; + outptr[1] = sum1; + } + + outptr += 2; + } + + pAT += max_kk * 2; + } + for (; ii < max_ii; ii += 1) + { + float* outptr0 = (float*)top_blob + (i + ii) * out_hstep + j; + + const float* pB = pBT; + + if (pC) + { + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)CT_tile + i + ii; + } + if (broadcast_type_C == 4) + { + pC = (const float*)CT_tile + j; + } + } + + int jj = 0; +#if __riscv_vector + for (; jj + 15 < max_jj; jj += 16) + { + vfloat32m4_t _sum; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum = __riscv_vfmv_v_f_f32m4(pC[0], vl16); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + _sum = __riscv_vle32_v_f32m4(pC, vl16); + pC += 16; + } + } + else + { + _sum = __riscv_vfmv_v_f_f32m4(0.f, vl16); + } + } + else + { + _sum = __riscv_vle32_v_f32m4(outptr, vl16); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m4_t _pB = __riscv_vle32_v_f32m4(pB, vl16); + + _sum = __riscv_vfmadd_vf_f32m4(_pB, pA[0], _sum, vl16); + + pA += 1; + pB += 16; + } + + if (k_end) + { + // if (out_elempack == 1) + { + __riscv_vse32_v_f32m4(outptr0, _sum, vl16); + outptr0 += 16; + } + } + else + { + __riscv_vse32_v_f32m4(outptr, _sum, vl16); + } + + outptr += 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + vfloat32m2_t _sum; + + if (k == 0) + { + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum = __riscv_vfmv_v_f_f32m2(pC[0], vl8); } - if (broadcast_type_C == 4) + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - sum0 = pC[0]; - sum1 = pC[0]; - pC += 1; + _sum = __riscv_vle32_v_f32m2(pC, vl8); + pC += 8; } } + else + { + _sum = __riscv_vfmv_v_f_f32m2(0.f, vl8); + } } else { - sum0 = outptr[0]; - sum1 = outptr[1]; + _sum = __riscv_vle32_v_f32m2(outptr, vl8); } const float* pA = pAT; int kk = 0; for (; kk < max_kk; kk += 1) { - sum0 += pA[0] * pB[0]; - sum1 += pA[1] * pB[0]; - pA += 2; - pB += 1; + vfloat32m2_t _pB = __riscv_vle32_v_f32m2(pB, vl8); + + _sum = __riscv_vfmadd_vf_f32m2(_pB, pA[0], _sum, vl8); + + pA += 1; + pB += 8; } if (k_end) { // if (out_elempack == 1) { - outptr0[0] = sum0; - outptr0[out_hstep] = sum1; - outptr0++; + __riscv_vse32_v_f32m2(outptr0, _sum, vl8); + outptr0 += 8; } } else { - outptr[0] = sum0; - outptr[1] = sum1; + __riscv_vse32_v_f32m2(outptr, _sum, vl8); } - outptr += 2; - } - - pAT += max_kk * 2; - } - for (; ii < max_ii; ii += 1) - { - float* outptr0 = (float*)top_blob + (i + ii) * out_hstep + j; - - const float* pB = pBT; - - if (pC) - { - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - pC = (const float*)CT_tile + i + ii; - } - if (broadcast_type_C == 4) - { - pC = (const float*)CT_tile + j; - } + outptr += 8; } - - int jj = 0; -#if __riscv_vector - for (; jj + (packn - 1) < max_jj; jj += packn) + for (; jj + 3 < max_jj; jj += 4) { vfloat32m1_t _sum; if (k == 0) { - _sum = __riscv_vfmv_v_f_f32m1(0.f, vl); - if (pC) { if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - _sum = __riscv_vfmv_v_f_f32m1(pC[0], vl); + _sum = __riscv_vfmv_v_f_f32m1(pC[0], vl4); } if (broadcast_type_C == 3 || broadcast_type_C == 4) { - _sum = __riscv_vle32_v_f32m1(pC, vl); - pC += packn; + _sum = __riscv_vle32_v_f32m1(pC, vl4); + pC += 4; } } + else + { + _sum = __riscv_vfmv_v_f_f32m1(0.f, vl4); + } } else { - _sum = __riscv_vle32_v_f32m1(outptr, vl); + _sum = __riscv_vle32_v_f32m1(outptr, vl4); } const float* pA = pAT; int kk = 0; for (; kk < max_kk; kk += 1) { - vfloat32m1_t _pB = __riscv_vle32_v_f32m1(pB, vl); - vfloat32m1_t _pA = __riscv_vfmv_v_f_f32m1(pA[0], vl); + vfloat32m1_t _pB = __riscv_vle32_v_f32m1(pB, vl4); - _sum = __riscv_vfmadd_vv_f32m1(_pA, _pB, _sum, vl); + _sum = __riscv_vfmadd_vf_f32m1(_pB, pA[0], _sum, vl4); pA += 1; - pB += packn; + pB += 4; } if (k_end) { // if (out_elempack == 1) { - __riscv_vse32_v_f32m1(outptr0, _sum, vl); - outptr0 += packn; + __riscv_vse32_v_f32m1(outptr0, _sum, vl4); + outptr0 += 4; } } else { - __riscv_vse32_v_f32m1(outptr, _sum, vl); + __riscv_vse32_v_f32m1(outptr, _sum, vl4); } - outptr += packn; + outptr += 4; } #endif // __riscv_vector for (; jj + 1 < max_jj; jj += 2) @@ -1434,12 +2406,14 @@ static void get_optimal_tile_mnk(int M, int N, int K, int constant_TILE_M, int c #if __riscv_vector const int packn = csrr_vlenb() / 4; + const int packn_n = 16; #else const int packn = 4; + const int packn_n = 4; #endif TILE_M = std::max(packn, tile_size / packn * packn); - TILE_N = std::max(packn, tile_size / packn * packn); + TILE_N = std::max(packn_n, tile_size / packn_n * packn_n); TILE_K = std::max(packn, tile_size / packn * packn); if (K > 0) @@ -1451,7 +2425,7 @@ static void get_optimal_tile_mnk(int M, int N, int K, int constant_TILE_M, int c { tile_size = (int)((float)l2_cache_size / 2 / sizeof(float) / TILE_K); TILE_M = std::max(packn, tile_size / packn * packn); - TILE_N = std::max(packn, tile_size / packn * packn); + TILE_N = std::max(packn_n, tile_size / packn_n * packn_n); } } @@ -1466,7 +2440,7 @@ static void get_optimal_tile_mnk(int M, int N, int K, int constant_TILE_M, int c if (N > 0) { int nn_N = (N + TILE_N - 1) / TILE_N; - TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + (packn - 1)) / packn * packn); + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + (packn_n - 1)) / packn_n * packn_n); } if (nT > 1) @@ -1482,7 +2456,7 @@ static void get_optimal_tile_mnk(int M, int N, int K, int constant_TILE_M, int c if (constant_TILE_N > 0) { - TILE_N = (constant_TILE_N + (packn - 1)) / packn * packn; + TILE_N = (constant_TILE_N + (packn_n - 1)) / packn_n * packn_n; } if (constant_TILE_K > 0) @@ -1553,25 +2527,61 @@ static int gemm_riscv(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, i return -100; } - #pragma omp parallel for num_threads(nT) - for (int ppi = 0; ppi < nn_M; ppi++) + if (nT > nn_M) { - const int i = ppi * TILE_M; + Mat AT(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 4u, opt.workspace_allocator); + if (AT.empty()) + return -100; - // shadowed variable for less openmp task args - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; - const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + const int nn_MK = nn_M * nn_K; - const int max_ii = std::min((M - i), TILE_M); + // pack A + #pragma omp parallel for num_threads(nT) + for (int ppik = 0; ppik < nn_MK; ppik++) + { + const int ppi = ppik / nn_K; + const int ppk = ppik % nn_K; - Mat topT_tile; - if (K > TILE_K || broadcast_type_C == 3 || output_transpose) - topT_tile = topT.channel(get_omp_thread_num()); + const int i = ppi * TILE_M; + const int k = ppk * TILE_K; + + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + if (transA) + { + transpose_pack_A_tile(A, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile(A, AT_tile, i, max_ii, k, max_kk); + } + } - for (int j = 0; j < N; j += TILE_N) + const int nn_MN = nn_M * nn_N; + + #pragma omp parallel for num_threads(nT) + for (int ppij = 0; ppij < nn_MN; ppij++) { + const int ppi = ppij / nn_N; + const int ppj = ppij % nn_N; + + const int i = ppi * TILE_M; + const int j = ppj * TILE_N; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); const int max_jj = std::min((N - j), TILE_N); + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + if (broadcast_type_C == 3) { pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); @@ -1585,22 +2595,10 @@ static int gemm_riscv(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, i // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); - Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); - if (j == 0) - { - if (transA) - { - transpose_pack_A_tile(A, AT_tile, i, max_ii, k, max_kk); - } - else - { - pack_A_tile(A, AT_tile, i, max_ii, k, max_kk); - } - } - bool k_end = !output_transpose && k + TILE_K >= K; gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end); } @@ -1611,6 +2609,67 @@ static int gemm_riscv(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, i } } } + else + { + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (j == 0) + { + if (transA) + { + transpose_pack_A_tile(A, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile(A, AT_tile, i, max_ii, k, max_kk); + } + } + + bool k_end = !output_transpose && k + TILE_K >= K; + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end); + } + + if (output_transpose) + { + transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj); + } + } + } + } return 0; } @@ -1669,44 +2728,46 @@ static int gemm_AT_riscv(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blo return -100; } + const int nn_MN = nn_M * nn_N; #pragma omp parallel for num_threads(nT) - for (int ppi = 0; ppi < nn_M; ppi++) + for (int ppij = 0; ppij < nn_MN; ppij++) { + const int ppi = ppij / nn_N; + const int ppj = ppij % nn_N; + const int i = ppi * TILE_M; + const int j = ppj * TILE_N; const int max_ii = std::min((M - i), TILE_M); + const int max_jj = std::min((N - j), TILE_N); Mat topT_tile; if (K > TILE_K || broadcast_type_C == 3 || output_transpose) topT_tile = topT.channel(get_omp_thread_num()); - for (int j = 0; j < N; j += TILE_N) - { - const int max_jj = std::min((N - j), TILE_N); - if (broadcast_type_C == 3) - { - pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); - } + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); + } - const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; - for (int k = 0; k < K; k += TILE_K) - { - const int max_kk = std::min((K - k), TILE_K); + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); - // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); - Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); - Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); - bool k_end = !output_transpose && k + TILE_K >= K; - gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end); - } + bool k_end = !output_transpose && k + TILE_K >= K; + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end); + } - if (output_transpose) - { - transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj); - } + if (output_transpose) + { + transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj); } } @@ -1725,11 +2786,8 @@ static int gemm_BT_riscv(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blo // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); int nn_M = (M + TILE_M - 1) / TILE_M; - // int nn_N = (N + TILE_N - 1) / TILE_N; - - Mat ATX(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 4u, opt.workspace_allocator); - if (ATX.empty()) - return -100; + int nn_N = (N + TILE_N - 1) / TILE_N; + int nn_K = (K + TILE_K - 1) / TILE_K; Mat topT; if (K > TILE_K || broadcast_type_C == 3 || output_transpose) @@ -1739,25 +2797,61 @@ static int gemm_BT_riscv(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blo return -100; } - #pragma omp parallel for num_threads(nT) - for (int ppi = 0; ppi < nn_M; ppi++) + if (nT > nn_M) { - const int i = ppi * TILE_M; + Mat AT(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 4u, opt.workspace_allocator); + if (AT.empty()) + return -100; - // shadowed variable for less openmp task args - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; - const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + const int nn_MK = nn_M * nn_K; - const int max_ii = std::min((M - i), TILE_M); + // pack A + #pragma omp parallel for num_threads(nT) + for (int ppik = 0; ppik < nn_MK; ppik++) + { + const int ppi = ppik / nn_K; + const int ppk = ppik % nn_K; - Mat topT_tile; - if (K > TILE_K || broadcast_type_C == 3 || output_transpose) - topT_tile = topT.channel(get_omp_thread_num()); + const int i = ppi * TILE_M; + const int k = ppk * TILE_K; + + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + if (transA) + { + transpose_pack_A_tile(A, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile(A, AT_tile, i, max_ii, k, max_kk); + } + } - for (int j = 0; j < N; j += TILE_N) + const int nn_MN = nn_M * nn_N; + + #pragma omp parallel for num_threads(nT) + for (int ppij = 0; ppij < nn_MN; ppij++) { + const int ppi = ppij / nn_N; + const int ppj = ppij % nn_N; + + const int i = ppi * TILE_M; + const int j = ppj * TILE_N; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); const int max_jj = std::min((N - j), TILE_N); + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + if (broadcast_type_C == 3) { pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); @@ -1771,22 +2865,10 @@ static int gemm_BT_riscv(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blo // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); - Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); - if (j == 0) - { - if (transA) - { - transpose_pack_A_tile(A, AT_tile, i, max_ii, k, max_kk); - } - else - { - pack_A_tile(A, AT_tile, i, max_ii, k, max_kk); - } - } - bool k_end = !output_transpose && k + TILE_K >= K; gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end); @@ -1798,6 +2880,72 @@ static int gemm_BT_riscv(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blo } } } + else + { + Mat ATX(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 4u, opt.workspace_allocator); + if (ATX.empty()) + return -100; + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (j == 0) + { + if (transA) + { + transpose_pack_A_tile(A, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile(A, AT_tile, i, max_ii, k, max_kk); + } + } + + bool k_end = !output_transpose && k + TILE_K >= K; + + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end); + } + + if (output_transpose) + { + transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj); + } + } + } + } return 0; } @@ -1812,7 +2960,7 @@ static int gemm_AT_BT_riscv(const Mat& AT, const Mat& BT, const Mat& C, Mat& top // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); int nn_M = (M + TILE_M - 1) / TILE_M; - // int nn_N = (N + TILE_N - 1) / TILE_N; + int nn_N = (N + TILE_N - 1) / TILE_N; Mat topT; if (K > TILE_K || broadcast_type_C == 3 || output_transpose) @@ -1822,47 +2970,48 @@ static int gemm_AT_BT_riscv(const Mat& AT, const Mat& BT, const Mat& C, Mat& top return -100; } + const int nn_MN = nn_M * nn_N; #pragma omp parallel for num_threads(nT) - for (int ppi = 0; ppi < nn_M; ppi++) + for (int ppij = 0; ppij < nn_MN; ppij++) { + const int ppi = ppij / nn_N; + const int ppj = ppij % nn_N; + const int i = ppi * TILE_M; + const int j = ppj * TILE_N; const int max_ii = std::min((M - i), TILE_M); + const int max_jj = std::min((N - j), TILE_N); Mat topT_tile; if (K > TILE_K || broadcast_type_C == 3 || output_transpose) topT_tile = topT.channel(get_omp_thread_num()); - for (int j = 0; j < N; j += TILE_N) + if (broadcast_type_C == 3) { - const int max_jj = std::min((N - j), TILE_N); - - if (broadcast_type_C == 3) - { - pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); - } + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); + } - const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; - for (int k = 0; k < K; k += TILE_K) - { - const int max_kk = std::min((K - k), TILE_K); + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); - // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); - Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); - Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); - bool k_end = !output_transpose && k + TILE_K >= K; + bool k_end = !output_transpose && k + TILE_K >= K; - gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end); - } + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end); + } - if (output_transpose) - { - transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj); - } + if (output_transpose) + { + transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj); } } @@ -1883,6 +3032,9 @@ int Gemm_riscv::create_pipeline(const Option& opt) #if NCNN_ZFH if (support_fp16_storage && opt.use_fp16_storage) { + if (opt.use_fp16_arithmetic) + return create_pipeline_fp16sa(opt); + return create_pipeline_fp16s(opt); } #endif @@ -2028,6 +3180,9 @@ int Gemm_riscv::forward(const std::vector& bottom_blobs, std::vector& int elembits = bottom_blob.elembits(); if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) { + if (opt.use_fp16_arithmetic) + return forward_fp16sa(bottom_blobs, top_blobs, opt); + return forward_fp16s(bottom_blobs, top_blobs, opt); } #endif diff --git a/src/layer/riscv/gemm_riscv.h b/src/layer/riscv/gemm_riscv.h index 2ef61f268927..68dd1aa51bcc 100644 --- a/src/layer/riscv/gemm_riscv.h +++ b/src/layer/riscv/gemm_riscv.h @@ -19,6 +19,8 @@ class Gemm_riscv : public Gemm protected: #if NCNN_ZFH + int create_pipeline_fp16sa(const Option& opt); + int forward_fp16sa(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; int create_pipeline_fp16s(const Option& opt); int forward_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; #endif diff --git a/src/layer/riscv/gemm_riscv_zfh.cpp b/src/layer/riscv/gemm_riscv_zfh.cpp index becc3d243a09..2b732b63273b 100644 --- a/src/layer/riscv/gemm_riscv_zfh.cpp +++ b/src/layer/riscv/gemm_riscv_zfh.cpp @@ -16,6 +16,7 @@ namespace ncnn { #if NCNN_ZFH #include "gemm_bf16s_fp16s.h" #include "gemm_fp16s.h" +#include "gemm_fp16sa.h" static int gemm_riscv_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) { @@ -76,27 +77,64 @@ static int gemm_riscv_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_b return -100; } - #pragma omp parallel for num_threads(nT) - for (int ppi = 0; ppi < nn_M; ppi++) + if (nT > nn_M) { - const int i = ppi * TILE_M; + Mat AT(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 2u, opt.workspace_allocator); + if (AT.empty()) + return -100; - // shadowed variable for less openmp task args - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; - const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + const int nn_MK = nn_M * nn_K; - const int max_ii = std::min((M - i), TILE_M); + // pack A + #pragma omp parallel for num_threads(nT) + for (int ppik = 0; ppik < nn_MK; ppik++) + { + const int ppi = ppik / nn_K; + const int ppk = ppik % nn_K; - Mat topT_tile; - if (K > TILE_K || broadcast_type_C == 3 || output_transpose) - topT_tile = topT.channel(get_omp_thread_num()); + const int i = ppi * TILE_M; + const int k = ppk * TILE_K; + + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + if (transA) + { + transpose_pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + } - for (int j = 0; j < N; j += TILE_N) + const int nn_MN = nn_M * nn_N; + + #pragma omp parallel for num_threads(nT) + for (int ppij = 0; ppij < nn_MN; ppij++) { + const int ppi = ppij / nn_N; + const int ppj = ppij % nn_N; + + const int i = ppi * TILE_M; + const int j = ppj * TILE_N; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); const int max_jj = std::min((N - j), TILE_N); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + if (broadcast_type_C == 3) { - pack_A_tile_fp16s(C, topT_tile, i, max_ii, j, max_jj); + pack_A_tile_fp32(C, topT_tile, i, max_ii, j, max_jj); } const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; @@ -106,22 +144,10 @@ static int gemm_riscv_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_b // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); - Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); - if (j == 0) - { - if (transA) - { - transpose_pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); - } - else - { - pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); - } - } - bool k_end = !output_transpose && k + TILE_K >= K; float _alpha = k + TILE_K >= K ? alpha : 1.f; @@ -134,6 +160,67 @@ static int gemm_riscv_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_b } } } + else + { + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + if (broadcast_type_C == 3) + { + pack_A_tile_fp32(C, topT_tile, i, max_ii, j, max_jj); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (j == 0) + { + if (transA) + { + transpose_pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + } + + bool k_end = !output_transpose && k + TILE_K >= K; + float _alpha = k + TILE_K >= K ? alpha : 1.f; + + gemm_transB_packed_tile_fp16s(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, _alpha, i, max_ii, j, max_jj, k, max_kk, k_end); + } + + if (output_transpose) + { + transpose_unpack_output_tile_fp32_to_fp16(topT_tile, top_blob, i, max_ii, j, max_jj); + } + } + } + } return 0; } @@ -192,48 +279,49 @@ static int gemm_AT_riscv_fp16s(const Mat& AT, const Mat& B, const Mat& C, Mat& t return -100; } + const int nn_MN = nn_M * nn_N; #pragma omp parallel for num_threads(nT) - for (int ppi = 0; ppi < nn_M; ppi++) + for (int ppij = 0; ppij < nn_MN; ppij++) { + const int ppi = ppij / nn_N; + const int ppj = ppij % nn_N; + const int i = ppi * TILE_M; + const int j = ppj * TILE_N; const int max_ii = std::min((M - i), TILE_M); + const int max_jj = std::min((N - j), TILE_N); Mat topT_tile; if (K > TILE_K || broadcast_type_C == 3 || output_transpose) topT_tile = topT.channel(get_omp_thread_num()); - for (int j = 0; j < N; j += TILE_N) + if (broadcast_type_C == 3) { - const int max_jj = std::min((N - j), TILE_N); - - if (broadcast_type_C == 3) - { - pack_A_tile_fp16s(C, topT_tile, i, max_ii, j, max_jj); - } + pack_A_tile_fp32(C, topT_tile, i, max_ii, j, max_jj); + } - const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; - for (int k = 0; k < K; k += TILE_K) - { - const int max_kk = std::min((K - k), TILE_K); + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); - // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); - Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); - Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); - bool k_end = !output_transpose && k + TILE_K >= K; - float _alpha = k + TILE_K >= K ? alpha : 1.f; + bool k_end = !output_transpose && k + TILE_K >= K; + float _alpha = k + TILE_K >= K ? alpha : 1.f; - gemm_transB_packed_tile_fp16s(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, _alpha, i, max_ii, j, max_jj, k, max_kk, k_end); - } + gemm_transB_packed_tile_fp16s(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, _alpha, i, max_ii, j, max_jj, k, max_kk, k_end); + } - if (output_transpose) - { - transpose_unpack_output_tile_fp32_to_fp16(topT_tile, top_blob, i, max_ii, j, max_jj); - } + if (output_transpose) + { + transpose_unpack_output_tile_fp32_to_fp16(topT_tile, top_blob, i, max_ii, j, max_jj); } } @@ -252,11 +340,8 @@ static int gemm_BT_riscv_fp16s(const Mat& A, const Mat& BT, const Mat& C, Mat& t // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); int nn_M = (M + TILE_M - 1) / TILE_M; - // int nn_N = (N + TILE_N - 1) / TILE_N; - - Mat ATX(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 2u, opt.workspace_allocator); - if (ATX.empty()) - return -100; + int nn_N = (N + TILE_N - 1) / TILE_N; + int nn_K = (K + TILE_K - 1) / TILE_K; Mat topT; if (K > TILE_K || broadcast_type_C == 3 || output_transpose) @@ -266,28 +351,64 @@ static int gemm_BT_riscv_fp16s(const Mat& A, const Mat& BT, const Mat& C, Mat& t return -100; } - #pragma omp parallel for num_threads(nT) - for (int ppi = 0; ppi < nn_M; ppi++) + if (nT > nn_M) { - const int i = ppi * TILE_M; + Mat AT(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 2u, opt.workspace_allocator); + if (AT.empty()) + return -100; - // shadowed variable for less openmp task args - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; - const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + const int nn_MK = nn_M * nn_K; - const int max_ii = std::min((M - i), TILE_M); + // pack A + #pragma omp parallel for num_threads(nT) + for (int ppik = 0; ppik < nn_MK; ppik++) + { + const int ppi = ppik / nn_K; + const int ppk = ppik % nn_K; - Mat topT_tile; - if (K > TILE_K || broadcast_type_C == 3 || output_transpose) - topT_tile = topT.channel(get_omp_thread_num()); + const int i = ppi * TILE_M; + const int k = ppk * TILE_K; + + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + if (transA) + { + transpose_pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + } - for (int j = 0; j < N; j += TILE_N) + const int nn_MN = nn_M * nn_N; + + #pragma omp parallel for num_threads(nT) + for (int ppij = 0; ppij < nn_MN; ppij++) { + const int ppi = ppij / nn_N; + const int ppj = ppij % nn_N; + + const int i = ppi * TILE_M; + const int j = ppj * TILE_N; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); const int max_jj = std::min((N - j), TILE_N); + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + if (broadcast_type_C == 3) { - pack_A_tile_fp16s(C, topT_tile, i, max_ii, j, max_jj); + pack_A_tile_fp32(C, topT_tile, i, max_ii, j, max_jj); } const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; @@ -298,22 +419,10 @@ static int gemm_BT_riscv_fp16s(const Mat& A, const Mat& BT, const Mat& C, Mat& t // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); - Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); - if (j == 0) - { - if (transA) - { - transpose_pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); - } - else - { - pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); - } - } - bool k_end = !output_transpose && k + TILE_K >= K; float _alpha = k + TILE_K >= K ? alpha : 1.f; @@ -326,6 +435,73 @@ static int gemm_BT_riscv_fp16s(const Mat& A, const Mat& BT, const Mat& C, Mat& t } } } + else + { + Mat ATX(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 2u, opt.workspace_allocator); + if (ATX.empty()) + return -100; + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile_fp32(C, topT_tile, i, max_ii, j, max_jj); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (j == 0) + { + if (transA) + { + transpose_pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + } + + bool k_end = !output_transpose && k + TILE_K >= K; + float _alpha = k + TILE_K >= K ? alpha : 1.f; + + gemm_transB_packed_tile_fp16s(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, _alpha, i, max_ii, j, max_jj, k, max_kk, k_end); + } + + if (output_transpose) + { + transpose_unpack_output_tile_fp32_to_fp16(topT_tile, top_blob, i, max_ii, j, max_jj); + } + } + } + } return 0; } @@ -340,7 +516,7 @@ static int gemm_AT_BT_riscv_fp16s(const Mat& AT, const Mat& BT, const Mat& C, Ma // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); int nn_M = (M + TILE_M - 1) / TILE_M; - // int nn_N = (N + TILE_N - 1) / TILE_N; + int nn_N = (N + TILE_N - 1) / TILE_N; Mat topT; if (K > TILE_K || broadcast_type_C == 3 || output_transpose) @@ -350,28 +526,175 @@ static int gemm_AT_BT_riscv_fp16s(const Mat& AT, const Mat& BT, const Mat& C, Ma return -100; } + const int nn_MN = nn_M * nn_N; #pragma omp parallel for num_threads(nT) - for (int ppi = 0; ppi < nn_M; ppi++) + for (int ppij = 0; ppij < nn_MN; ppij++) { + const int ppi = ppij / nn_N; + const int ppj = ppij % nn_N; + const int i = ppi * TILE_M; + const int j = ppj * TILE_N; const int max_ii = std::min((M - i), TILE_M); + const int max_jj = std::min((N - j), TILE_N); Mat topT_tile; if (K > TILE_K || broadcast_type_C == 3 || output_transpose) topT_tile = topT.channel(get_omp_thread_num()); - for (int j = 0; j < N; j += TILE_N) + if (broadcast_type_C == 3) + { + pack_A_tile_fp32(C, topT_tile, i, max_ii, j, max_jj); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + bool k_end = !output_transpose && k + TILE_K >= K; + float _alpha = k + TILE_K >= K ? alpha : 1.f; + + gemm_transB_packed_tile_fp16s(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, _alpha, i, max_ii, j, max_jj, k, max_kk, k_end); + } + + if (output_transpose) + { + transpose_unpack_output_tile_fp32_to_fp16(topT_tile, top_blob, i, max_ii, j, max_jj); + } + } + + return 0; +} + +static int gemm_riscv_fp16sa(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +{ + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_fp16sa(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + int nn_N = (N + TILE_N - 1) / TILE_N; + int nn_K = (K + TILE_K - 1) / TILE_K; + + Mat ATX(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 2u, opt.workspace_allocator); + if (ATX.empty()) + return -100; + Mat BT(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 2u, opt.workspace_allocator); + if (BT.empty()) + return -100; + + const int nn_NK = nn_N * nn_K; + + // pack B + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + { + pack_B_tile_fp16sa(B, BT_tile, j, max_jj, k, max_kk); + } + else + { + transpose_pack_B_tile_fp16sa(B, BT_tile, j, max_jj, k, max_kk); + } + } + + Mat topT; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + { + topT.create(TILE_N * TILE_M, 1, nT, 2u, opt.workspace_allocator); + if (topT.empty()) + return -100; + } + + if (nT > nn_M) + { + Mat AT(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 2u, opt.workspace_allocator); + if (AT.empty()) + return -100; + + const int nn_MK = nn_M * nn_K; + + // pack A + #pragma omp parallel for num_threads(nT) + for (int ppik = 0; ppik < nn_MK; ppik++) + { + const int ppi = ppik / nn_K; + const int ppk = ppik % nn_K; + + const int i = ppi * TILE_M; + const int k = ppk * TILE_K; + + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + if (transA) + { + transpose_pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + } + + const int nn_MN = nn_M * nn_N; + + #pragma omp parallel for num_threads(nT) + for (int ppij = 0; ppij < nn_MN; ppij++) { + const int ppi = ppij / nn_N; + const int ppj = ppij % nn_N; + + const int i = ppi * TILE_M; + const int j = ppj * TILE_N; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); const int max_jj = std::min((N - j), TILE_N); + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + if (broadcast_type_C == 3) { - pack_A_tile_fp16s(C, topT_tile, i, max_ii, j, max_jj); + pack_A_tile_bf16_fp16(C, topT_tile, i, max_ii, j, max_jj); } const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; - for (int k = 0; k < K; k += TILE_K) { const int max_kk = std::min((K - k), TILE_K); @@ -385,19 +708,739 @@ static int gemm_AT_BT_riscv_fp16s(const Mat& AT, const Mat& BT, const Mat& C, Ma bool k_end = !output_transpose && k + TILE_K >= K; float _alpha = k + TILE_K >= K ? alpha : 1.f; - gemm_transB_packed_tile_fp16s(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, _alpha, i, max_ii, j, max_jj, k, max_kk, k_end); + gemm_transB_packed_tile_fp16sa(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, _alpha, i, max_ii, j, max_jj, k, max_kk, k_end); } if (output_transpose) { - transpose_unpack_output_tile_fp32_to_fp16(topT_tile, top_blob, i, max_ii, j, max_jj); + transpose_unpack_output_tile_fp16sa(topT_tile, top_blob, i, max_ii, j, max_jj); } } } - + else + { + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + if (broadcast_type_C == 3) + { + pack_A_tile_bf16_fp16(C, topT_tile, i, max_ii, j, max_jj); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (j == 0) + { + if (transA) + { + transpose_pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + } + + bool k_end = !output_transpose && k + TILE_K >= K; + float _alpha = k + TILE_K >= K ? alpha : 1.f; + + gemm_transB_packed_tile_fp16sa(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, _alpha, i, max_ii, j, max_jj, k, max_kk, k_end); + } + + if (output_transpose) + { + transpose_unpack_output_tile_fp16sa(topT_tile, top_blob, i, max_ii, j, max_jj); + } + } + } + } + + return 0; +} + +static int gemm_AT_riscv_fp16sa(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +{ + const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_fp16sa(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + int nn_N = (N + TILE_N - 1) / TILE_N; + int nn_K = (K + TILE_K - 1) / TILE_K; + + Mat BT(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 2u, opt.workspace_allocator); + if (BT.empty()) + return -100; + + const int nn_NK = nn_N * nn_K; + + // pack B + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + { + pack_B_tile_fp16sa(B, BT_tile, j, max_jj, k, max_kk); + } + else + { + transpose_pack_B_tile_fp16sa(B, BT_tile, j, max_jj, k, max_kk); + } + } + + Mat topT; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + { + topT.create(TILE_N * TILE_M, 1, nT, 2u, opt.workspace_allocator); + if (topT.empty()) + return -100; + } + + const int nn_MN = nn_M * nn_N; + #pragma omp parallel for num_threads(nT) + for (int ppij = 0; ppij < nn_MN; ppij++) + { + const int ppi = ppij / nn_N; + const int ppj = ppij % nn_N; + + const int i = ppi * TILE_M; + const int j = ppj * TILE_N; + + const int max_ii = std::min((M - i), TILE_M); + const int max_jj = std::min((N - j), TILE_N); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + if (broadcast_type_C == 3) + { + pack_A_tile_bf16_fp16(C, topT_tile, i, max_ii, j, max_jj); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + bool k_end = !output_transpose && k + TILE_K >= K; + float _alpha = k + TILE_K >= K ? alpha : 1.f; + + gemm_transB_packed_tile_fp16sa(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, _alpha, i, max_ii, j, max_jj, k, max_kk, k_end); + } + + if (output_transpose) + { + transpose_unpack_output_tile_fp16sa(topT_tile, top_blob, i, max_ii, j, max_jj); + } + } + + return 0; +} + +static int gemm_BT_riscv_fp16sa(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +{ + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_fp16sa(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + int nn_N = (N + TILE_N - 1) / TILE_N; + int nn_K = (K + TILE_K - 1) / TILE_K; + + Mat topT; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + { + topT.create(TILE_N * TILE_M, 1, nT, 2u, opt.workspace_allocator); + if (topT.empty()) + return -100; + } + + if (nT > nn_M) + { + Mat AT(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 2u, opt.workspace_allocator); + if (AT.empty()) + return -100; + + const int nn_MK = nn_M * nn_K; + + // pack A + #pragma omp parallel for num_threads(nT) + for (int ppik = 0; ppik < nn_MK; ppik++) + { + const int ppi = ppik / nn_K; + const int ppk = ppik % nn_K; + + const int i = ppi * TILE_M; + const int k = ppk * TILE_K; + + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + if (transA) + { + transpose_pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + } + + const int nn_MN = nn_M * nn_N; + + #pragma omp parallel for num_threads(nT) + for (int ppij = 0; ppij < nn_MN; ppij++) + { + const int ppi = ppij / nn_N; + const int ppj = ppij % nn_N; + + const int i = ppi * TILE_M; + const int j = ppj * TILE_N; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); + const int max_jj = std::min((N - j), TILE_N); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + if (broadcast_type_C == 3) + { + pack_A_tile_bf16_fp16(C, topT_tile, i, max_ii, j, max_jj); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + bool k_end = !output_transpose && k + TILE_K >= K; + float _alpha = k + TILE_K >= K ? alpha : 1.f; + + gemm_transB_packed_tile_fp16sa(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, _alpha, i, max_ii, j, max_jj, k, max_kk, k_end); + } + + if (output_transpose) + { + transpose_unpack_output_tile_fp16sa(topT_tile, top_blob, i, max_ii, j, max_jj); + } + } + } + else + { + Mat ATX(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 2u, opt.workspace_allocator); + if (ATX.empty()) + return -100; + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile_bf16_fp16(C, topT_tile, i, max_ii, j, max_jj); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (j == 0) + { + if (transA) + { + transpose_pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + } + + bool k_end = !output_transpose && k + TILE_K >= K; + float _alpha = k + TILE_K >= K ? alpha : 1.f; + + gemm_transB_packed_tile_fp16sa(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, _alpha, i, max_ii, j, max_jj, k, max_kk, k_end); + } + + if (output_transpose) + { + transpose_unpack_output_tile_fp16sa(topT_tile, top_blob, i, max_ii, j, max_jj); + } + } + } + } + + return 0; +} + +static int gemm_AT_BT_riscv_fp16sa(const Mat& AT, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int N, int K, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +{ + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_fp16sa(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + int nn_N = (N + TILE_N - 1) / TILE_N; + + Mat topT; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + { + topT.create(TILE_N * TILE_M, 1, nT, 2u, opt.workspace_allocator); + if (topT.empty()) + return -100; + } + + const int nn_MN = nn_M * nn_N; + #pragma omp parallel for num_threads(nT) + for (int ppij = 0; ppij < nn_MN; ppij++) + { + const int ppi = ppij / nn_N; + const int ppj = ppij % nn_N; + + const int i = ppi * TILE_M; + const int j = ppj * TILE_N; + + const int max_ii = std::min((M - i), TILE_M); + const int max_jj = std::min((N - j), TILE_N); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + if (broadcast_type_C == 3) + { + pack_A_tile_bf16_fp16(C, topT_tile, i, max_ii, j, max_jj); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + bool k_end = !output_transpose && k + TILE_K >= K; + float _alpha = k + TILE_K >= K ? alpha : 1.f; + + gemm_transB_packed_tile_fp16sa(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, _alpha, i, max_ii, j, max_jj, k, max_kk, k_end); + } + + if (output_transpose) + { + transpose_unpack_output_tile_fp16sa(topT_tile, top_blob, i, max_ii, j, max_jj); + } + } + + return 0; +} + +int Gemm_riscv::create_pipeline_fp16sa(const Option& opt) +{ + if (constantA) + { + const int M = constantM; + const int K = constantK; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_fp16sa(M, 0, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + + AT_data.create(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 2u, (Allocator*)0); + if (AT_data.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + Mat AT_tile = AT_data.channel(i / TILE_M).row_range(k / TILE_K, 1); + + if (transA) + { + transpose_pack_A_tile_fp32_to_fp16(A_data, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile_fp32_to_fp16(A_data, AT_tile, i, max_ii, k, max_kk); + } + } + } + + if (opt.lightmode) + A_data.release(); + } + + if (constantB) + { + const int N = constantN; + const int K = constantK; + + Mat B_data_fp16; + cast_float32_to_float16(B_data, B_data_fp16, opt); + if (B_data_fp16.empty()) + return -100; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_fp16sa(0, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_N = (N + TILE_N - 1) / TILE_N; + + BT_data.create(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 2u, (Allocator*)0); + if (BT_data.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_N; ppj++) + { + const int j = ppj * TILE_N; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT_data.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + { + pack_B_tile_fp16sa(B_data_fp16, BT_tile, j, max_jj, k, max_kk); + } + else + { + transpose_pack_B_tile_fp16sa(B_data_fp16, BT_tile, j, max_jj, k, max_kk); + } + } + } + + if (opt.lightmode) + B_data.release(); + } + + if (constantC && constant_broadcast_type_C != -1) + { + cast_float32_to_float16(C_data, CT_data, opt); + if (CT_data.empty()) + return -100; + +#if __riscv_vector + const int packn = csrr_vlenb() / 2; + + if (constant_broadcast_type_C == 3 && opt.use_packing_layout) + { + int C_elempack = constantM % packn == 0 ? packn : 1; + Mat tmp; + convert_packing(CT_data, tmp, C_elempack, opt); + CT_data = tmp; + if (CT_data.empty()) + return -100; + } +#endif // __riscv_vector + + // pre-multiply C with beta + if (beta != 1.f) + { + const int size = CT_data.total() * CT_data.elempack; + __fp16* ptr = CT_data; + for (int i = 0; i < size; i++) + { + ptr[i] *= (__fp16)beta; + } + } + + if (opt.lightmode) + C_data.release(); + } + + if (constantA || constantB || constantC) + { + nT = opt.num_threads; + } + return 0; } +int Gemm_riscv::forward_fp16sa(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + int M; + int N; + if (constantA && constantB) + { + M = constantM; + N = constantN; + } + else if (constantA) + { + const Mat& B = bottom_blobs[0]; + M = constantM; + N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + } + else if (constantB) + { + const Mat& A = bottom_blobs[0]; + M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + N = constantN; + } + else + { + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + } + + Mat C; + int broadcast_type_C = 0; + if (constantC) + { + C = CT_data; + broadcast_type_C = constant_broadcast_type_C; + } + else + { + if (constantA && constantB) + { + C = bottom_blobs.size() == 1 ? bottom_blobs[0] : Mat(); + } + else if (constantA) + { + C = bottom_blobs.size() == 2 ? bottom_blobs[1] : Mat(); + } + else if (constantB) + { + C = bottom_blobs.size() == 2 ? bottom_blobs[1] : Mat(); + } + else + { + C = bottom_blobs.size() == 3 ? bottom_blobs[2] : Mat(); + } + + if (!C.empty()) + { + if (C.dims == 1 && C.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w * C.elempack == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w * C.elempack == N) + { + // N + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h * C.elempack == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h * C.elempack == M) + { + // MxN + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h * C.elempack == 1) + { + // 1xN + broadcast_type_C = 4; + } + + if (beta != 1.f) + { + Mat CT_data; + CT_data.create_like(C, opt.workspace_allocator); + if (CT_data.empty()) + return -100; + + const int size = C.total() * C.elempack; + const __fp16* ptr = C; + __fp16* outptr = CT_data; + for (int i = 0; i < size; i++) + { + outptr[i] = ptr[i] * (__fp16)beta; + } + + C = CT_data; + } + } + } + +#if __riscv_vector + const int packn = csrr_vlenb() / 2; +#endif + + int out_elempack = 1; +#if __riscv_vector + if (opt.use_packing_layout) + { + int outh = output_transpose ? N : M; + out_elempack = outh % packn == 0 ? packn : 1; + } +#endif // __riscv_vector + if (output_elempack) + out_elempack = output_elempack; + size_t out_elemsize = 2u * out_elempack; + + Mat& top_blob = top_blobs[0]; + if (output_transpose) + { + if (output_N1M) + top_blob.create(M, 1, N / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + else + top_blob.create(M, N / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + else + { + if (output_N1M) + top_blob.create(N, 1, M / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + else + top_blob.create(N, M / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + if (top_blob.empty()) + return -100; + + int _nT = nT ? nT : opt.num_threads; + if (nT != 0 && opt.num_threads != nT) + { + // force num_threads the same as in create_pipeline + // so we could use pre-packed A/B from the same tile config + NCNN_LOGE("opt.num_threads %d changed, gemm will use load-time value %d", opt.num_threads, nT); + } + + int ret = 0; + if (constantA && constantB) + { + ret = gemm_AT_BT_riscv_fp16sa(AT_data, BT_data, C, top_blob, broadcast_type_C, constantM, constantN, constantK, output_transpose, alpha, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else if (constantA) + { + const Mat& B = bottom_blobs[0]; + ret = gemm_AT_riscv_fp16sa(AT_data, B, C, top_blob, broadcast_type_C, constantM, constantK, transB, output_transpose, alpha, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else if (constantB) + { + const Mat& A = bottom_blobs[0]; + ret = gemm_BT_riscv_fp16sa(A, BT_data, C, top_blob, broadcast_type_C, constantN, constantK, transA, output_transpose, alpha, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else + { + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + ret = gemm_riscv_fp16sa(A, B, C, top_blob, broadcast_type_C, transA, transB, output_transpose, alpha, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + if (ret != 0) + return ret; + + return 0; +} int Gemm_riscv::create_pipeline_fp16s(const Option& opt) { if (constantA)