From 917dfd5fb8b78da43ff7fc6e42632c946e65e21f Mon Sep 17 00:00:00 2001 From: chenglimin Date: Tue, 10 Feb 2026 16:00:52 +0800 Subject: [PATCH 01/14] add the deformableconv2d operator for RVV backend, with 12.94x-20.16x speedup over scaler implementation --- src/layer/riscv/deformableconv2d_pack1ton.h | 160 +++++++ src/layer/riscv/deformableconv2d_packn.h | 182 ++++++++ src/layer/riscv/deformableconv2d_packnto1.h | 174 +++++++ src/layer/riscv/deformableconv2d_riscv.cpp | 482 ++++++++++++++++++++ src/layer/riscv/deformableconv2d_riscv.h | 43 ++ 5 files changed, 1041 insertions(+) create mode 100644 src/layer/riscv/deformableconv2d_pack1ton.h create mode 100644 src/layer/riscv/deformableconv2d_packn.h create mode 100644 src/layer/riscv/deformableconv2d_packnto1.h create mode 100644 src/layer/riscv/deformableconv2d_riscv.cpp create mode 100644 src/layer/riscv/deformableconv2d_riscv.h diff --git a/src/layer/riscv/deformableconv2d_pack1ton.h b/src/layer/riscv/deformableconv2d_pack1ton.h new file mode 100644 index 000000000000..dfa046c15e57 --- /dev/null +++ b/src/layer/riscv/deformableconv2d_pack1ton.h @@ -0,0 +1,160 @@ + + +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void deformableconv2d_pack1ton(const std::vector& bottom_blobs, Mat& top_blob, const Mat& weight_data_packed, const Mat& bias_data, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int pad_left, int pad_top, int activation_type, const Mat& activation_params, const Option& opt) +{ + const Mat& bottom_blob = bottom_blobs[0]; + const Mat& offset = bottom_blobs[1]; + const bool has_mask = (bottom_blobs.size() == 3); + const bool offset_not_pack = offset.elempack == 1; + const bool mask_not_pack = has_mask ? bottom_blobs[2].elempack == 1 : true; + + int w = bottom_blob.w; + int h = bottom_blob.h; + int inch = bottom_blob.c; + + int outw = top_blob.w; + int outh = top_blob.h; + int outch = top_blob.c; + + const float* bias_data_ptr = bias_data; + const int packn = csrr_vlenb() / 4; + const size_t vl = __riscv_vsetvl_e32m1(packn); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int h_col = 0; h_col < outh; h_col++) + { + for (int w_col = 0; w_col < outw; w_col++) + { + int h_in = h_col * stride_h - pad_top; + int w_in = w_col * stride_w - pad_left; + for (int oc = 0; oc < outch; oc++) + { + const float* kptr = weight_data_packed.channel(oc); + float* outptr = top_blob.channel(oc); + vfloat32m1_t _sum = __riscv_vfmv_v_f_f32m1(0.f, vl); + if (bias_data_ptr) + _sum = __riscv_vle32_v_f32m1(bias_data_ptr + oc * packn, vl); + + for (int i = 0; i < kernel_h; i++) + { + for (int j = 0; j < kernel_w; j++) + { + float offset_h = 0.f; + float offset_w = 0.f; + float mask_ = 1.f; + if (offset_not_pack) + { + offset_h = offset.channel((i * kernel_w + j) * 2).row(h_col)[w_col]; + offset_w = offset.channel((i * kernel_w + j) * 2 + 1).row(h_col)[w_col]; + } + else + { + const int y_c = (i * kernel_w + j) * 2; + const int x_c = (i * kernel_w + j) * 2 + 1; + offset_h = offset.channel(y_c / offset.elempack).row(h_col)[w_col * offset.elempack + y_c % offset.elempack]; + offset_w = offset.channel(x_c / offset.elempack).row(h_col)[w_col * offset.elempack + x_c % offset.elempack]; + } + if (has_mask) + { + const Mat& mask = bottom_blobs[2]; + if (mask_not_pack) + { + mask_ = mask.channel(i * kernel_w + j).row(h_col)[w_col]; + } + else + { + const int m_c = i * kernel_w + j; + mask_ = mask.channel(m_c / mask.elempack).row(h_col)[w_col * mask.elempack + m_c % mask.elempack]; + } + } + const float h_im = h_in + i * dilation_h + offset_h; + const float w_im = w_in + j * dilation_w + offset_w; + + // Bilinear + const bool cond = h_im > -1 && w_im > -1 && h_im < h && w_im < w; + float w1 = 0.f; + float w2 = 0.f; + float w3 = 0.f; + float w4 = 0.f; + bool v1_cond = false; + bool v2_cond = false; + bool v3_cond = false; + bool v4_cond = false; + int v1_pos = 0; + int v2_pos = 0; + int v3_pos = 0; + int v4_pos = 0; + if (cond) + { + int h_low = (int)floorf(h_im); + int w_low = (int)floorf(w_im); + int h_high = h_low + 1; + int w_high = w_low + 1; + + float lh = h_im - h_low; + float lw = w_im - w_low; + float hh = 1 - lh; + float hw = 1 - lw; + + v1_cond = (h_low >= 0 && w_low >= 0); + v2_cond = (h_low >= 0 && w_high <= w - 1); + v3_cond = (h_high <= h - 1 && w_low >= 0); + v4_cond = (h_high <= h - 1 && w_high <= w - 1); + if (v1_cond) + v1_pos = h_low * w + w_low; + if (v2_cond) + v2_pos = h_low * w + w_high; + if (v3_cond) + v3_pos = h_high * w + w_low; + if (v4_cond) + v4_pos = h_high * w + w_high; + + w1 = hh * hw; + w2 = hh * lw; + w3 = lh * hw; + w4 = lh * lw; + } + + for (int ic = 0; ic < inch; ic++) + { + const float* data_im_ptr = bottom_blob.channel(ic); + + if (cond) + { + float v_in = 0.f; + if (v1_cond) v_in += data_im_ptr[v1_pos] * w1; + if (v2_cond) v_in += data_im_ptr[v2_pos] * w2; + if (v3_cond) v_in += data_im_ptr[v3_pos] * w3; + if (v4_cond) v_in += data_im_ptr[v4_pos] * w4; + + if (has_mask) v_in *= mask_; + + vfloat32m1_t _w = __riscv_vle32_v_f32m1(kptr, vl); + _sum = __riscv_vfmacc_vf_f32m1(_sum, v_in, _w, vl); + } + + kptr += packn; + } + } + } + _sum = activation_ps(_sum, activation_type, activation_params, vl); + __riscv_vse32_v_f32m1(outptr + (h_col * outw + w_col) * packn, _sum, vl); + } + } + } +} + diff --git a/src/layer/riscv/deformableconv2d_packn.h b/src/layer/riscv/deformableconv2d_packn.h new file mode 100644 index 000000000000..c44e8bacc19e --- /dev/null +++ b/src/layer/riscv/deformableconv2d_packn.h @@ -0,0 +1,182 @@ + + +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& top_blob, const Mat& weight_data_packed, const Mat& bias_data, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int pad_left, int pad_top, int activation_type, const Mat& activation_params, const Option& opt) +{ + const Mat& bottom_blob = bottom_blobs[0]; + const Mat& offset = bottom_blobs[1]; + const bool has_mask = (bottom_blobs.size() == 3); + const bool offset_not_pack = offset.elempack == 1; + const bool mask_not_pack = has_mask ? bottom_blobs[2].elempack == 1 : true; + + int w = bottom_blob.w; + int h = bottom_blob.h; + int inch = bottom_blob.c; + + int outw = top_blob.w; + int outh = top_blob.h; + int outch = top_blob.c; + + const float* bias_data_ptr = bias_data; + const int packn = csrr_vlenb() / 4; + const size_t vl = __riscv_vsetvl_e32m1(packn); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int h_col = 0; h_col < outh; h_col++) + { + for (int w_col = 0; w_col < outw; w_col++) + { + int h_in = h_col * stride_h - pad_top; + int w_in = w_col * stride_w - pad_left; + for (int oc = 0; oc < outch; oc++) + { + const float* kptr = weight_data_packed.channel(oc); + float* outptr = top_blob.channel(oc); + vfloat32m1_t _sum = __riscv_vfmv_v_f_f32m1(0.f, vl); + if (bias_data_ptr) + _sum = __riscv_vle32_v_f32m1(bias_data_ptr + oc * packn, vl); + + for (int i = 0; i < kernel_h; i++) + { + for (int j = 0; j < kernel_w; j++) + { + float offset_h = 0.f; + float offset_w = 0.f; + float mask_ = 1.f; + if (offset_not_pack) + { + offset_h = offset.channel((i * kernel_w + j) * 2).row(h_col)[w_col]; + offset_w = offset.channel((i * kernel_w + j) * 2 + 1).row(h_col)[w_col]; + } + else + { + const int y_c = (i * kernel_w + j) * 2; + const int x_c = (i * kernel_w + j) * 2 + 1; + offset_h = offset.channel(y_c / offset.elempack).row(h_col)[w_col * offset.elempack + y_c % offset.elempack]; + offset_w = offset.channel(x_c / offset.elempack).row(h_col)[w_col * offset.elempack + x_c % offset.elempack]; + } + if (has_mask) + { + const Mat& mask = bottom_blobs[2]; + if (mask_not_pack) + { + mask_ = mask.channel(i * kernel_w + j).row(h_col)[w_col]; + } + else + { + const int m_c = i * kernel_w + j; + mask_ = mask.channel(m_c / mask.elempack).row(h_col)[w_col * mask.elempack + m_c % mask.elempack]; + } + } + const float h_im = h_in + i * dilation_h + offset_h; + const float w_im = w_in + j * dilation_w + offset_w; + + // Bilinear + const bool cond = h_im > -1 && w_im > -1 && h_im < h && w_im < w; + float w1 = 0.f; + float w2 = 0.f; + float w3 = 0.f; + float w4 = 0.f; + bool v1_cond = false; + bool v2_cond = false; + bool v3_cond = false; + bool v4_cond = false; + int v1_pos = 0; + int v2_pos = 0; + int v3_pos = 0; + int v4_pos = 0; + if (cond) + { + int h_low = (int)floorf(h_im); + int w_low = (int)floorf(w_im); + int h_high = h_low + 1; + int w_high = w_low + 1; + + float lh = h_im - h_low; + float lw = w_im - w_low; + float hh = 1 - lh; + float hw = 1 - lw; + + v1_cond = (h_low >= 0 && w_low >= 0); + v2_cond = (h_low >= 0 && w_high <= w - 1); + v3_cond = (h_high <= h - 1 && w_low >= 0); + v4_cond = (h_high <= h - 1 && w_high <= w - 1); + if (v1_cond) + v1_pos = h_low * w + w_low; + if (v2_cond) + v2_pos = h_low * w + w_high; + if (v3_cond) + v3_pos = h_high * w + w_low; + if (v4_cond) + v4_pos = h_high * w + w_high; + + w1 = hh * hw; + w2 = hh * lw; + w3 = lh * hw; + w4 = lh * lw; + } + + for (int ic = 0; ic < inch; ic++) + { + const float* data_im_ptr = bottom_blob.channel(ic); + + if (cond) + { + vfloat32m1_t _val = __riscv_vfmv_v_f_f32m1(0.f, vl); + + // Since we are iterating over input channels which are packed, + // we need to handle each element in the pack. + // However, the weight layout for packn is: + // [outch/packn][kh][kw][inch/packn][packn_in][packn_out] + // Wait, let's check the weight transformation in deformableconv2d_riscv.cpp + // weight_data_tm.create(num_input * maxk * num_output / (elempack * out_elempack), (size_t)4u * elempack * out_elempack, elempack * out_elempack); + // It seems the weight is packed as [packn_in * packn_out] + + // For each input channel pack (size packn), we have packn input values. + // Each input value contributes to all packn output values. + // So we have packn * packn weights for this block. + + // Let's look at x86 implementation again. + // _val_channel0..3 corresponds to the 4 input values in the pack. + // _conv_w0..3 corresponds to the weights for these input values. + // Each _conv_w is a vector of size 4 (out_elempack), representing weights for one input channel to all 4 output channels. + + for (int k = 0; k < packn; k++) + { + float v_in = 0.f; + if (v1_cond) v_in += data_im_ptr[v1_pos * packn + k] * w1; + if (v2_cond) v_in += data_im_ptr[v2_pos * packn + k] * w2; + if (v3_cond) v_in += data_im_ptr[v3_pos * packn + k] * w3; + if (v4_cond) v_in += data_im_ptr[v4_pos * packn + k] * w4; + + if (has_mask) v_in *= mask_; + + vfloat32m1_t _w = __riscv_vle32_v_f32m1(kptr + k * packn, vl); + _sum = __riscv_vfmacc_vf_f32m1(_sum, v_in, _w, vl); + } + } + + kptr += packn * packn; + } + } + } + _sum = activation_ps(_sum, activation_type, activation_params, vl); + __riscv_vse32_v_f32m1(outptr + (h_col * outw + w_col) * packn, _sum, vl); + } + } + } +} + diff --git a/src/layer/riscv/deformableconv2d_packnto1.h b/src/layer/riscv/deformableconv2d_packnto1.h new file mode 100644 index 000000000000..9c04b28c35ca --- /dev/null +++ b/src/layer/riscv/deformableconv2d_packnto1.h @@ -0,0 +1,174 @@ + + + +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void deformableconv2d_packnto1(const std::vector& bottom_blobs, Mat& top_blob, const Mat& weight_data_packed, const Mat& bias_data, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int pad_left, int pad_top, int activation_type, const Mat& activation_params, const Option& opt) +{ + const Mat& bottom_blob = bottom_blobs[0]; + const Mat& offset = bottom_blobs[1]; + const bool has_mask = (bottom_blobs.size() == 3); + const bool offset_not_pack = offset.elempack == 1; + const bool mask_not_pack = has_mask ? bottom_blobs[2].elempack == 1 : true; + + int w = bottom_blob.w; + int h = bottom_blob.h; + int inch = bottom_blob.c; + + int outw = top_blob.w; + int outh = top_blob.h; + int outch = top_blob.c; + + const float* bias_data_ptr = bias_data; + const int packn = csrr_vlenb() / 4; + const size_t vl = __riscv_vsetvl_e32m1(packn); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int h_col = 0; h_col < outh; h_col++) + { + for (int w_col = 0; w_col < outw; w_col++) + { + int h_in = h_col * stride_h - pad_top; + int w_in = w_col * stride_w - pad_left; + for (int oc = 0; oc < outch; oc++) + { + const float* kptr = weight_data_packed.channel(oc); + float* outptr = top_blob.channel(oc); + float sum = 0.f; + if (bias_data_ptr) + sum = bias_data_ptr[oc]; + + vfloat32m1_t _sum = __riscv_vfmv_v_f_f32m1(0.f, vl); + + for (int i = 0; i < kernel_h; i++) + { + for (int j = 0; j < kernel_w; j++) + { + float offset_h = 0.f; + float offset_w = 0.f; + float mask_ = 1.f; + if (offset_not_pack) + { + offset_h = offset.channel((i * kernel_w + j) * 2).row(h_col)[w_col]; + offset_w = offset.channel((i * kernel_w + j) * 2 + 1).row(h_col)[w_col]; + } + else + { + const int y_c = (i * kernel_w + j) * 2; + const int x_c = (i * kernel_w + j) * 2 + 1; + offset_h = offset.channel(y_c / offset.elempack).row(h_col)[w_col * offset.elempack + y_c % offset.elempack]; + offset_w = offset.channel(x_c / offset.elempack).row(h_col)[w_col * offset.elempack + x_c % offset.elempack]; + } + if (has_mask) + { + const Mat& mask = bottom_blobs[2]; + if (mask_not_pack) + { + mask_ = mask.channel(i * kernel_w + j).row(h_col)[w_col]; + } + else + { + const int m_c = i * kernel_w + j; + mask_ = mask.channel(m_c / mask.elempack).row(h_col)[w_col * mask.elempack + m_c % mask.elempack]; + } + } + const float h_im = h_in + i * dilation_h + offset_h; + const float w_im = w_in + j * dilation_w + offset_w; + + // Bilinear + const bool cond = h_im > -1 && w_im > -1 && h_im < h && w_im < w; + float w1 = 0.f; + float w2 = 0.f; + float w3 = 0.f; + float w4 = 0.f; + bool v1_cond = false; + bool v2_cond = false; + bool v3_cond = false; + bool v4_cond = false; + int v1_pos = 0; + int v2_pos = 0; + int v3_pos = 0; + int v4_pos = 0; + if (cond) + { + int h_low = (int)floorf(h_im); + int w_low = (int)floorf(w_im); + int h_high = h_low + 1; + int w_high = w_low + 1; + + float lh = h_im - h_low; + float lw = w_im - w_low; + float hh = 1 - lh; + float hw = 1 - lw; + + v1_cond = (h_low >= 0 && w_low >= 0); + v2_cond = (h_low >= 0 && w_high <= w - 1); + v3_cond = (h_high <= h - 1 && w_low >= 0); + v4_cond = (h_high <= h - 1 && w_high <= w - 1); + if (v1_cond) + v1_pos = h_low * w + w_low; + if (v2_cond) + v2_pos = h_low * w + w_high; + if (v3_cond) + v3_pos = h_high * w + w_low; + if (v4_cond) + v4_pos = h_high * w + w_high; + + w1 = hh * hw; + w2 = hh * lw; + w3 = lh * hw; + w4 = lh * lw; + } + + for (int ic = 0; ic < inch; ic++) + { + const float* data_im_ptr = bottom_blob.channel(ic); + + if (cond) + { + vfloat32m1_t _v1 = v1_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v1_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); + vfloat32m1_t _v2 = v2_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v2_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); + vfloat32m1_t _v3 = v3_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v3_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); + vfloat32m1_t _v4 = v4_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v4_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); + + vfloat32m1_t _val = __riscv_vfmv_v_f_f32m1(0.f, vl); + _val = __riscv_vfmacc_vf_f32m1(_val, w1, _v1, vl); + _val = __riscv_vfmacc_vf_f32m1(_val, w2, _v2, vl); + _val = __riscv_vfmacc_vf_f32m1(_val, w3, _v3, vl); + _val = __riscv_vfmacc_vf_f32m1(_val, w4, _v4, vl); + + if (has_mask) + _val = __riscv_vfmul_vf_f32m1(_val, mask_, vl); + + vfloat32m1_t _w = __riscv_vle32_v_f32m1(kptr, vl); + _sum = __riscv_vfmacc_vv_f32m1(_sum, _val, _w, vl); + } + + kptr += packn; + } + } + } + + vfloat32m1_t _v_sum = __riscv_vfredusum_vs_f32m1_f32m1(_sum, __riscv_vfmv_v_f_f32m1(0.f, vl), vl); + sum += __riscv_vfmv_f_s_f32m1_f32(_v_sum); + + sum = activation_ss(sum, activation_type, activation_params); + outptr[h_col * outw + w_col] = sum; + } + } + } +} + + diff --git a/src/layer/riscv/deformableconv2d_riscv.cpp b/src/layer/riscv/deformableconv2d_riscv.cpp new file mode 100644 index 000000000000..d88eda0e4c7c --- /dev/null +++ b/src/layer/riscv/deformableconv2d_riscv.cpp @@ -0,0 +1,482 @@ + +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "deformableconv2d_riscv.h" + +#if __riscv_vector +#include +#endif // __riscv_vector + +#include "riscv_activation.h" +#include "riscv_usability.h" + +#include "benchmark.h" +#include "cpu.h" +#include "layer_type.h" + +namespace ncnn { + +#if __riscv_vector +#include "deformableconv2d_packn.h" +#include "deformableconv2d_pack1ton.h" +#include "deformableconv2d_packnto1.h" +#endif // __riscv_vector + +DeformableConv2D_riscv::DeformableConv2D_riscv() +{ +#if __riscv_vector + support_packing = true; +#endif // __riscv_vector + + activation = 0; + gemm = 0; +} + +static int _4Dindex_to_1Dindex(int i0, int i1, int i2, int i3, int l1, int l2, int l3) +{ + return ((i0 * l1 + i1) * l2 + i2) * l3 + i3; +} + +static int _6Dindex_to_1Dindex(int i0, int i1, int i2, int i3, int i4, int i5, int l1, int l2, int l3, int l4, int l5) +{ + return ((((i0 * l1 + i1) * l2 + i2) * l3 + i3) * l4 + i4) * l5 + i5; +} + +#if __riscv_vector +static void deformableconv2d_transform_kernel_packed_riscv(const Mat& weight_data, Mat& weight_data_tm, int num_input, int num_output, int kernel_w, int kernel_h, int elempack, int out_elempack) +{ + const int maxk = kernel_w * kernel_h; + + // src = kw-kh-inch-outch + // dst = pb-pa-inch/pa-kw-kh-outch/pb + { + const float* weight_ptr = weight_data; + + weight_data_tm.create(num_input * maxk * num_output / (elempack * out_elempack), (size_t)4u * elempack * out_elempack, elempack * out_elempack); + float* ptr = weight_data_tm; + for (int oc = 0; oc < num_output; oc++) + { + for (int i = 0; i < kernel_h; i++) + { + for (int j = 0; j < kernel_w; j++) + { + for (int ic = 0; ic < num_input; ic++) + { + ptr[_6Dindex_to_1Dindex(oc / out_elempack, i, j, ic / elempack, ic % elempack, oc % out_elempack, kernel_h, kernel_w, num_input / elempack, elempack, out_elempack)] = weight_ptr[_4Dindex_to_1Dindex(oc, ic, i, j, num_input, kernel_h, kernel_w)]; + } + } + } + } + weight_data_tm = weight_data_tm.reshape(num_input / elempack, maxk, num_output / out_elempack); + } +} +#endif // __riscv_vector + +int DeformableConv2D_riscv::create_pipeline(const Option& opt) +{ + activation = create_activation_layer(activation_type, activation_params, opt); + + int kernel_size = kernel_w * kernel_h; + int num_input = weight_data_size / kernel_size / num_output; + + int elempack = 1; + int out_elempack = 1; + +#if __riscv_vector + if (opt.use_packing_layout) + { + const int packn = csrr_vlenb() / 4; + elempack = num_input % packn == 0 ? packn : 1; + out_elempack = num_output % packn == 0 ? packn : 1; + } +#endif // __riscv_vector + + if (opt.use_sgemm_convolution) + { + const int maxk = kernel_w * kernel_h; + + gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); + + ncnn::ParamDict pd; + pd.set(2, 0); // transA + pd.set(3, 0); // transB + pd.set(4, 1); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC + pd.set(7, num_output); // M = outch + pd.set(8, 0); // N = size + pd.set(9, maxk * num_input); // K = maxk*inch + pd.set(10, bias_term ? 1 : -1); // constant_broadcast_type_C = (M) + pd.set(11, 1); // output_N1M + + gemm->load_param(pd); + + // maxk-inch-outch to pa-maxk-inch/pa-outch + Mat tmp; + { + Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); + + tmp.create(maxk * num_input, num_output); + + for (int q = 0; q < num_output; q += 1) + { + float* g00 = tmp.row(q); + + for (int p = 0; p + (elempack - 1) < num_input; p += elempack) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < elempack; i++) + { + const float* k00 = weight_data_r2.channel(q).row(p + i); + g00[0] = k00[k]; + g00++; + } + } + } + } + } + + if (bias_term) + { + ncnn::Mat weights[2]; + weights[0] = tmp; + weights[1] = bias_data; + + gemm->load_model(ModelBinFromMatArray(weights)); + } + else + { + ncnn::Mat weights[1]; + weights[0] = tmp; + + gemm->load_model(ModelBinFromMatArray(weights)); + } + + gemm->create_pipeline(opt); + } + else if (elempack == 1 && out_elempack == 1) + { + weight_data_tm = weight_data; + } + else + { +#if __riscv_vector + deformableconv2d_transform_kernel_packed_riscv(weight_data, weight_data_tm, num_input, num_output, kernel_w, kernel_h, elempack, out_elempack); +#endif // __riscv_vector + } + + if (opt.lightmode) + { + if (!(elempack == 1 && out_elempack == 1)) + weight_data.release(); + } + + return 0; +} + +int DeformableConv2D_riscv::destroy_pipeline(const Option& opt) +{ + if (activation) + { + activation->destroy_pipeline(opt); + delete activation; + activation = 0; + } + + if (gemm) + { + gemm->destroy_pipeline(opt); + delete gemm; + gemm = 0; + } + + return 0; +} + +int DeformableConv2D_riscv::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& bottom_blob = bottom_blobs[0]; + const Mat& offset = bottom_blobs[1]; + const bool has_mask = (bottom_blobs.size() == 3); + Mat& top_blob = top_blobs[0]; + + int w = bottom_blob.w; + int h = bottom_blob.h; + int channels = bottom_blob.c; + size_t elemsize = bottom_blob.elemsize; + int elempack = bottom_blob.elempack; + + const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + const int kernel_extent_h = dilation_h * (kernel_h - 1) + 1; + const int outw = (w + pad_left + pad_right - kernel_extent_w) / stride_w + 1; + const int outh = (h + pad_top + pad_bottom - kernel_extent_h) / stride_h + 1; + + int out_elempack = 1; +#if __riscv_vector + if (opt.use_packing_layout) + { + const int packn = csrr_vlenb() / 4; + out_elempack = num_output % packn == 0 ? packn : 1; + } +#endif // __riscv_vector + size_t out_elemsize = elemsize / elempack * out_elempack; + + top_blob.create(outw, outh, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + if (opt.use_sgemm_convolution) + { + const int size = outw * outh; + const int maxk = kernel_w * kernel_h; + + Mat offset_unpacked; + convert_packing(offset, offset_unpacked, 1, opt); + + Mat mask_unpacked; + if (has_mask) + { + const Mat& mask = bottom_blobs[2]; + convert_packing(mask, mask_unpacked, 1, opt); + } + + // im2col + Mat bottom_im2col(size, maxk * channels, elemsize, elempack, opt.workspace_allocator); + +#if __riscv_vector + const int packn = csrr_vlenb() / 4; + if (elempack == packn) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < channels; p++) + { + const Mat img = bottom_blob.channel(p); + float* ptr = bottom_im2col.row(p * maxk); + + for (int u = 0; u < kernel_h; u++) + { + for (int v = 0; v < kernel_w; v++) + { + const Mat offset_h_k = offset_unpacked.channel((u * kernel_w + v) * 2); + const Mat offset_w_k = offset_unpacked.channel((u * kernel_w + v) * 2 + 1); + const Mat mask_k = has_mask ? mask_unpacked.channel(u * kernel_w + v) : 0; + + for (int i = 0; i < outh; i++) + { + for (int j = 0; j < outw; j++) + { + float offset_h = offset_h_k.row(i)[j]; + float offset_w = offset_w_k.row(i)[j]; + + int h_in = i * stride_h - pad_top; + int w_in = j * stride_w - pad_left; + + const float h_im = h_in + u * dilation_h + offset_h; + const float w_im = w_in + v * dilation_w + offset_w; + + // Bilinear + size_t vl = __riscv_vsetvl_e32m1(packn); + vfloat32m1_t _val = __riscv_vfmv_v_f_f32m1(0.f, vl); + bool cond = h_im > -1 && w_im > -1 && h_im < h && w_im < w; + if (cond) + { + int h_low = floor(h_im); + int w_low = floor(w_im); + int h_high = h_low + 1; + int w_high = w_low + 1; + + float lh = h_im - h_low; + float lw = w_im - w_low; + float hh = 1 - lh; + float hw = 1 - lw; + + bool v1_cond = (h_low >= 0 && w_low >= 0); + bool v2_cond = (h_low >= 0 && w_high <= w - 1); + bool v3_cond = (h_high <= h - 1 && w_low >= 0); + bool v4_cond = (h_high <= h - 1 && w_high <= w - 1); + + float w1 = hh * hw; + float w2 = hh * lw; + float w3 = lh * hw; + float w4 = lh * lw; + + vfloat32m1_t _v1 = v1_cond ? __riscv_vle32_v_f32m1((const float*)img.row(h_low) + w_low * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); + vfloat32m1_t _v2 = v2_cond ? __riscv_vle32_v_f32m1((const float*)img.row(h_low) + w_high * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); + vfloat32m1_t _v3 = v3_cond ? __riscv_vle32_v_f32m1((const float*)img.row(h_high) + w_low * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); + vfloat32m1_t _v4 = v4_cond ? __riscv_vle32_v_f32m1((const float*)img.row(h_high) + w_high * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); + + _val = __riscv_vfmacc_vf_f32m1(_val, w1, _v1, vl); + _val = __riscv_vfmacc_vf_f32m1(_val, w2, _v2, vl); + _val = __riscv_vfmacc_vf_f32m1(_val, w3, _v3, vl); + _val = __riscv_vfmacc_vf_f32m1(_val, w4, _v4, vl); + + if (has_mask) + _val = __riscv_vfmul_vf_f32m1(_val, mask_k.row(i)[j], vl); + } + + __riscv_vse32_v_f32m1(ptr, _val, vl); + + ptr += packn; + } + } + } + } + } + } +#endif // __riscv_vector + + if (elempack == 1) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < channels; p++) + { + const Mat img = bottom_blob.channel(p); + float* ptr = bottom_im2col.row(p * maxk); + + for (int u = 0; u < kernel_h; u++) + { + for (int v = 0; v < kernel_w; v++) + { + const Mat offset_h_k = offset_unpacked.channel((u * kernel_w + v) * 2); + const Mat offset_w_k = offset_unpacked.channel((u * kernel_w + v) * 2 + 1); + const Mat mask_k = has_mask ? mask_unpacked.channel(u * kernel_w + v) : 0; + + for (int i = 0; i < outh; i++) + { + for (int j = 0; j < outw; j++) + { + float offset_h = offset_h_k.row(i)[j]; + float offset_w = offset_w_k.row(i)[j]; + + int h_in = i * stride_h - pad_top; + int w_in = j * stride_w - pad_left; + + const float h_im = h_in + u * dilation_h + offset_h; + const float w_im = w_in + v * dilation_w + offset_w; + + // Bilinear + float val = 0.f; + bool cond = h_im > -1 && w_im > -1 && h_im < h && w_im < w; + if (cond) + { + int h_low = (int)floorf(h_im); + int w_low = (int)floorf(w_im); + int h_high = h_low + 1; + int w_high = w_low + 1; + + float lh = h_im - h_low; + float lw = w_im - w_low; + float hh = 1 - lh; + float hw = 1 - lw; + + bool v1_cond = (h_low >= 0 && w_low >= 0); + bool v2_cond = (h_low >= 0 && w_high <= w - 1); + bool v3_cond = (h_high <= h - 1 && w_low >= 0); + bool v4_cond = (h_high <= h - 1 && w_high <= w - 1); + + float w1 = hh * hw; + float w2 = hh * lw; + float w3 = lh * hw; + float w4 = lh * lw; + + float v1 = v1_cond ? img.row(h_low)[w_low] : 0.f; + float v2 = v2_cond ? img.row(h_low)[w_high] : 0.f; + float v3 = v3_cond ? img.row(h_high)[w_low] : 0.f; + float v4 = v4_cond ? img.row(h_high)[w_high] : 0.f; + val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; + + if (has_mask) + val *= mask_k.row(i)[j]; + } + + ptr[0] = val; + + ptr += 1; + } + } + } + } + } + } + + // sgemm + { + top_blob.w = outw * outh; + top_blob.h = 1; + } + Option opt_b = opt; + opt_b.blob_allocator = opt.workspace_allocator; + gemm->forward(bottom_im2col, top_blob, opt_b); + { + top_blob.w = outw; + top_blob.h = outh; + } + + if (activation) + { + activation->forward_inplace(top_blob, opt); + } + } + else + { +#if __riscv_vector + const int packn = csrr_vlenb() / 4; + + if (elempack == packn && out_elempack == packn) + { + deformableconv2d_packn(bottom_blobs, top_blob, weight_data_tm, bias_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, pad_left, pad_top, activation_type, activation_params, opt); + } + + if (elempack == 1 && out_elempack == packn) + { + deformableconv2d_pack1ton(bottom_blobs, top_blob, weight_data_tm, bias_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, pad_left, pad_top, activation_type, activation_params, opt); + } + + if (elempack == packn && out_elempack == 1) + { + deformableconv2d_packnto1(bottom_blobs, top_blob, weight_data_tm, bias_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, pad_left, pad_top, activation_type, activation_params, opt); + } + + if (elempack == 1 && out_elempack == 1) + { + std::vector bottom_blobs_unpacked = bottom_blobs; + Mat offset_unpacked; + if (offset.elempack != 1) + { + convert_packing(offset, offset_unpacked, 1, opt); + bottom_blobs_unpacked[1] = offset_unpacked; + } + + if (bottom_blobs.size() == 3) + { + const Mat& mask = bottom_blobs[2]; + if (mask.elempack != 1) + { + Mat mask_unpacked; + convert_packing(mask, mask_unpacked, 1, opt); + bottom_blobs_unpacked[2] = mask_unpacked; + } + } + + return DeformableConv2D::forward(bottom_blobs_unpacked, top_blobs, opt); + } +#endif // __riscv_vector + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/riscv/deformableconv2d_riscv.h b/src/layer/riscv/deformableconv2d_riscv.h new file mode 100644 index 000000000000..5d6a3b7765be --- /dev/null +++ b/src/layer/riscv/deformableconv2d_riscv.h @@ -0,0 +1,43 @@ + +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_DEFORMABLECONV2D_RISCV_H +#define LAYER_DEFORMABLECONV2D_RISCV_H + +#include "deformableconv2d.h" + +namespace ncnn { + +class DeformableConv2D_riscv : public DeformableConv2D +{ +public: + DeformableConv2D_riscv(); + + virtual int create_pipeline(const Option& opt); + virtual int destroy_pipeline(const Option& opt); + + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + +public: + Layer* activation; + + Mat weight_data_tm; + + Layer* gemm; +}; + +} // namespace ncnn + +#endif // LAYER_DEFORMABLECONV2D_RISCV_H From 2934858505513cf601a43ca3568793802876be9e Mon Sep 17 00:00:00 2001 From: chenglimin <18213449+chenglimin@users.noreply.github.com> Date: Tue, 10 Feb 2026 08:54:17 +0000 Subject: [PATCH 02/14] apply code-format changes --- src/layer/riscv/deformableconv2d_pack1ton.h | 11 +++++------ src/layer/riscv/deformableconv2d_packn.h | 19 +++++++++---------- src/layer/riscv/deformableconv2d_packnto1.h | 21 +++++++++------------ 3 files changed, 23 insertions(+), 28 deletions(-) diff --git a/src/layer/riscv/deformableconv2d_pack1ton.h b/src/layer/riscv/deformableconv2d_pack1ton.h index dfa046c15e57..628a9f7ae470 100644 --- a/src/layer/riscv/deformableconv2d_pack1ton.h +++ b/src/layer/riscv/deformableconv2d_pack1ton.h @@ -48,7 +48,7 @@ static void deformableconv2d_pack1ton(const std::vector& bottom_blobs, Mat& vfloat32m1_t _sum = __riscv_vfmv_v_f_f32m1(0.f, vl); if (bias_data_ptr) _sum = __riscv_vle32_v_f32m1(bias_data_ptr + oc * packn, vl); - + for (int i = 0; i < kernel_h; i++) { for (int j = 0; j < kernel_w; j++) @@ -132,7 +132,7 @@ static void deformableconv2d_pack1ton(const std::vector& bottom_blobs, Mat& for (int ic = 0; ic < inch; ic++) { const float* data_im_ptr = bottom_blob.channel(ic); - + if (cond) { float v_in = 0.f; @@ -140,13 +140,13 @@ static void deformableconv2d_pack1ton(const std::vector& bottom_blobs, Mat& if (v2_cond) v_in += data_im_ptr[v2_pos] * w2; if (v3_cond) v_in += data_im_ptr[v3_pos] * w3; if (v4_cond) v_in += data_im_ptr[v4_pos] * w4; - + if (has_mask) v_in *= mask_; - + vfloat32m1_t _w = __riscv_vle32_v_f32m1(kptr, vl); _sum = __riscv_vfmacc_vf_f32m1(_sum, v_in, _w, vl); } - + kptr += packn; } } @@ -157,4 +157,3 @@ static void deformableconv2d_pack1ton(const std::vector& bottom_blobs, Mat& } } } - diff --git a/src/layer/riscv/deformableconv2d_packn.h b/src/layer/riscv/deformableconv2d_packn.h index c44e8bacc19e..50346e6816bd 100644 --- a/src/layer/riscv/deformableconv2d_packn.h +++ b/src/layer/riscv/deformableconv2d_packn.h @@ -48,7 +48,7 @@ static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& to vfloat32m1_t _sum = __riscv_vfmv_v_f_f32m1(0.f, vl); if (bias_data_ptr) _sum = __riscv_vle32_v_f32m1(bias_data_ptr + oc * packn, vl); - + for (int i = 0; i < kernel_h; i++) { for (int j = 0; j < kernel_w; j++) @@ -132,11 +132,11 @@ static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& to for (int ic = 0; ic < inch; ic++) { const float* data_im_ptr = bottom_blob.channel(ic); - + if (cond) { vfloat32m1_t _val = __riscv_vfmv_v_f_f32m1(0.f, vl); - + // Since we are iterating over input channels which are packed, // we need to handle each element in the pack. // However, the weight layout for packn is: @@ -144,16 +144,16 @@ static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& to // Wait, let's check the weight transformation in deformableconv2d_riscv.cpp // weight_data_tm.create(num_input * maxk * num_output / (elempack * out_elempack), (size_t)4u * elempack * out_elempack, elempack * out_elempack); // It seems the weight is packed as [packn_in * packn_out] - + // For each input channel pack (size packn), we have packn input values. // Each input value contributes to all packn output values. // So we have packn * packn weights for this block. - + // Let's look at x86 implementation again. // _val_channel0..3 corresponds to the 4 input values in the pack. // _conv_w0..3 corresponds to the weights for these input values. // Each _conv_w is a vector of size 4 (out_elempack), representing weights for one input channel to all 4 output channels. - + for (int k = 0; k < packn; k++) { float v_in = 0.f; @@ -161,14 +161,14 @@ static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& to if (v2_cond) v_in += data_im_ptr[v2_pos * packn + k] * w2; if (v3_cond) v_in += data_im_ptr[v3_pos * packn + k] * w3; if (v4_cond) v_in += data_im_ptr[v4_pos * packn + k] * w4; - + if (has_mask) v_in *= mask_; - + vfloat32m1_t _w = __riscv_vle32_v_f32m1(kptr + k * packn, vl); _sum = __riscv_vfmacc_vf_f32m1(_sum, v_in, _w, vl); } } - + kptr += packn * packn; } } @@ -179,4 +179,3 @@ static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& to } } } - diff --git a/src/layer/riscv/deformableconv2d_packnto1.h b/src/layer/riscv/deformableconv2d_packnto1.h index 9c04b28c35ca..2f5896c04f79 100644 --- a/src/layer/riscv/deformableconv2d_packnto1.h +++ b/src/layer/riscv/deformableconv2d_packnto1.h @@ -1,6 +1,5 @@ - // Tencent is pleased to support the open source community by making ncnn available. // // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. @@ -49,9 +48,9 @@ static void deformableconv2d_packnto1(const std::vector& bottom_blobs, Mat& float sum = 0.f; if (bias_data_ptr) sum = bias_data_ptr[oc]; - + vfloat32m1_t _sum = __riscv_vfmv_v_f_f32m1(0.f, vl); - + for (int i = 0; i < kernel_h; i++) { for (int j = 0; j < kernel_w; j++) @@ -135,40 +134,38 @@ static void deformableconv2d_packnto1(const std::vector& bottom_blobs, Mat& for (int ic = 0; ic < inch; ic++) { const float* data_im_ptr = bottom_blob.channel(ic); - + if (cond) { vfloat32m1_t _v1 = v1_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v1_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); vfloat32m1_t _v2 = v2_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v2_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); vfloat32m1_t _v3 = v3_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v3_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); vfloat32m1_t _v4 = v4_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v4_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); - + vfloat32m1_t _val = __riscv_vfmv_v_f_f32m1(0.f, vl); _val = __riscv_vfmacc_vf_f32m1(_val, w1, _v1, vl); _val = __riscv_vfmacc_vf_f32m1(_val, w2, _v2, vl); _val = __riscv_vfmacc_vf_f32m1(_val, w3, _v3, vl); _val = __riscv_vfmacc_vf_f32m1(_val, w4, _v4, vl); - + if (has_mask) _val = __riscv_vfmul_vf_f32m1(_val, mask_, vl); - + vfloat32m1_t _w = __riscv_vle32_v_f32m1(kptr, vl); _sum = __riscv_vfmacc_vv_f32m1(_sum, _val, _w, vl); } - + kptr += packn; } } } - + vfloat32m1_t _v_sum = __riscv_vfredusum_vs_f32m1_f32m1(_sum, __riscv_vfmv_v_f_f32m1(0.f, vl), vl); sum += __riscv_vfmv_f_s_f32m1_f32(_v_sum); - + sum = activation_ss(sum, activation_type, activation_params); outptr[h_col * outw + w_col] = sum; } } } } - - From 8ceffbed9d2a5804eabd71b959a0be215fdeb377 Mon Sep 17 00:00:00 2001 From: chenglimin Date: Tue, 10 Feb 2026 18:23:21 +0800 Subject: [PATCH 03/14] Fix deformableconv2d RVV implementation --- src/layer/riscv/deformableconv2d_pack1ton.h | 28 +++++----------- src/layer/riscv/deformableconv2d_packn.h | 36 +++++++------------- src/layer/riscv/deformableconv2d_packnto1.h | 37 ++++++++------------- src/layer/riscv/deformableconv2d_riscv.cpp | 16 ++------- src/layer/riscv/deformableconv2d_riscv.h | 16 ++------- 5 files changed, 37 insertions(+), 96 deletions(-) diff --git a/src/layer/riscv/deformableconv2d_pack1ton.h b/src/layer/riscv/deformableconv2d_pack1ton.h index 628a9f7ae470..945d2b5e39c0 100644 --- a/src/layer/riscv/deformableconv2d_pack1ton.h +++ b/src/layer/riscv/deformableconv2d_pack1ton.h @@ -1,18 +1,5 @@ - - -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause static void deformableconv2d_pack1ton(const std::vector& bottom_blobs, Mat& top_blob, const Mat& weight_data_packed, const Mat& bias_data, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int pad_left, int pad_top, int activation_type, const Mat& activation_params, const Option& opt) { @@ -48,7 +35,7 @@ static void deformableconv2d_pack1ton(const std::vector& bottom_blobs, Mat& vfloat32m1_t _sum = __riscv_vfmv_v_f_f32m1(0.f, vl); if (bias_data_ptr) _sum = __riscv_vle32_v_f32m1(bias_data_ptr + oc * packn, vl); - + for (int i = 0; i < kernel_h; i++) { for (int j = 0; j < kernel_w; j++) @@ -132,7 +119,7 @@ static void deformableconv2d_pack1ton(const std::vector& bottom_blobs, Mat& for (int ic = 0; ic < inch; ic++) { const float* data_im_ptr = bottom_blob.channel(ic); - + if (cond) { float v_in = 0.f; @@ -140,13 +127,13 @@ static void deformableconv2d_pack1ton(const std::vector& bottom_blobs, Mat& if (v2_cond) v_in += data_im_ptr[v2_pos] * w2; if (v3_cond) v_in += data_im_ptr[v3_pos] * w3; if (v4_cond) v_in += data_im_ptr[v4_pos] * w4; - + if (has_mask) v_in *= mask_; - + vfloat32m1_t _w = __riscv_vle32_v_f32m1(kptr, vl); _sum = __riscv_vfmacc_vf_f32m1(_sum, v_in, _w, vl); } - + kptr += packn; } } @@ -157,3 +144,4 @@ static void deformableconv2d_pack1ton(const std::vector& bottom_blobs, Mat& } } } + diff --git a/src/layer/riscv/deformableconv2d_packn.h b/src/layer/riscv/deformableconv2d_packn.h index 50346e6816bd..61612a260aea 100644 --- a/src/layer/riscv/deformableconv2d_packn.h +++ b/src/layer/riscv/deformableconv2d_packn.h @@ -1,18 +1,5 @@ - - -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& top_blob, const Mat& weight_data_packed, const Mat& bias_data, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int pad_left, int pad_top, int activation_type, const Mat& activation_params, const Option& opt) { @@ -48,7 +35,7 @@ static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& to vfloat32m1_t _sum = __riscv_vfmv_v_f_f32m1(0.f, vl); if (bias_data_ptr) _sum = __riscv_vle32_v_f32m1(bias_data_ptr + oc * packn, vl); - + for (int i = 0; i < kernel_h; i++) { for (int j = 0; j < kernel_w; j++) @@ -132,11 +119,11 @@ static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& to for (int ic = 0; ic < inch; ic++) { const float* data_im_ptr = bottom_blob.channel(ic); - + if (cond) { vfloat32m1_t _val = __riscv_vfmv_v_f_f32m1(0.f, vl); - + // Since we are iterating over input channels which are packed, // we need to handle each element in the pack. // However, the weight layout for packn is: @@ -144,16 +131,16 @@ static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& to // Wait, let's check the weight transformation in deformableconv2d_riscv.cpp // weight_data_tm.create(num_input * maxk * num_output / (elempack * out_elempack), (size_t)4u * elempack * out_elempack, elempack * out_elempack); // It seems the weight is packed as [packn_in * packn_out] - + // For each input channel pack (size packn), we have packn input values. // Each input value contributes to all packn output values. // So we have packn * packn weights for this block. - + // Let's look at x86 implementation again. // _val_channel0..3 corresponds to the 4 input values in the pack. // _conv_w0..3 corresponds to the weights for these input values. // Each _conv_w is a vector of size 4 (out_elempack), representing weights for one input channel to all 4 output channels. - + for (int k = 0; k < packn; k++) { float v_in = 0.f; @@ -161,14 +148,14 @@ static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& to if (v2_cond) v_in += data_im_ptr[v2_pos * packn + k] * w2; if (v3_cond) v_in += data_im_ptr[v3_pos * packn + k] * w3; if (v4_cond) v_in += data_im_ptr[v4_pos * packn + k] * w4; - + if (has_mask) v_in *= mask_; - + vfloat32m1_t _w = __riscv_vle32_v_f32m1(kptr + k * packn, vl); _sum = __riscv_vfmacc_vf_f32m1(_sum, v_in, _w, vl); } } - + kptr += packn * packn; } } @@ -179,3 +166,4 @@ static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& to } } } + diff --git a/src/layer/riscv/deformableconv2d_packnto1.h b/src/layer/riscv/deformableconv2d_packnto1.h index 2f5896c04f79..577058ea1d41 100644 --- a/src/layer/riscv/deformableconv2d_packnto1.h +++ b/src/layer/riscv/deformableconv2d_packnto1.h @@ -1,18 +1,5 @@ - - -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause static void deformableconv2d_packnto1(const std::vector& bottom_blobs, Mat& top_blob, const Mat& weight_data_packed, const Mat& bias_data, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int pad_left, int pad_top, int activation_type, const Mat& activation_params, const Option& opt) { @@ -48,9 +35,9 @@ static void deformableconv2d_packnto1(const std::vector& bottom_blobs, Mat& float sum = 0.f; if (bias_data_ptr) sum = bias_data_ptr[oc]; - + vfloat32m1_t _sum = __riscv_vfmv_v_f_f32m1(0.f, vl); - + for (int i = 0; i < kernel_h; i++) { for (int j = 0; j < kernel_w; j++) @@ -134,38 +121,40 @@ static void deformableconv2d_packnto1(const std::vector& bottom_blobs, Mat& for (int ic = 0; ic < inch; ic++) { const float* data_im_ptr = bottom_blob.channel(ic); - + if (cond) { vfloat32m1_t _v1 = v1_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v1_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); vfloat32m1_t _v2 = v2_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v2_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); vfloat32m1_t _v3 = v3_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v3_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); vfloat32m1_t _v4 = v4_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v4_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); - + vfloat32m1_t _val = __riscv_vfmv_v_f_f32m1(0.f, vl); _val = __riscv_vfmacc_vf_f32m1(_val, w1, _v1, vl); _val = __riscv_vfmacc_vf_f32m1(_val, w2, _v2, vl); _val = __riscv_vfmacc_vf_f32m1(_val, w3, _v3, vl); _val = __riscv_vfmacc_vf_f32m1(_val, w4, _v4, vl); - + if (has_mask) _val = __riscv_vfmul_vf_f32m1(_val, mask_, vl); - + vfloat32m1_t _w = __riscv_vle32_v_f32m1(kptr, vl); _sum = __riscv_vfmacc_vv_f32m1(_sum, _val, _w, vl); } - + kptr += packn; } } } - + vfloat32m1_t _v_sum = __riscv_vfredusum_vs_f32m1_f32m1(_sum, __riscv_vfmv_v_f_f32m1(0.f, vl), vl); sum += __riscv_vfmv_f_s_f32m1_f32(_v_sum); - + sum = activation_ss(sum, activation_type, activation_params); outptr[h_col * outw + w_col] = sum; } } } } + + diff --git a/src/layer/riscv/deformableconv2d_riscv.cpp b/src/layer/riscv/deformableconv2d_riscv.cpp index d88eda0e4c7c..e0975e1e6054 100644 --- a/src/layer/riscv/deformableconv2d_riscv.cpp +++ b/src/layer/riscv/deformableconv2d_riscv.cpp @@ -1,17 +1,5 @@ - -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause #include "deformableconv2d_riscv.h" diff --git a/src/layer/riscv/deformableconv2d_riscv.h b/src/layer/riscv/deformableconv2d_riscv.h index 5d6a3b7765be..d538e5097075 100644 --- a/src/layer/riscv/deformableconv2d_riscv.h +++ b/src/layer/riscv/deformableconv2d_riscv.h @@ -1,17 +1,5 @@ - -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause #ifndef LAYER_DEFORMABLECONV2D_RISCV_H #define LAYER_DEFORMABLECONV2D_RISCV_H From f8dd5ab068c17121f8f8e93817c505a48c84a153 Mon Sep 17 00:00:00 2001 From: chenglimin <18213449+chenglimin@users.noreply.github.com> Date: Tue, 10 Feb 2026 10:27:31 +0000 Subject: [PATCH 04/14] apply code-format changes --- src/layer/riscv/deformableconv2d_pack1ton.h | 11 +++++------ src/layer/riscv/deformableconv2d_packn.h | 19 +++++++++---------- src/layer/riscv/deformableconv2d_packnto1.h | 20 +++++++++----------- 3 files changed, 23 insertions(+), 27 deletions(-) diff --git a/src/layer/riscv/deformableconv2d_pack1ton.h b/src/layer/riscv/deformableconv2d_pack1ton.h index 945d2b5e39c0..e5ba658c9551 100644 --- a/src/layer/riscv/deformableconv2d_pack1ton.h +++ b/src/layer/riscv/deformableconv2d_pack1ton.h @@ -35,7 +35,7 @@ static void deformableconv2d_pack1ton(const std::vector& bottom_blobs, Mat& vfloat32m1_t _sum = __riscv_vfmv_v_f_f32m1(0.f, vl); if (bias_data_ptr) _sum = __riscv_vle32_v_f32m1(bias_data_ptr + oc * packn, vl); - + for (int i = 0; i < kernel_h; i++) { for (int j = 0; j < kernel_w; j++) @@ -119,7 +119,7 @@ static void deformableconv2d_pack1ton(const std::vector& bottom_blobs, Mat& for (int ic = 0; ic < inch; ic++) { const float* data_im_ptr = bottom_blob.channel(ic); - + if (cond) { float v_in = 0.f; @@ -127,13 +127,13 @@ static void deformableconv2d_pack1ton(const std::vector& bottom_blobs, Mat& if (v2_cond) v_in += data_im_ptr[v2_pos] * w2; if (v3_cond) v_in += data_im_ptr[v3_pos] * w3; if (v4_cond) v_in += data_im_ptr[v4_pos] * w4; - + if (has_mask) v_in *= mask_; - + vfloat32m1_t _w = __riscv_vle32_v_f32m1(kptr, vl); _sum = __riscv_vfmacc_vf_f32m1(_sum, v_in, _w, vl); } - + kptr += packn; } } @@ -144,4 +144,3 @@ static void deformableconv2d_pack1ton(const std::vector& bottom_blobs, Mat& } } } - diff --git a/src/layer/riscv/deformableconv2d_packn.h b/src/layer/riscv/deformableconv2d_packn.h index 61612a260aea..4b2b0e7f194e 100644 --- a/src/layer/riscv/deformableconv2d_packn.h +++ b/src/layer/riscv/deformableconv2d_packn.h @@ -35,7 +35,7 @@ static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& to vfloat32m1_t _sum = __riscv_vfmv_v_f_f32m1(0.f, vl); if (bias_data_ptr) _sum = __riscv_vle32_v_f32m1(bias_data_ptr + oc * packn, vl); - + for (int i = 0; i < kernel_h; i++) { for (int j = 0; j < kernel_w; j++) @@ -119,11 +119,11 @@ static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& to for (int ic = 0; ic < inch; ic++) { const float* data_im_ptr = bottom_blob.channel(ic); - + if (cond) { vfloat32m1_t _val = __riscv_vfmv_v_f_f32m1(0.f, vl); - + // Since we are iterating over input channels which are packed, // we need to handle each element in the pack. // However, the weight layout for packn is: @@ -131,16 +131,16 @@ static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& to // Wait, let's check the weight transformation in deformableconv2d_riscv.cpp // weight_data_tm.create(num_input * maxk * num_output / (elempack * out_elempack), (size_t)4u * elempack * out_elempack, elempack * out_elempack); // It seems the weight is packed as [packn_in * packn_out] - + // For each input channel pack (size packn), we have packn input values. // Each input value contributes to all packn output values. // So we have packn * packn weights for this block. - + // Let's look at x86 implementation again. // _val_channel0..3 corresponds to the 4 input values in the pack. // _conv_w0..3 corresponds to the weights for these input values. // Each _conv_w is a vector of size 4 (out_elempack), representing weights for one input channel to all 4 output channels. - + for (int k = 0; k < packn; k++) { float v_in = 0.f; @@ -148,14 +148,14 @@ static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& to if (v2_cond) v_in += data_im_ptr[v2_pos * packn + k] * w2; if (v3_cond) v_in += data_im_ptr[v3_pos * packn + k] * w3; if (v4_cond) v_in += data_im_ptr[v4_pos * packn + k] * w4; - + if (has_mask) v_in *= mask_; - + vfloat32m1_t _w = __riscv_vle32_v_f32m1(kptr + k * packn, vl); _sum = __riscv_vfmacc_vf_f32m1(_sum, v_in, _w, vl); } } - + kptr += packn * packn; } } @@ -166,4 +166,3 @@ static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& to } } } - diff --git a/src/layer/riscv/deformableconv2d_packnto1.h b/src/layer/riscv/deformableconv2d_packnto1.h index 577058ea1d41..d84ccd0a77c8 100644 --- a/src/layer/riscv/deformableconv2d_packnto1.h +++ b/src/layer/riscv/deformableconv2d_packnto1.h @@ -35,9 +35,9 @@ static void deformableconv2d_packnto1(const std::vector& bottom_blobs, Mat& float sum = 0.f; if (bias_data_ptr) sum = bias_data_ptr[oc]; - + vfloat32m1_t _sum = __riscv_vfmv_v_f_f32m1(0.f, vl); - + for (int i = 0; i < kernel_h; i++) { for (int j = 0; j < kernel_w; j++) @@ -121,40 +121,38 @@ static void deformableconv2d_packnto1(const std::vector& bottom_blobs, Mat& for (int ic = 0; ic < inch; ic++) { const float* data_im_ptr = bottom_blob.channel(ic); - + if (cond) { vfloat32m1_t _v1 = v1_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v1_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); vfloat32m1_t _v2 = v2_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v2_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); vfloat32m1_t _v3 = v3_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v3_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); vfloat32m1_t _v4 = v4_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v4_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); - + vfloat32m1_t _val = __riscv_vfmv_v_f_f32m1(0.f, vl); _val = __riscv_vfmacc_vf_f32m1(_val, w1, _v1, vl); _val = __riscv_vfmacc_vf_f32m1(_val, w2, _v2, vl); _val = __riscv_vfmacc_vf_f32m1(_val, w3, _v3, vl); _val = __riscv_vfmacc_vf_f32m1(_val, w4, _v4, vl); - + if (has_mask) _val = __riscv_vfmul_vf_f32m1(_val, mask_, vl); - + vfloat32m1_t _w = __riscv_vle32_v_f32m1(kptr, vl); _sum = __riscv_vfmacc_vv_f32m1(_sum, _val, _w, vl); } - + kptr += packn; } } } - + vfloat32m1_t _v_sum = __riscv_vfredusum_vs_f32m1_f32m1(_sum, __riscv_vfmv_v_f_f32m1(0.f, vl), vl); sum += __riscv_vfmv_f_s_f32m1_f32(_v_sum); - + sum = activation_ss(sum, activation_type, activation_params); outptr[h_col * outw + w_col] = sum; } } } } - - From d94348c1af85e7098f0d7bd349ef90547a83f0f6 Mon Sep 17 00:00:00 2001 From: chenglimin Date: Thu, 12 Feb 2026 15:50:47 +0800 Subject: [PATCH 05/14] Update src/layer/riscv/deformableconv2d_packn.h Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/layer/riscv/deformableconv2d_packn.h | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/src/layer/riscv/deformableconv2d_packn.h b/src/layer/riscv/deformableconv2d_packn.h index 4b2b0e7f194e..3a03e9b0b552 100644 --- a/src/layer/riscv/deformableconv2d_packn.h +++ b/src/layer/riscv/deformableconv2d_packn.h @@ -124,23 +124,13 @@ static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& to { vfloat32m1_t _val = __riscv_vfmv_v_f_f32m1(0.f, vl); - // Since we are iterating over input channels which are packed, - // we need to handle each element in the pack. - // However, the weight layout for packn is: - // [outch/packn][kh][kw][inch/packn][packn_in][packn_out] - // Wait, let's check the weight transformation in deformableconv2d_riscv.cpp - // weight_data_tm.create(num_input * maxk * num_output / (elempack * out_elempack), (size_t)4u * elempack * out_elempack, elempack * out_elempack); - // It seems the weight is packed as [packn_in * packn_out] - - // For each input channel pack (size packn), we have packn input values. - // Each input value contributes to all packn output values. - // So we have packn * packn weights for this block. - - // Let's look at x86 implementation again. - // _val_channel0..3 corresponds to the 4 input values in the pack. - // _conv_w0..3 corresponds to the weights for these input values. - // Each _conv_w is a vector of size 4 (out_elempack), representing weights for one input channel to all 4 output channels. - + // Packed-weight memory layout for packn: + // For each output-channel pack, kernel position (kh, kw) and input-channel pack, + // the weights are stored as a contiguous block of size packn_in * packn_out + // (with packn_in == packn_out == packn here). Within this block, lane k in + // the input pack uses the vector loaded from kptr + k * packn, which contains + // the weights from that input lane to all packn output channels. After all + // packn input lanes are processed, kptr is advanced by packn * packn. for (int k = 0; k < packn; k++) { float v_in = 0.f; From 09bbbc98df5b4eb63f5db386ae194b1122e76b8c Mon Sep 17 00:00:00 2001 From: chenglimin Date: Thu, 12 Feb 2026 16:59:18 +0800 Subject: [PATCH 06/14] add scaler fallback --- src/layer/riscv/deformableconv2d_riscv.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layer/riscv/deformableconv2d_riscv.cpp b/src/layer/riscv/deformableconv2d_riscv.cpp index e0975e1e6054..86e94b74c58e 100644 --- a/src/layer/riscv/deformableconv2d_riscv.cpp +++ b/src/layer/riscv/deformableconv2d_riscv.cpp @@ -437,7 +437,7 @@ int DeformableConv2D_riscv::forward(const std::vector& bottom_blobs, std::v { deformableconv2d_packnto1(bottom_blobs, top_blob, weight_data_tm, bias_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, pad_left, pad_top, activation_type, activation_params, opt); } - +#endif // __riscv_vector if (elempack == 1 && out_elempack == 1) { std::vector bottom_blobs_unpacked = bottom_blobs; @@ -461,7 +461,7 @@ int DeformableConv2D_riscv::forward(const std::vector& bottom_blobs, std::v return DeformableConv2D::forward(bottom_blobs_unpacked, top_blobs, opt); } -#endif // __riscv_vector + } return 0; From feff32e4574ec6d8e52c05de467b2ba42839866b Mon Sep 17 00:00:00 2001 From: chenglimin <18213449+chenglimin@users.noreply.github.com> Date: Thu, 12 Feb 2026 09:05:37 +0000 Subject: [PATCH 07/14] apply code-format changes --- src/layer/riscv/deformableconv2d_riscv.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/layer/riscv/deformableconv2d_riscv.cpp b/src/layer/riscv/deformableconv2d_riscv.cpp index 86e94b74c58e..8e6f7aa2b697 100644 --- a/src/layer/riscv/deformableconv2d_riscv.cpp +++ b/src/layer/riscv/deformableconv2d_riscv.cpp @@ -461,7 +461,6 @@ int DeformableConv2D_riscv::forward(const std::vector& bottom_blobs, std::v return DeformableConv2D::forward(bottom_blobs_unpacked, top_blobs, opt); } - } return 0; From 3380ab2053caa3322f9ea77a41d8459a4b05d638 Mon Sep 17 00:00:00 2001 From: chenglimin Date: Thu, 12 Feb 2026 18:49:10 +0800 Subject: [PATCH 08/14] add always release weight_data --- src/layer/riscv/deformableconv2d_riscv.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/layer/riscv/deformableconv2d_riscv.cpp b/src/layer/riscv/deformableconv2d_riscv.cpp index 8e6f7aa2b697..3355238e72a1 100644 --- a/src/layer/riscv/deformableconv2d_riscv.cpp +++ b/src/layer/riscv/deformableconv2d_riscv.cpp @@ -168,8 +168,7 @@ int DeformableConv2D_riscv::create_pipeline(const Option& opt) if (opt.lightmode) { - if (!(elempack == 1 && out_elempack == 1)) - weight_data.release(); + weight_data.release(); } return 0; From f0def8687ec09a3becbba397060562debd63a458 Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 20 Feb 2026 21:32:52 +0800 Subject: [PATCH 09/14] add pack1 path --- src/layer/riscv/deformableconv2d_riscv.cpp | 120 ++++++++++++++++++--- 1 file changed, 104 insertions(+), 16 deletions(-) diff --git a/src/layer/riscv/deformableconv2d_riscv.cpp b/src/layer/riscv/deformableconv2d_riscv.cpp index 3355238e72a1..a0695abe86b3 100644 --- a/src/layer/riscv/deformableconv2d_riscv.cpp +++ b/src/layer/riscv/deformableconv2d_riscv.cpp @@ -439,26 +439,114 @@ int DeformableConv2D_riscv::forward(const std::vector& bottom_blobs, std::v #endif // __riscv_vector if (elempack == 1 && out_elempack == 1) { - std::vector bottom_blobs_unpacked = bottom_blobs; - Mat offset_unpacked; - if (offset.elempack != 1) - { - convert_packing(offset, offset_unpacked, 1, opt); - bottom_blobs_unpacked[1] = offset_unpacked; - } - - if (bottom_blobs.size() == 3) + const bool offset_not_pack = offset.elempack == 1; + const bool mask_not_pack = has_mask ? bottom_blobs[2].elempack == 1 : true; + const float* weight_ptr = weight_data_tm; + + // naive deformable conv + #pragma omp parallel for num_threads(opt.num_threads) + for (int h_col = 0; h_col < outh; h_col++) { - const Mat& mask = bottom_blobs[2]; - if (mask.elempack != 1) + for (int w_col = 0; w_col < outw; w_col++) { - Mat mask_unpacked; - convert_packing(mask, mask_unpacked, 1, opt); - bottom_blobs_unpacked[2] = mask_unpacked; + int h_in = h_col * stride_h - pad_top; + int w_in = w_col * stride_w - pad_left; + for (int oc = 0; oc < num_output; oc++) + { + float sum = 0.f; + if (bias_term) + sum = bias_data[oc]; + for (int i = 0; i < kernel_h; i++) + { + for (int j = 0; j < kernel_w; j++) + { + float offset_h = 0.f; + float offset_w = 0.f; + float mask_ = 1.f; + if (offset_not_pack) + { + offset_h = offset.channel((i * kernel_w + j) * 2).row(h_col)[w_col]; + offset_w = offset.channel((i * kernel_w + j) * 2 + 1).row(h_col)[w_col]; + } + else + { + const int y_c = (i * kernel_w + j) * 2; + const int x_c = (i * kernel_w + j) * 2 + 1; + offset_h = offset.channel(y_c / offset.elempack).row(h_col)[w_col * offset.elempack + y_c % offset.elempack]; + offset_w = offset.channel(x_c / offset.elempack).row(h_col)[w_col * offset.elempack + x_c % offset.elempack]; + } + if (has_mask) + { + const Mat& mask = bottom_blobs[2]; + if (mask_not_pack) + { + mask_ = mask.channel(i * kernel_w + j).row(h_col)[w_col]; + } + else + { + const int m_c = i * kernel_w + j; + mask_ = mask.channel(m_c / mask.elempack).row(h_col)[w_col * mask.elempack + m_c % mask.elempack]; + } + } + const float h_im = h_in + i * dilation_h + offset_h; + const float w_im = w_in + j * dilation_w + offset_w; + + // Bilinear + const bool cond = h_im > -1 && w_im > -1 && h_im < h && w_im < w; + int h_low = 0; + int w_low = 0; + int h_high = 0; + int w_high = 0; + float w1 = 0.f; + float w2 = 0.f; + float w3 = 0.f; + float w4 = 0.f; + bool v1_cond = false; + bool v2_cond = false; + bool v3_cond = false; + bool v4_cond = false; + if (cond) + { + h_low = (int)floorf(h_im); + w_low = (int)floorf(w_im); + h_high = h_low + 1; + w_high = w_low + 1; + + float lh = h_im - h_low; + float lw = w_im - w_low; + float hh = 1 - lh; + float hw = 1 - lw; + + v1_cond = (h_low >= 0 && w_low >= 0); + v2_cond = (h_low >= 0 && w_high <= w - 1); + v3_cond = (h_high <= h - 1 && w_low >= 0); + v4_cond = (h_high <= h - 1 && w_high <= w - 1); + + w1 = hh * hw; + w2 = hh * lw; + w3 = lh * hw; + w4 = lh * lw; + } + + for (int ic = 0; ic < channels; ic++) + { + float val = 0.f; + if (cond) + { + float v1 = v1_cond ? bottom_blob.channel(ic).row(h_low)[w_low] : 0.f; + float v2 = v2_cond ? bottom_blob.channel(ic).row(h_low)[w_high] : 0.f; + float v3 = v3_cond ? bottom_blob.channel(ic).row(h_high)[w_low] : 0.f; + float v4 = v4_cond ? bottom_blob.channel(ic).row(h_high)[w_high] : 0.f; + val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; + } + sum += val * mask_ * weight_ptr[((oc * channels + ic) * kernel_h + i) * kernel_w + j]; + } + } + } + top_blob.channel(oc).row(h_col)[w_col] = activation_ss(sum, activation_type, activation_params); + } } } - - return DeformableConv2D::forward(bottom_blobs_unpacked, top_blobs, opt); } } From 5e1a76315208e9743083153e066d8b937cac36c3 Mon Sep 17 00:00:00 2001 From: nihui <171016+nihui@users.noreply.github.com> Date: Fri, 20 Feb 2026 13:34:37 +0000 Subject: [PATCH 10/14] apply code-format changes --- src/layer/riscv/deformableconv2d_riscv.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/layer/riscv/deformableconv2d_riscv.cpp b/src/layer/riscv/deformableconv2d_riscv.cpp index a0695abe86b3..1bf560afc428 100644 --- a/src/layer/riscv/deformableconv2d_riscv.cpp +++ b/src/layer/riscv/deformableconv2d_riscv.cpp @@ -442,7 +442,7 @@ int DeformableConv2D_riscv::forward(const std::vector& bottom_blobs, std::v const bool offset_not_pack = offset.elempack == 1; const bool mask_not_pack = has_mask ? bottom_blobs[2].elempack == 1 : true; const float* weight_ptr = weight_data_tm; - + // naive deformable conv #pragma omp parallel for num_threads(opt.num_threads) for (int h_col = 0; h_col < outh; h_col++) @@ -490,7 +490,7 @@ int DeformableConv2D_riscv::forward(const std::vector& bottom_blobs, std::v } const float h_im = h_in + i * dilation_h + offset_h; const float w_im = w_in + j * dilation_w + offset_w; - + // Bilinear const bool cond = h_im > -1 && w_im > -1 && h_im < h && w_im < w; int h_low = 0; @@ -511,23 +511,23 @@ int DeformableConv2D_riscv::forward(const std::vector& bottom_blobs, std::v w_low = (int)floorf(w_im); h_high = h_low + 1; w_high = w_low + 1; - + float lh = h_im - h_low; float lw = w_im - w_low; float hh = 1 - lh; float hw = 1 - lw; - + v1_cond = (h_low >= 0 && w_low >= 0); v2_cond = (h_low >= 0 && w_high <= w - 1); v3_cond = (h_high <= h - 1 && w_low >= 0); v4_cond = (h_high <= h - 1 && w_high <= w - 1); - + w1 = hh * hw; w2 = hh * lw; w3 = lh * hw; w4 = lh * lw; } - + for (int ic = 0; ic < channels; ic++) { float val = 0.f; From c76a349eaece934f5d3adda40e3520d8faece9bb Mon Sep 17 00:00:00 2001 From: chenglimin Date: Fri, 22 May 2026 10:43:36 +0800 Subject: [PATCH 11/14] add riscv rvv support for lstm operator --- src/layer/riscv/lstm_riscv.cpp | 653 +++++++++++++++++++++++++++++++++ src/layer/riscv/lstm_riscv.h | 35 ++ 2 files changed, 688 insertions(+) create mode 100644 src/layer/riscv/lstm_riscv.cpp create mode 100644 src/layer/riscv/lstm_riscv.h diff --git a/src/layer/riscv/lstm_riscv.cpp b/src/layer/riscv/lstm_riscv.cpp new file mode 100644 index 000000000000..9eca6cd8ae1f --- /dev/null +++ b/src/layer/riscv/lstm_riscv.cpp @@ -0,0 +1,653 @@ + + + + + + + + + +#include "lstm_riscv.h" +#include +#include "riscv_usability.h" +#include "rvv_mathfun.h" +#include +#include + +namespace ncnn { + +LSTM_riscv::LSTM_riscv() +{ +} + +static inline float dot_product(const float* a, const float* b, int n) +{ + size_t max_vl = __riscv_vsetvlmax_e32m8(); + vfloat32m8_t sum_v = __riscv_vfmv_v_f_f32m8(0.f, max_vl); + + while (n > 0) + { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t va = __riscv_vle32_v_f32m8(a, vl); + vfloat32m8_t vb = __riscv_vle32_v_f32m8(b, vl); + sum_v = __riscv_vfmacc_vv_f32m8_tu(sum_v, va, vb, vl); + a += vl; + b += vl; + n -= vl; + } + + vfloat32m1_t sum_s = __riscv_vfredusum_vs_f32m8_f32m1(sum_v, __riscv_vfmv_v_f_f32m1(0.f, 1), max_vl); + return __riscv_vfmv_f_s_f32m1_f32(sum_s); +} + +static void sigmoid_vector(float* ptr, int n) +{ + while (n > 0) + { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr, vl); + + vfloat32m8_t _neg_p = __riscv_vfmul_vf_f32m8(_p, -1.f, vl); + vfloat32m8_t _exp_neg_p = exp_ps(_neg_p, vl); + vfloat32m8_t _den = __riscv_vfadd_vf_f32m8(_exp_neg_p, 1.f, vl); + _p = __riscv_vfrdiv_vf_f32m8(_den, 1.f, vl); + + __riscv_vse32_v_f32m8(ptr, _p, vl); + ptr += vl; + n -= vl; + } +} + +static void tanh_vector(float* ptr, int n) +{ + while (n > 0) + { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr, vl); + + vfloat32m8_t _2x = __riscv_vfmul_vf_f32m8(_p, 2.f, vl); + vfloat32m8_t _exp2x = exp_ps(_2x, vl); + vfloat32m8_t _num = __riscv_vfsub_vf_f32m8(_exp2x, 1.f, vl); + vfloat32m8_t _den = __riscv_vfadd_vf_f32m8(_exp2x, 1.f, vl); + _p = __riscv_vfdiv_vv_f32m8(_num, _den, vl); + + __riscv_vse32_v_f32m8(ptr, _p, vl); + ptr += vl; + n -= vl; + } +} + +static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +{ + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + int hidden_size = cell_state.w; + + // hidden_size x 4 + Mat gates(hidden_size, 4, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + Mat tmp_hidden_state; + if (num_output != hidden_size) + { + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + if (tmp_hidden_state.empty()) + return -100; + } + + for (int t = 0; t < T; t++) + { + int ti = reverse ? T - 1 - t : t; + + const float* x = bottom_blob.row(ti); + + float* I_ptr = gates.row(0); + float* F_ptr = gates.row(1); + float* O_ptr = gates.row(2); + float* G_ptr = gates.row(3); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < hidden_size; q++) + { + const float* bias_c_I = bias_c.row(0); + const float* bias_c_F = bias_c.row(1); + const float* bias_c_O = bias_c.row(2); + const float* bias_c_G = bias_c.row(3); + + const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); + const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); + const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); + const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); + + const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); + const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); + const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); + const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); + + float I = bias_c_I[q]; + float F = bias_c_F[q]; + float O = bias_c_O[q]; + float G = bias_c_G[q]; + + I += dot_product(weight_xc_I, x, size); + F += dot_product(weight_xc_F, x, size); + O += dot_product(weight_xc_O, x, size); + G += dot_product(weight_xc_G, x, size); + + I += dot_product(weight_hc_I, hidden_state, num_output); + F += dot_product(weight_hc_F, hidden_state, num_output); + O += dot_product(weight_hc_O, hidden_state, num_output); + G += dot_product(weight_hc_G, hidden_state, num_output); + + I_ptr[q] = I; + F_ptr[q] = F; + O_ptr[q] = O; + G_ptr[q] = G; + } + + sigmoid_vector(I_ptr, hidden_size); + sigmoid_vector(F_ptr, hidden_size); + sigmoid_vector(O_ptr, hidden_size); + tanh_vector(G_ptr, hidden_size); + + // Update cell and hidden + float* cell_ptr = cell_state; + float* hidden_ptr = hidden_state; + float* tmp_hidden_ptr = tmp_hidden_state; + float* output_data = top_blob.row(ti); + + int n = hidden_size; + float* i_p = I_ptr; + float* f_p = F_ptr; + float* o_p = O_ptr; + float* g_p = G_ptr; + float* c_p = cell_ptr; + float* h_out_p = (num_output == hidden_size) ? hidden_ptr : tmp_hidden_ptr; + + while (n > 0) + { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t _i = __riscv_vle32_v_f32m8(i_p, vl); + vfloat32m8_t _f = __riscv_vle32_v_f32m8(f_p, vl); + vfloat32m8_t _o = __riscv_vle32_v_f32m8(o_p, vl); + vfloat32m8_t _g = __riscv_vle32_v_f32m8(g_p, vl); + vfloat32m8_t _c = __riscv_vle32_v_f32m8(c_p, vl); + + // cell = F * cell + I * G + vfloat32m8_t _fc = __riscv_vfmul_vv_f32m8(_f, _c, vl); + vfloat32m8_t _ig = __riscv_vfmul_vv_f32m8(_i, _g, vl); + _c = __riscv_vfadd_vv_f32m8(_fc, _ig, vl); + __riscv_vse32_v_f32m8(c_p, _c, vl); + + // H = O * tanh(cell) + vfloat32m8_t _2c = __riscv_vfmul_vf_f32m8(_c, 2.f, vl); + vfloat32m8_t _exp2c = exp_ps(_2c, vl); + vfloat32m8_t _num = __riscv_vfsub_vf_f32m8(_exp2c, 1.f, vl); + vfloat32m8_t _den = __riscv_vfadd_vf_f32m8(_exp2c, 1.f, vl); + vfloat32m8_t _tanh_c = __riscv_vfdiv_vv_f32m8(_num, _den, vl); + + vfloat32m8_t _h = __riscv_vfmul_vv_f32m8(_o, _tanh_c, vl); + __riscv_vse32_v_f32m8(h_out_p, _h, vl); + + if (num_output == hidden_size) + { + __riscv_vse32_v_f32m8(output_data, _h, vl); + output_data += vl; + } + + i_p += vl; f_p += vl; o_p += vl; g_p += vl; + c_p += vl; h_out_p += vl; + n -= vl; + } + + if (num_output != hidden_size) + { + // Projection + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + const float* tmp_h = tmp_hidden_state; + + float H = dot_product(hr, tmp_h, hidden_size); + + hidden_state[q] = H; + top_blob.row(ti)[q] = H; + } + } + } + + return 0; +} + +#if NCNN_INT8 +static int lstm_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const float* weight_xc_int8_scales, const Mat& bias_c, const Mat& weight_hc_int8, const float* weight_hc_int8_scales, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +{ + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + int hidden_size = cell_state.w; + + // 4 x hidden_size + Mat gates(4, hidden_size, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + Mat tmp_hidden_state; + if (num_output != hidden_size) + { + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + if (tmp_hidden_state.empty()) + return -100; + } + + // dynamic quantize bottom_blob + Mat bottom_blob_int8(size, T, (size_t)1u, 1, opt.workspace_allocator); + Mat bottom_blob_int8_scales(T, (size_t)4u, 1, opt.workspace_allocator); + { + for (int t = 0; t < T; t++) + { + const float* x = bottom_blob.row(t); + + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(x[i])); + } + + bottom_blob_int8_scales[t] = 127.f / absmax; + } + + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_quant); + } + + Mat hidden_state_int8(num_output, (size_t)1u, 1, opt.workspace_allocator); + Mat hidden_state_int8_scales(1, (size_t)4u, 1, opt.workspace_allocator); + + // unroll + for (int t = 0; t < T; t++) + { + int ti = reverse ? T - 1 - t : t; + + // dynamic quantize hidden_state + { + float absmax = 0.f; + for (int i = 0; i < num_output; i++) + { + absmax = std::max(absmax, (float)fabs(hidden_state[i])); + } + + if (absmax == 0.f) + { + hidden_state_int8_scales[0] = 1.f; + hidden_state_int8.fill(0); + } + else + { + hidden_state_int8_scales[0] = 127.f / absmax; + + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + quantize_to_int8(hidden_state, hidden_state_int8, hidden_state_int8_scales, opt_quant); + } + } + + const signed char* x = bottom_blob_int8.row(ti); + const signed char* hs = hidden_state_int8; + const float descale_x = 1.f / bottom_blob_int8_scales[ti]; + const float descale_h = 1.f / hidden_state_int8_scales[0]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < hidden_size; q++) + { + const float* bias_c_I = bias_c.row(0); + const float* bias_c_F = bias_c.row(1); + const float* bias_c_O = bias_c.row(2); + const float* bias_c_G = bias_c.row(3); + + float* gates_data = gates.row(q); + + // gate I F O G + const signed char* weight_xc_int8_I = weight_xc_int8.row(hidden_size * 0 + q); + const signed char* weight_xc_int8_F = weight_xc_int8.row(hidden_size * 1 + q); + const signed char* weight_xc_int8_O = weight_xc_int8.row(hidden_size * 2 + q); + const signed char* weight_xc_int8_G = weight_xc_int8.row(hidden_size * 3 + q); + + const signed char* weight_hc_int8_I = weight_hc_int8.row(hidden_size * 0 + q); + const signed char* weight_hc_int8_F = weight_hc_int8.row(hidden_size * 1 + q); + const signed char* weight_hc_int8_O = weight_hc_int8.row(hidden_size * 2 + q); + const signed char* weight_hc_int8_G = weight_hc_int8.row(hidden_size * 3 + q); + + const float descale_xc_I = 1.f / weight_xc_int8_scales[hidden_size * 0 + q]; + const float descale_xc_F = 1.f / weight_xc_int8_scales[hidden_size * 1 + q]; + const float descale_xc_O = 1.f / weight_xc_int8_scales[hidden_size * 2 + q]; + const float descale_xc_G = 1.f / weight_xc_int8_scales[hidden_size * 3 + q]; + const float descale_hc_I = 1.f / weight_hc_int8_scales[hidden_size * 0 + q]; + const float descale_hc_F = 1.f / weight_hc_int8_scales[hidden_size * 1 + q]; + const float descale_hc_O = 1.f / weight_hc_int8_scales[hidden_size * 2 + q]; + const float descale_hc_G = 1.f / weight_hc_int8_scales[hidden_size * 3 + q]; + + int Ix = 0; + int Fx = 0; + int Ox = 0; + int Gx = 0; + for (int i = 0; i < size; i++) + { + signed char xi = x[i]; + + Ix += weight_xc_int8_I[i] * xi; + Fx += weight_xc_int8_F[i] * xi; + Ox += weight_xc_int8_O[i] * xi; + Gx += weight_xc_int8_G[i] * xi; + } + + int Ih = 0; + int Fh = 0; + int Oh = 0; + int Gh = 0; + for (int i = 0; i < num_output; i++) + { + signed char h_cont = hs[i]; + + Ih += weight_hc_int8_I[i] * h_cont; + Fh += weight_hc_int8_F[i] * h_cont; + Oh += weight_hc_int8_O[i] * h_cont; + Gh += weight_hc_int8_G[i] * h_cont; + } + + float I = bias_c_I[q] + Ix * (descale_x * descale_xc_I) + Ih * (descale_h * descale_hc_I); + float F = bias_c_F[q] + Fx * (descale_x * descale_xc_F) + Fh * (descale_h * descale_hc_F); + float O = bias_c_O[q] + Ox * (descale_x * descale_xc_O) + Oh * (descale_h * descale_hc_O); + float G = bias_c_G[q] + Gx * (descale_x * descale_xc_G) + Gh * (descale_h * descale_hc_G); + + gates_data[0] = I; + gates_data[1] = F; + gates_data[2] = O; + gates_data[3] = G; + } + + // lstm unit + float* output_data = top_blob.row(ti); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < hidden_size; q++) + { + const float* gates_data = gates.row(q); + + float I = gates_data[0]; + float F = gates_data[1]; + float O = gates_data[2]; + float G = gates_data[3]; + + I = 1.f / (1.f + expf(-I)); + F = 1.f / (1.f + expf(-F)); + O = 1.f / (1.f + expf(-O)); + G = tanhf(G); + + float cell2 = F * cell_state[q] + I * G; + float H = O * tanhf(cell2); + cell_state[q] = cell2; + + if (num_output == hidden_size) + { + hidden_state[q] = H; + output_data[q] = H; + } + else + { + tmp_hidden_state[q] = H; + } + } + + if (num_output != hidden_size) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + + float H = 0; + for (int i = 0; i < hidden_size; i++) + { + H += tmp_hidden_state[i] * hr[i]; + } + + hidden_state[q] = H; + output_data[q] = H; + } + } + } + + return 0; +} +#endif // NCNN_INT8 + +int LSTM_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + int T = bottom_blob.h; + + int num_directions = direction == 2 ? 2 : 1; + + // initial hidden state + Mat hidden(num_output, 4u, opt.workspace_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + + Mat cell(hidden_size, 4u, opt.workspace_allocator); + if (cell.empty()) + return -100; + cell.fill(0.f); + + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + // Uni directional + if (direction == 0 || direction == 1) + { +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + } + + if (direction == 2) + { + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + + hidden.fill(0.0f); + cell.fill(0.0f); + +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } + + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } + + return 0; +} + +int LSTM_riscv::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& bottom_blob = bottom_blobs[0]; + int T = bottom_blob.h; + int num_directions = direction == 2 ? 2 : 1; + + Mat hidden; + Mat cell; + Allocator* hidden_cell_allocator = top_blobs.size() == 3 ? opt.blob_allocator : opt.workspace_allocator; + if (bottom_blobs.size() == 3) + { + hidden = bottom_blobs[1].clone(hidden_cell_allocator); + cell = bottom_blobs[2].clone(hidden_cell_allocator); + } + else + { + hidden.create(num_output, num_directions, 4u, hidden_cell_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + + cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator); + if (cell.empty()) + return -100; + cell.fill(0.f); + } + + Mat& top_blob = top_blobs[0]; + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + // Uni directional + if (direction == 0 || direction == 1) + { +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + } + + if (direction == 2) + { + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + Mat hidden0 = hidden.row_range(0, 1); + Mat cell0 = cell.row_range(0, 1); +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); + if (ret != 0) + return ret; + } + + Mat hidden1 = hidden.row_range(1, 1); + Mat cell1 = cell.row_range(1, 1); +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); + if (ret != 0) + return ret; + } + + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } + + if (top_blobs.size() == 3) + { + top_blobs[1] = hidden; + top_blobs[2] = cell; + } + + return 0; +} + +} // namespace ncnn + + + + + + + + diff --git a/src/layer/riscv/lstm_riscv.h b/src/layer/riscv/lstm_riscv.h new file mode 100644 index 000000000000..28a6cbca7d84 --- /dev/null +++ b/src/layer/riscv/lstm_riscv.h @@ -0,0 +1,35 @@ + + + + + + + + +#ifndef LAYER_LSTM_RISCV_H +#define LAYER_LSTM_RISCV_H + +#include "lstm.h" + +namespace ncnn { + +class LSTM_riscv : public LSTM +{ +public: + LSTM_riscv(); + + virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +}; + +} // namespace ncnn + +#endif // LAYER_LSTM_RISCV_H + + + + + + + From f72bc2071f3bd24d71d76003e7a8d423efafa381 Mon Sep 17 00:00:00 2001 From: chenglimin <18213449+chenglimin@users.noreply.github.com> Date: Fri, 22 May 2026 02:45:30 +0000 Subject: [PATCH 12/14] apply code-format changes --- src/layer/riscv/lstm_riscv.cpp | 61 ++++++++++++++-------------------- src/layer/riscv/lstm_riscv.h | 13 -------- 2 files changed, 25 insertions(+), 49 deletions(-) diff --git a/src/layer/riscv/lstm_riscv.cpp b/src/layer/riscv/lstm_riscv.cpp index 9eca6cd8ae1f..5fbbf9158551 100644 --- a/src/layer/riscv/lstm_riscv.cpp +++ b/src/layer/riscv/lstm_riscv.cpp @@ -1,12 +1,5 @@ - - - - - - - #include "lstm_riscv.h" #include #include "riscv_usability.h" @@ -24,7 +17,7 @@ static inline float dot_product(const float* a, const float* b, int n) { size_t max_vl = __riscv_vsetvlmax_e32m8(); vfloat32m8_t sum_v = __riscv_vfmv_v_f_f32m8(0.f, max_vl); - + while (n > 0) { size_t vl = __riscv_vsetvl_e32m8(n); @@ -35,7 +28,7 @@ static inline float dot_product(const float* a, const float* b, int n) b += vl; n -= vl; } - + vfloat32m1_t sum_s = __riscv_vfredusum_vs_f32m8_f32m1(sum_v, __riscv_vfmv_v_f_f32m1(0.f, 1), max_vl); return __riscv_vfmv_f_s_f32m1_f32(sum_s); } @@ -46,12 +39,12 @@ static void sigmoid_vector(float* ptr, int n) { size_t vl = __riscv_vsetvl_e32m8(n); vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr, vl); - + vfloat32m8_t _neg_p = __riscv_vfmul_vf_f32m8(_p, -1.f, vl); vfloat32m8_t _exp_neg_p = exp_ps(_neg_p, vl); vfloat32m8_t _den = __riscv_vfadd_vf_f32m8(_exp_neg_p, 1.f, vl); _p = __riscv_vfrdiv_vf_f32m8(_den, 1.f, vl); - + __riscv_vse32_v_f32m8(ptr, _p, vl); ptr += vl; n -= vl; @@ -64,13 +57,13 @@ static void tanh_vector(float* ptr, int n) { size_t vl = __riscv_vsetvl_e32m8(n); vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr, vl); - + vfloat32m8_t _2x = __riscv_vfmul_vf_f32m8(_p, 2.f, vl); vfloat32m8_t _exp2x = exp_ps(_2x, vl); vfloat32m8_t _num = __riscv_vfsub_vf_f32m8(_exp2x, 1.f, vl); vfloat32m8_t _den = __riscv_vfadd_vf_f32m8(_exp2x, 1.f, vl); _p = __riscv_vfdiv_vv_f32m8(_num, _den, vl); - + __riscv_vse32_v_f32m8(ptr, _p, vl); ptr += vl; n -= vl; @@ -103,7 +96,7 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w int ti = reverse ? T - 1 - t : t; const float* x = bottom_blob.row(ti); - + float* I_ptr = gates.row(0); float* F_ptr = gates.row(1); float* O_ptr = gates.row(2); @@ -147,18 +140,18 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w O_ptr[q] = O; G_ptr[q] = G; } - + sigmoid_vector(I_ptr, hidden_size); sigmoid_vector(F_ptr, hidden_size); sigmoid_vector(O_ptr, hidden_size); tanh_vector(G_ptr, hidden_size); - + // Update cell and hidden float* cell_ptr = cell_state; float* hidden_ptr = hidden_state; float* tmp_hidden_ptr = tmp_hidden_state; float* output_data = top_blob.row(ti); - + int n = hidden_size; float* i_p = I_ptr; float* f_p = F_ptr; @@ -166,7 +159,7 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w float* g_p = G_ptr; float* c_p = cell_ptr; float* h_out_p = (num_output == hidden_size) ? hidden_ptr : tmp_hidden_ptr; - + while (n > 0) { size_t vl = __riscv_vsetvl_e32m8(n); @@ -175,34 +168,38 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w vfloat32m8_t _o = __riscv_vle32_v_f32m8(o_p, vl); vfloat32m8_t _g = __riscv_vle32_v_f32m8(g_p, vl); vfloat32m8_t _c = __riscv_vle32_v_f32m8(c_p, vl); - + // cell = F * cell + I * G vfloat32m8_t _fc = __riscv_vfmul_vv_f32m8(_f, _c, vl); vfloat32m8_t _ig = __riscv_vfmul_vv_f32m8(_i, _g, vl); _c = __riscv_vfadd_vv_f32m8(_fc, _ig, vl); __riscv_vse32_v_f32m8(c_p, _c, vl); - + // H = O * tanh(cell) vfloat32m8_t _2c = __riscv_vfmul_vf_f32m8(_c, 2.f, vl); vfloat32m8_t _exp2c = exp_ps(_2c, vl); vfloat32m8_t _num = __riscv_vfsub_vf_f32m8(_exp2c, 1.f, vl); vfloat32m8_t _den = __riscv_vfadd_vf_f32m8(_exp2c, 1.f, vl); vfloat32m8_t _tanh_c = __riscv_vfdiv_vv_f32m8(_num, _den, vl); - + vfloat32m8_t _h = __riscv_vfmul_vv_f32m8(_o, _tanh_c, vl); __riscv_vse32_v_f32m8(h_out_p, _h, vl); - + if (num_output == hidden_size) { __riscv_vse32_v_f32m8(output_data, _h, vl); output_data += vl; } - - i_p += vl; f_p += vl; o_p += vl; g_p += vl; - c_p += vl; h_out_p += vl; + + i_p += vl; + f_p += vl; + o_p += vl; + g_p += vl; + c_p += vl; + h_out_p += vl; n -= vl; } - + if (num_output != hidden_size) { // Projection @@ -211,9 +208,9 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w { const float* hr = weight_hr.row(q); const float* tmp_h = tmp_hidden_state; - + float H = dot_product(hr, tmp_h, hidden_size); - + hidden_state[q] = H; top_blob.row(ti)[q] = H; } @@ -643,11 +640,3 @@ int LSTM_riscv::forward(const std::vector& bottom_blobs, std::vector& } } // namespace ncnn - - - - - - - - diff --git a/src/layer/riscv/lstm_riscv.h b/src/layer/riscv/lstm_riscv.h index 28a6cbca7d84..915ff82bdeb8 100644 --- a/src/layer/riscv/lstm_riscv.h +++ b/src/layer/riscv/lstm_riscv.h @@ -1,11 +1,5 @@ - - - - - - #ifndef LAYER_LSTM_RISCV_H #define LAYER_LSTM_RISCV_H @@ -26,10 +20,3 @@ class LSTM_riscv : public LSTM } // namespace ncnn #endif // LAYER_LSTM_RISCV_H - - - - - - - From 4486cf926a60c57482fc1daa8c2fb27fb73708d7 Mon Sep 17 00:00:00 2001 From: chenglimin Date: Fri, 22 May 2026 11:20:00 +0800 Subject: [PATCH 13/14] add riscv rvv support for lstm operator --- src/layer/riscv/deformableconv2d_pack1ton.h | 146 ----- src/layer/riscv/deformableconv2d_packn.h | 158 ------ src/layer/riscv/deformableconv2d_packnto1.h | 158 ------ src/layer/riscv/deformableconv2d_riscv.cpp | 556 -------------------- src/layer/riscv/deformableconv2d_riscv.h | 31 -- 5 files changed, 1049 deletions(-) delete mode 100644 src/layer/riscv/deformableconv2d_pack1ton.h delete mode 100644 src/layer/riscv/deformableconv2d_packn.h delete mode 100644 src/layer/riscv/deformableconv2d_packnto1.h delete mode 100644 src/layer/riscv/deformableconv2d_riscv.cpp delete mode 100644 src/layer/riscv/deformableconv2d_riscv.h diff --git a/src/layer/riscv/deformableconv2d_pack1ton.h b/src/layer/riscv/deformableconv2d_pack1ton.h deleted file mode 100644 index e5ba658c9551..000000000000 --- a/src/layer/riscv/deformableconv2d_pack1ton.h +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2026 Tencent -// SPDX-License-Identifier: BSD-3-Clause - -static void deformableconv2d_pack1ton(const std::vector& bottom_blobs, Mat& top_blob, const Mat& weight_data_packed, const Mat& bias_data, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int pad_left, int pad_top, int activation_type, const Mat& activation_params, const Option& opt) -{ - const Mat& bottom_blob = bottom_blobs[0]; - const Mat& offset = bottom_blobs[1]; - const bool has_mask = (bottom_blobs.size() == 3); - const bool offset_not_pack = offset.elempack == 1; - const bool mask_not_pack = has_mask ? bottom_blobs[2].elempack == 1 : true; - - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - const float* bias_data_ptr = bias_data; - const int packn = csrr_vlenb() / 4; - const size_t vl = __riscv_vsetvl_e32m1(packn); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int h_col = 0; h_col < outh; h_col++) - { - for (int w_col = 0; w_col < outw; w_col++) - { - int h_in = h_col * stride_h - pad_top; - int w_in = w_col * stride_w - pad_left; - for (int oc = 0; oc < outch; oc++) - { - const float* kptr = weight_data_packed.channel(oc); - float* outptr = top_blob.channel(oc); - vfloat32m1_t _sum = __riscv_vfmv_v_f_f32m1(0.f, vl); - if (bias_data_ptr) - _sum = __riscv_vle32_v_f32m1(bias_data_ptr + oc * packn, vl); - - for (int i = 0; i < kernel_h; i++) - { - for (int j = 0; j < kernel_w; j++) - { - float offset_h = 0.f; - float offset_w = 0.f; - float mask_ = 1.f; - if (offset_not_pack) - { - offset_h = offset.channel((i * kernel_w + j) * 2).row(h_col)[w_col]; - offset_w = offset.channel((i * kernel_w + j) * 2 + 1).row(h_col)[w_col]; - } - else - { - const int y_c = (i * kernel_w + j) * 2; - const int x_c = (i * kernel_w + j) * 2 + 1; - offset_h = offset.channel(y_c / offset.elempack).row(h_col)[w_col * offset.elempack + y_c % offset.elempack]; - offset_w = offset.channel(x_c / offset.elempack).row(h_col)[w_col * offset.elempack + x_c % offset.elempack]; - } - if (has_mask) - { - const Mat& mask = bottom_blobs[2]; - if (mask_not_pack) - { - mask_ = mask.channel(i * kernel_w + j).row(h_col)[w_col]; - } - else - { - const int m_c = i * kernel_w + j; - mask_ = mask.channel(m_c / mask.elempack).row(h_col)[w_col * mask.elempack + m_c % mask.elempack]; - } - } - const float h_im = h_in + i * dilation_h + offset_h; - const float w_im = w_in + j * dilation_w + offset_w; - - // Bilinear - const bool cond = h_im > -1 && w_im > -1 && h_im < h && w_im < w; - float w1 = 0.f; - float w2 = 0.f; - float w3 = 0.f; - float w4 = 0.f; - bool v1_cond = false; - bool v2_cond = false; - bool v3_cond = false; - bool v4_cond = false; - int v1_pos = 0; - int v2_pos = 0; - int v3_pos = 0; - int v4_pos = 0; - if (cond) - { - int h_low = (int)floorf(h_im); - int w_low = (int)floorf(w_im); - int h_high = h_low + 1; - int w_high = w_low + 1; - - float lh = h_im - h_low; - float lw = w_im - w_low; - float hh = 1 - lh; - float hw = 1 - lw; - - v1_cond = (h_low >= 0 && w_low >= 0); - v2_cond = (h_low >= 0 && w_high <= w - 1); - v3_cond = (h_high <= h - 1 && w_low >= 0); - v4_cond = (h_high <= h - 1 && w_high <= w - 1); - if (v1_cond) - v1_pos = h_low * w + w_low; - if (v2_cond) - v2_pos = h_low * w + w_high; - if (v3_cond) - v3_pos = h_high * w + w_low; - if (v4_cond) - v4_pos = h_high * w + w_high; - - w1 = hh * hw; - w2 = hh * lw; - w3 = lh * hw; - w4 = lh * lw; - } - - for (int ic = 0; ic < inch; ic++) - { - const float* data_im_ptr = bottom_blob.channel(ic); - - if (cond) - { - float v_in = 0.f; - if (v1_cond) v_in += data_im_ptr[v1_pos] * w1; - if (v2_cond) v_in += data_im_ptr[v2_pos] * w2; - if (v3_cond) v_in += data_im_ptr[v3_pos] * w3; - if (v4_cond) v_in += data_im_ptr[v4_pos] * w4; - - if (has_mask) v_in *= mask_; - - vfloat32m1_t _w = __riscv_vle32_v_f32m1(kptr, vl); - _sum = __riscv_vfmacc_vf_f32m1(_sum, v_in, _w, vl); - } - - kptr += packn; - } - } - } - _sum = activation_ps(_sum, activation_type, activation_params, vl); - __riscv_vse32_v_f32m1(outptr + (h_col * outw + w_col) * packn, _sum, vl); - } - } - } -} diff --git a/src/layer/riscv/deformableconv2d_packn.h b/src/layer/riscv/deformableconv2d_packn.h deleted file mode 100644 index 3a03e9b0b552..000000000000 --- a/src/layer/riscv/deformableconv2d_packn.h +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright 2026 Tencent -// SPDX-License-Identifier: BSD-3-Clause - -static void deformableconv2d_packn(const std::vector& bottom_blobs, Mat& top_blob, const Mat& weight_data_packed, const Mat& bias_data, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int pad_left, int pad_top, int activation_type, const Mat& activation_params, const Option& opt) -{ - const Mat& bottom_blob = bottom_blobs[0]; - const Mat& offset = bottom_blobs[1]; - const bool has_mask = (bottom_blobs.size() == 3); - const bool offset_not_pack = offset.elempack == 1; - const bool mask_not_pack = has_mask ? bottom_blobs[2].elempack == 1 : true; - - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - const float* bias_data_ptr = bias_data; - const int packn = csrr_vlenb() / 4; - const size_t vl = __riscv_vsetvl_e32m1(packn); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int h_col = 0; h_col < outh; h_col++) - { - for (int w_col = 0; w_col < outw; w_col++) - { - int h_in = h_col * stride_h - pad_top; - int w_in = w_col * stride_w - pad_left; - for (int oc = 0; oc < outch; oc++) - { - const float* kptr = weight_data_packed.channel(oc); - float* outptr = top_blob.channel(oc); - vfloat32m1_t _sum = __riscv_vfmv_v_f_f32m1(0.f, vl); - if (bias_data_ptr) - _sum = __riscv_vle32_v_f32m1(bias_data_ptr + oc * packn, vl); - - for (int i = 0; i < kernel_h; i++) - { - for (int j = 0; j < kernel_w; j++) - { - float offset_h = 0.f; - float offset_w = 0.f; - float mask_ = 1.f; - if (offset_not_pack) - { - offset_h = offset.channel((i * kernel_w + j) * 2).row(h_col)[w_col]; - offset_w = offset.channel((i * kernel_w + j) * 2 + 1).row(h_col)[w_col]; - } - else - { - const int y_c = (i * kernel_w + j) * 2; - const int x_c = (i * kernel_w + j) * 2 + 1; - offset_h = offset.channel(y_c / offset.elempack).row(h_col)[w_col * offset.elempack + y_c % offset.elempack]; - offset_w = offset.channel(x_c / offset.elempack).row(h_col)[w_col * offset.elempack + x_c % offset.elempack]; - } - if (has_mask) - { - const Mat& mask = bottom_blobs[2]; - if (mask_not_pack) - { - mask_ = mask.channel(i * kernel_w + j).row(h_col)[w_col]; - } - else - { - const int m_c = i * kernel_w + j; - mask_ = mask.channel(m_c / mask.elempack).row(h_col)[w_col * mask.elempack + m_c % mask.elempack]; - } - } - const float h_im = h_in + i * dilation_h + offset_h; - const float w_im = w_in + j * dilation_w + offset_w; - - // Bilinear - const bool cond = h_im > -1 && w_im > -1 && h_im < h && w_im < w; - float w1 = 0.f; - float w2 = 0.f; - float w3 = 0.f; - float w4 = 0.f; - bool v1_cond = false; - bool v2_cond = false; - bool v3_cond = false; - bool v4_cond = false; - int v1_pos = 0; - int v2_pos = 0; - int v3_pos = 0; - int v4_pos = 0; - if (cond) - { - int h_low = (int)floorf(h_im); - int w_low = (int)floorf(w_im); - int h_high = h_low + 1; - int w_high = w_low + 1; - - float lh = h_im - h_low; - float lw = w_im - w_low; - float hh = 1 - lh; - float hw = 1 - lw; - - v1_cond = (h_low >= 0 && w_low >= 0); - v2_cond = (h_low >= 0 && w_high <= w - 1); - v3_cond = (h_high <= h - 1 && w_low >= 0); - v4_cond = (h_high <= h - 1 && w_high <= w - 1); - if (v1_cond) - v1_pos = h_low * w + w_low; - if (v2_cond) - v2_pos = h_low * w + w_high; - if (v3_cond) - v3_pos = h_high * w + w_low; - if (v4_cond) - v4_pos = h_high * w + w_high; - - w1 = hh * hw; - w2 = hh * lw; - w3 = lh * hw; - w4 = lh * lw; - } - - for (int ic = 0; ic < inch; ic++) - { - const float* data_im_ptr = bottom_blob.channel(ic); - - if (cond) - { - vfloat32m1_t _val = __riscv_vfmv_v_f_f32m1(0.f, vl); - - // Packed-weight memory layout for packn: - // For each output-channel pack, kernel position (kh, kw) and input-channel pack, - // the weights are stored as a contiguous block of size packn_in * packn_out - // (with packn_in == packn_out == packn here). Within this block, lane k in - // the input pack uses the vector loaded from kptr + k * packn, which contains - // the weights from that input lane to all packn output channels. After all - // packn input lanes are processed, kptr is advanced by packn * packn. - for (int k = 0; k < packn; k++) - { - float v_in = 0.f; - if (v1_cond) v_in += data_im_ptr[v1_pos * packn + k] * w1; - if (v2_cond) v_in += data_im_ptr[v2_pos * packn + k] * w2; - if (v3_cond) v_in += data_im_ptr[v3_pos * packn + k] * w3; - if (v4_cond) v_in += data_im_ptr[v4_pos * packn + k] * w4; - - if (has_mask) v_in *= mask_; - - vfloat32m1_t _w = __riscv_vle32_v_f32m1(kptr + k * packn, vl); - _sum = __riscv_vfmacc_vf_f32m1(_sum, v_in, _w, vl); - } - } - - kptr += packn * packn; - } - } - } - _sum = activation_ps(_sum, activation_type, activation_params, vl); - __riscv_vse32_v_f32m1(outptr + (h_col * outw + w_col) * packn, _sum, vl); - } - } - } -} diff --git a/src/layer/riscv/deformableconv2d_packnto1.h b/src/layer/riscv/deformableconv2d_packnto1.h deleted file mode 100644 index d84ccd0a77c8..000000000000 --- a/src/layer/riscv/deformableconv2d_packnto1.h +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright 2026 Tencent -// SPDX-License-Identifier: BSD-3-Clause - -static void deformableconv2d_packnto1(const std::vector& bottom_blobs, Mat& top_blob, const Mat& weight_data_packed, const Mat& bias_data, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int pad_left, int pad_top, int activation_type, const Mat& activation_params, const Option& opt) -{ - const Mat& bottom_blob = bottom_blobs[0]; - const Mat& offset = bottom_blobs[1]; - const bool has_mask = (bottom_blobs.size() == 3); - const bool offset_not_pack = offset.elempack == 1; - const bool mask_not_pack = has_mask ? bottom_blobs[2].elempack == 1 : true; - - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - const float* bias_data_ptr = bias_data; - const int packn = csrr_vlenb() / 4; - const size_t vl = __riscv_vsetvl_e32m1(packn); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int h_col = 0; h_col < outh; h_col++) - { - for (int w_col = 0; w_col < outw; w_col++) - { - int h_in = h_col * stride_h - pad_top; - int w_in = w_col * stride_w - pad_left; - for (int oc = 0; oc < outch; oc++) - { - const float* kptr = weight_data_packed.channel(oc); - float* outptr = top_blob.channel(oc); - float sum = 0.f; - if (bias_data_ptr) - sum = bias_data_ptr[oc]; - - vfloat32m1_t _sum = __riscv_vfmv_v_f_f32m1(0.f, vl); - - for (int i = 0; i < kernel_h; i++) - { - for (int j = 0; j < kernel_w; j++) - { - float offset_h = 0.f; - float offset_w = 0.f; - float mask_ = 1.f; - if (offset_not_pack) - { - offset_h = offset.channel((i * kernel_w + j) * 2).row(h_col)[w_col]; - offset_w = offset.channel((i * kernel_w + j) * 2 + 1).row(h_col)[w_col]; - } - else - { - const int y_c = (i * kernel_w + j) * 2; - const int x_c = (i * kernel_w + j) * 2 + 1; - offset_h = offset.channel(y_c / offset.elempack).row(h_col)[w_col * offset.elempack + y_c % offset.elempack]; - offset_w = offset.channel(x_c / offset.elempack).row(h_col)[w_col * offset.elempack + x_c % offset.elempack]; - } - if (has_mask) - { - const Mat& mask = bottom_blobs[2]; - if (mask_not_pack) - { - mask_ = mask.channel(i * kernel_w + j).row(h_col)[w_col]; - } - else - { - const int m_c = i * kernel_w + j; - mask_ = mask.channel(m_c / mask.elempack).row(h_col)[w_col * mask.elempack + m_c % mask.elempack]; - } - } - const float h_im = h_in + i * dilation_h + offset_h; - const float w_im = w_in + j * dilation_w + offset_w; - - // Bilinear - const bool cond = h_im > -1 && w_im > -1 && h_im < h && w_im < w; - float w1 = 0.f; - float w2 = 0.f; - float w3 = 0.f; - float w4 = 0.f; - bool v1_cond = false; - bool v2_cond = false; - bool v3_cond = false; - bool v4_cond = false; - int v1_pos = 0; - int v2_pos = 0; - int v3_pos = 0; - int v4_pos = 0; - if (cond) - { - int h_low = (int)floorf(h_im); - int w_low = (int)floorf(w_im); - int h_high = h_low + 1; - int w_high = w_low + 1; - - float lh = h_im - h_low; - float lw = w_im - w_low; - float hh = 1 - lh; - float hw = 1 - lw; - - v1_cond = (h_low >= 0 && w_low >= 0); - v2_cond = (h_low >= 0 && w_high <= w - 1); - v3_cond = (h_high <= h - 1 && w_low >= 0); - v4_cond = (h_high <= h - 1 && w_high <= w - 1); - if (v1_cond) - v1_pos = h_low * w + w_low; - if (v2_cond) - v2_pos = h_low * w + w_high; - if (v3_cond) - v3_pos = h_high * w + w_low; - if (v4_cond) - v4_pos = h_high * w + w_high; - - w1 = hh * hw; - w2 = hh * lw; - w3 = lh * hw; - w4 = lh * lw; - } - - for (int ic = 0; ic < inch; ic++) - { - const float* data_im_ptr = bottom_blob.channel(ic); - - if (cond) - { - vfloat32m1_t _v1 = v1_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v1_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); - vfloat32m1_t _v2 = v2_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v2_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); - vfloat32m1_t _v3 = v3_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v3_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); - vfloat32m1_t _v4 = v4_cond ? __riscv_vle32_v_f32m1(data_im_ptr + v4_pos * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); - - vfloat32m1_t _val = __riscv_vfmv_v_f_f32m1(0.f, vl); - _val = __riscv_vfmacc_vf_f32m1(_val, w1, _v1, vl); - _val = __riscv_vfmacc_vf_f32m1(_val, w2, _v2, vl); - _val = __riscv_vfmacc_vf_f32m1(_val, w3, _v3, vl); - _val = __riscv_vfmacc_vf_f32m1(_val, w4, _v4, vl); - - if (has_mask) - _val = __riscv_vfmul_vf_f32m1(_val, mask_, vl); - - vfloat32m1_t _w = __riscv_vle32_v_f32m1(kptr, vl); - _sum = __riscv_vfmacc_vv_f32m1(_sum, _val, _w, vl); - } - - kptr += packn; - } - } - } - - vfloat32m1_t _v_sum = __riscv_vfredusum_vs_f32m1_f32m1(_sum, __riscv_vfmv_v_f_f32m1(0.f, vl), vl); - sum += __riscv_vfmv_f_s_f32m1_f32(_v_sum); - - sum = activation_ss(sum, activation_type, activation_params); - outptr[h_col * outw + w_col] = sum; - } - } - } -} diff --git a/src/layer/riscv/deformableconv2d_riscv.cpp b/src/layer/riscv/deformableconv2d_riscv.cpp deleted file mode 100644 index 1bf560afc428..000000000000 --- a/src/layer/riscv/deformableconv2d_riscv.cpp +++ /dev/null @@ -1,556 +0,0 @@ -// Copyright 2026 Tencent -// SPDX-License-Identifier: BSD-3-Clause - -#include "deformableconv2d_riscv.h" - -#if __riscv_vector -#include -#endif // __riscv_vector - -#include "riscv_activation.h" -#include "riscv_usability.h" - -#include "benchmark.h" -#include "cpu.h" -#include "layer_type.h" - -namespace ncnn { - -#if __riscv_vector -#include "deformableconv2d_packn.h" -#include "deformableconv2d_pack1ton.h" -#include "deformableconv2d_packnto1.h" -#endif // __riscv_vector - -DeformableConv2D_riscv::DeformableConv2D_riscv() -{ -#if __riscv_vector - support_packing = true; -#endif // __riscv_vector - - activation = 0; - gemm = 0; -} - -static int _4Dindex_to_1Dindex(int i0, int i1, int i2, int i3, int l1, int l2, int l3) -{ - return ((i0 * l1 + i1) * l2 + i2) * l3 + i3; -} - -static int _6Dindex_to_1Dindex(int i0, int i1, int i2, int i3, int i4, int i5, int l1, int l2, int l3, int l4, int l5) -{ - return ((((i0 * l1 + i1) * l2 + i2) * l3 + i3) * l4 + i4) * l5 + i5; -} - -#if __riscv_vector -static void deformableconv2d_transform_kernel_packed_riscv(const Mat& weight_data, Mat& weight_data_tm, int num_input, int num_output, int kernel_w, int kernel_h, int elempack, int out_elempack) -{ - const int maxk = kernel_w * kernel_h; - - // src = kw-kh-inch-outch - // dst = pb-pa-inch/pa-kw-kh-outch/pb - { - const float* weight_ptr = weight_data; - - weight_data_tm.create(num_input * maxk * num_output / (elempack * out_elempack), (size_t)4u * elempack * out_elempack, elempack * out_elempack); - float* ptr = weight_data_tm; - for (int oc = 0; oc < num_output; oc++) - { - for (int i = 0; i < kernel_h; i++) - { - for (int j = 0; j < kernel_w; j++) - { - for (int ic = 0; ic < num_input; ic++) - { - ptr[_6Dindex_to_1Dindex(oc / out_elempack, i, j, ic / elempack, ic % elempack, oc % out_elempack, kernel_h, kernel_w, num_input / elempack, elempack, out_elempack)] = weight_ptr[_4Dindex_to_1Dindex(oc, ic, i, j, num_input, kernel_h, kernel_w)]; - } - } - } - } - weight_data_tm = weight_data_tm.reshape(num_input / elempack, maxk, num_output / out_elempack); - } -} -#endif // __riscv_vector - -int DeformableConv2D_riscv::create_pipeline(const Option& opt) -{ - activation = create_activation_layer(activation_type, activation_params, opt); - - int kernel_size = kernel_w * kernel_h; - int num_input = weight_data_size / kernel_size / num_output; - - int elempack = 1; - int out_elempack = 1; - -#if __riscv_vector - if (opt.use_packing_layout) - { - const int packn = csrr_vlenb() / 4; - elempack = num_input % packn == 0 ? packn : 1; - out_elempack = num_output % packn == 0 ? packn : 1; - } -#endif // __riscv_vector - - if (opt.use_sgemm_convolution) - { - const int maxk = kernel_w * kernel_h; - - gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); - - ncnn::ParamDict pd; - pd.set(2, 0); // transA - pd.set(3, 0); // transB - pd.set(4, 1); // constantA - pd.set(5, 0); // constantB - pd.set(6, 1); // constantC - pd.set(7, num_output); // M = outch - pd.set(8, 0); // N = size - pd.set(9, maxk * num_input); // K = maxk*inch - pd.set(10, bias_term ? 1 : -1); // constant_broadcast_type_C = (M) - pd.set(11, 1); // output_N1M - - gemm->load_param(pd); - - // maxk-inch-outch to pa-maxk-inch/pa-outch - Mat tmp; - { - Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); - - tmp.create(maxk * num_input, num_output); - - for (int q = 0; q < num_output; q += 1) - { - float* g00 = tmp.row(q); - - for (int p = 0; p + (elempack - 1) < num_input; p += elempack) - { - for (int k = 0; k < maxk; k++) - { - for (int i = 0; i < elempack; i++) - { - const float* k00 = weight_data_r2.channel(q).row(p + i); - g00[0] = k00[k]; - g00++; - } - } - } - } - } - - if (bias_term) - { - ncnn::Mat weights[2]; - weights[0] = tmp; - weights[1] = bias_data; - - gemm->load_model(ModelBinFromMatArray(weights)); - } - else - { - ncnn::Mat weights[1]; - weights[0] = tmp; - - gemm->load_model(ModelBinFromMatArray(weights)); - } - - gemm->create_pipeline(opt); - } - else if (elempack == 1 && out_elempack == 1) - { - weight_data_tm = weight_data; - } - else - { -#if __riscv_vector - deformableconv2d_transform_kernel_packed_riscv(weight_data, weight_data_tm, num_input, num_output, kernel_w, kernel_h, elempack, out_elempack); -#endif // __riscv_vector - } - - if (opt.lightmode) - { - weight_data.release(); - } - - return 0; -} - -int DeformableConv2D_riscv::destroy_pipeline(const Option& opt) -{ - if (activation) - { - activation->destroy_pipeline(opt); - delete activation; - activation = 0; - } - - if (gemm) - { - gemm->destroy_pipeline(opt); - delete gemm; - gemm = 0; - } - - return 0; -} - -int DeformableConv2D_riscv::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const -{ - const Mat& bottom_blob = bottom_blobs[0]; - const Mat& offset = bottom_blobs[1]; - const bool has_mask = (bottom_blobs.size() == 3); - Mat& top_blob = top_blobs[0]; - - int w = bottom_blob.w; - int h = bottom_blob.h; - int channels = bottom_blob.c; - size_t elemsize = bottom_blob.elemsize; - int elempack = bottom_blob.elempack; - - const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; - const int kernel_extent_h = dilation_h * (kernel_h - 1) + 1; - const int outw = (w + pad_left + pad_right - kernel_extent_w) / stride_w + 1; - const int outh = (h + pad_top + pad_bottom - kernel_extent_h) / stride_h + 1; - - int out_elempack = 1; -#if __riscv_vector - if (opt.use_packing_layout) - { - const int packn = csrr_vlenb() / 4; - out_elempack = num_output % packn == 0 ? packn : 1; - } -#endif // __riscv_vector - size_t out_elemsize = elemsize / elempack * out_elempack; - - top_blob.create(outw, outh, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); - if (top_blob.empty()) - return -100; - - if (opt.use_sgemm_convolution) - { - const int size = outw * outh; - const int maxk = kernel_w * kernel_h; - - Mat offset_unpacked; - convert_packing(offset, offset_unpacked, 1, opt); - - Mat mask_unpacked; - if (has_mask) - { - const Mat& mask = bottom_blobs[2]; - convert_packing(mask, mask_unpacked, 1, opt); - } - - // im2col - Mat bottom_im2col(size, maxk * channels, elemsize, elempack, opt.workspace_allocator); - -#if __riscv_vector - const int packn = csrr_vlenb() / 4; - if (elempack == packn) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < channels; p++) - { - const Mat img = bottom_blob.channel(p); - float* ptr = bottom_im2col.row(p * maxk); - - for (int u = 0; u < kernel_h; u++) - { - for (int v = 0; v < kernel_w; v++) - { - const Mat offset_h_k = offset_unpacked.channel((u * kernel_w + v) * 2); - const Mat offset_w_k = offset_unpacked.channel((u * kernel_w + v) * 2 + 1); - const Mat mask_k = has_mask ? mask_unpacked.channel(u * kernel_w + v) : 0; - - for (int i = 0; i < outh; i++) - { - for (int j = 0; j < outw; j++) - { - float offset_h = offset_h_k.row(i)[j]; - float offset_w = offset_w_k.row(i)[j]; - - int h_in = i * stride_h - pad_top; - int w_in = j * stride_w - pad_left; - - const float h_im = h_in + u * dilation_h + offset_h; - const float w_im = w_in + v * dilation_w + offset_w; - - // Bilinear - size_t vl = __riscv_vsetvl_e32m1(packn); - vfloat32m1_t _val = __riscv_vfmv_v_f_f32m1(0.f, vl); - bool cond = h_im > -1 && w_im > -1 && h_im < h && w_im < w; - if (cond) - { - int h_low = floor(h_im); - int w_low = floor(w_im); - int h_high = h_low + 1; - int w_high = w_low + 1; - - float lh = h_im - h_low; - float lw = w_im - w_low; - float hh = 1 - lh; - float hw = 1 - lw; - - bool v1_cond = (h_low >= 0 && w_low >= 0); - bool v2_cond = (h_low >= 0 && w_high <= w - 1); - bool v3_cond = (h_high <= h - 1 && w_low >= 0); - bool v4_cond = (h_high <= h - 1 && w_high <= w - 1); - - float w1 = hh * hw; - float w2 = hh * lw; - float w3 = lh * hw; - float w4 = lh * lw; - - vfloat32m1_t _v1 = v1_cond ? __riscv_vle32_v_f32m1((const float*)img.row(h_low) + w_low * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); - vfloat32m1_t _v2 = v2_cond ? __riscv_vle32_v_f32m1((const float*)img.row(h_low) + w_high * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); - vfloat32m1_t _v3 = v3_cond ? __riscv_vle32_v_f32m1((const float*)img.row(h_high) + w_low * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); - vfloat32m1_t _v4 = v4_cond ? __riscv_vle32_v_f32m1((const float*)img.row(h_high) + w_high * packn, vl) : __riscv_vfmv_v_f_f32m1(0.f, vl); - - _val = __riscv_vfmacc_vf_f32m1(_val, w1, _v1, vl); - _val = __riscv_vfmacc_vf_f32m1(_val, w2, _v2, vl); - _val = __riscv_vfmacc_vf_f32m1(_val, w3, _v3, vl); - _val = __riscv_vfmacc_vf_f32m1(_val, w4, _v4, vl); - - if (has_mask) - _val = __riscv_vfmul_vf_f32m1(_val, mask_k.row(i)[j], vl); - } - - __riscv_vse32_v_f32m1(ptr, _val, vl); - - ptr += packn; - } - } - } - } - } - } -#endif // __riscv_vector - - if (elempack == 1) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < channels; p++) - { - const Mat img = bottom_blob.channel(p); - float* ptr = bottom_im2col.row(p * maxk); - - for (int u = 0; u < kernel_h; u++) - { - for (int v = 0; v < kernel_w; v++) - { - const Mat offset_h_k = offset_unpacked.channel((u * kernel_w + v) * 2); - const Mat offset_w_k = offset_unpacked.channel((u * kernel_w + v) * 2 + 1); - const Mat mask_k = has_mask ? mask_unpacked.channel(u * kernel_w + v) : 0; - - for (int i = 0; i < outh; i++) - { - for (int j = 0; j < outw; j++) - { - float offset_h = offset_h_k.row(i)[j]; - float offset_w = offset_w_k.row(i)[j]; - - int h_in = i * stride_h - pad_top; - int w_in = j * stride_w - pad_left; - - const float h_im = h_in + u * dilation_h + offset_h; - const float w_im = w_in + v * dilation_w + offset_w; - - // Bilinear - float val = 0.f; - bool cond = h_im > -1 && w_im > -1 && h_im < h && w_im < w; - if (cond) - { - int h_low = (int)floorf(h_im); - int w_low = (int)floorf(w_im); - int h_high = h_low + 1; - int w_high = w_low + 1; - - float lh = h_im - h_low; - float lw = w_im - w_low; - float hh = 1 - lh; - float hw = 1 - lw; - - bool v1_cond = (h_low >= 0 && w_low >= 0); - bool v2_cond = (h_low >= 0 && w_high <= w - 1); - bool v3_cond = (h_high <= h - 1 && w_low >= 0); - bool v4_cond = (h_high <= h - 1 && w_high <= w - 1); - - float w1 = hh * hw; - float w2 = hh * lw; - float w3 = lh * hw; - float w4 = lh * lw; - - float v1 = v1_cond ? img.row(h_low)[w_low] : 0.f; - float v2 = v2_cond ? img.row(h_low)[w_high] : 0.f; - float v3 = v3_cond ? img.row(h_high)[w_low] : 0.f; - float v4 = v4_cond ? img.row(h_high)[w_high] : 0.f; - val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; - - if (has_mask) - val *= mask_k.row(i)[j]; - } - - ptr[0] = val; - - ptr += 1; - } - } - } - } - } - } - - // sgemm - { - top_blob.w = outw * outh; - top_blob.h = 1; - } - Option opt_b = opt; - opt_b.blob_allocator = opt.workspace_allocator; - gemm->forward(bottom_im2col, top_blob, opt_b); - { - top_blob.w = outw; - top_blob.h = outh; - } - - if (activation) - { - activation->forward_inplace(top_blob, opt); - } - } - else - { -#if __riscv_vector - const int packn = csrr_vlenb() / 4; - - if (elempack == packn && out_elempack == packn) - { - deformableconv2d_packn(bottom_blobs, top_blob, weight_data_tm, bias_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, pad_left, pad_top, activation_type, activation_params, opt); - } - - if (elempack == 1 && out_elempack == packn) - { - deformableconv2d_pack1ton(bottom_blobs, top_blob, weight_data_tm, bias_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, pad_left, pad_top, activation_type, activation_params, opt); - } - - if (elempack == packn && out_elempack == 1) - { - deformableconv2d_packnto1(bottom_blobs, top_blob, weight_data_tm, bias_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, pad_left, pad_top, activation_type, activation_params, opt); - } -#endif // __riscv_vector - if (elempack == 1 && out_elempack == 1) - { - const bool offset_not_pack = offset.elempack == 1; - const bool mask_not_pack = has_mask ? bottom_blobs[2].elempack == 1 : true; - const float* weight_ptr = weight_data_tm; - - // naive deformable conv - #pragma omp parallel for num_threads(opt.num_threads) - for (int h_col = 0; h_col < outh; h_col++) - { - for (int w_col = 0; w_col < outw; w_col++) - { - int h_in = h_col * stride_h - pad_top; - int w_in = w_col * stride_w - pad_left; - for (int oc = 0; oc < num_output; oc++) - { - float sum = 0.f; - if (bias_term) - sum = bias_data[oc]; - for (int i = 0; i < kernel_h; i++) - { - for (int j = 0; j < kernel_w; j++) - { - float offset_h = 0.f; - float offset_w = 0.f; - float mask_ = 1.f; - if (offset_not_pack) - { - offset_h = offset.channel((i * kernel_w + j) * 2).row(h_col)[w_col]; - offset_w = offset.channel((i * kernel_w + j) * 2 + 1).row(h_col)[w_col]; - } - else - { - const int y_c = (i * kernel_w + j) * 2; - const int x_c = (i * kernel_w + j) * 2 + 1; - offset_h = offset.channel(y_c / offset.elempack).row(h_col)[w_col * offset.elempack + y_c % offset.elempack]; - offset_w = offset.channel(x_c / offset.elempack).row(h_col)[w_col * offset.elempack + x_c % offset.elempack]; - } - if (has_mask) - { - const Mat& mask = bottom_blobs[2]; - if (mask_not_pack) - { - mask_ = mask.channel(i * kernel_w + j).row(h_col)[w_col]; - } - else - { - const int m_c = i * kernel_w + j; - mask_ = mask.channel(m_c / mask.elempack).row(h_col)[w_col * mask.elempack + m_c % mask.elempack]; - } - } - const float h_im = h_in + i * dilation_h + offset_h; - const float w_im = w_in + j * dilation_w + offset_w; - - // Bilinear - const bool cond = h_im > -1 && w_im > -1 && h_im < h && w_im < w; - int h_low = 0; - int w_low = 0; - int h_high = 0; - int w_high = 0; - float w1 = 0.f; - float w2 = 0.f; - float w3 = 0.f; - float w4 = 0.f; - bool v1_cond = false; - bool v2_cond = false; - bool v3_cond = false; - bool v4_cond = false; - if (cond) - { - h_low = (int)floorf(h_im); - w_low = (int)floorf(w_im); - h_high = h_low + 1; - w_high = w_low + 1; - - float lh = h_im - h_low; - float lw = w_im - w_low; - float hh = 1 - lh; - float hw = 1 - lw; - - v1_cond = (h_low >= 0 && w_low >= 0); - v2_cond = (h_low >= 0 && w_high <= w - 1); - v3_cond = (h_high <= h - 1 && w_low >= 0); - v4_cond = (h_high <= h - 1 && w_high <= w - 1); - - w1 = hh * hw; - w2 = hh * lw; - w3 = lh * hw; - w4 = lh * lw; - } - - for (int ic = 0; ic < channels; ic++) - { - float val = 0.f; - if (cond) - { - float v1 = v1_cond ? bottom_blob.channel(ic).row(h_low)[w_low] : 0.f; - float v2 = v2_cond ? bottom_blob.channel(ic).row(h_low)[w_high] : 0.f; - float v3 = v3_cond ? bottom_blob.channel(ic).row(h_high)[w_low] : 0.f; - float v4 = v4_cond ? bottom_blob.channel(ic).row(h_high)[w_high] : 0.f; - val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; - } - sum += val * mask_ * weight_ptr[((oc * channels + ic) * kernel_h + i) * kernel_w + j]; - } - } - } - top_blob.channel(oc).row(h_col)[w_col] = activation_ss(sum, activation_type, activation_params); - } - } - } - } - } - - return 0; -} - -} // namespace ncnn diff --git a/src/layer/riscv/deformableconv2d_riscv.h b/src/layer/riscv/deformableconv2d_riscv.h deleted file mode 100644 index d538e5097075..000000000000 --- a/src/layer/riscv/deformableconv2d_riscv.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2026 Tencent -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef LAYER_DEFORMABLECONV2D_RISCV_H -#define LAYER_DEFORMABLECONV2D_RISCV_H - -#include "deformableconv2d.h" - -namespace ncnn { - -class DeformableConv2D_riscv : public DeformableConv2D -{ -public: - DeformableConv2D_riscv(); - - virtual int create_pipeline(const Option& opt); - virtual int destroy_pipeline(const Option& opt); - - virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; - -public: - Layer* activation; - - Mat weight_data_tm; - - Layer* gemm; -}; - -} // namespace ncnn - -#endif // LAYER_DEFORMABLECONV2D_RISCV_H From e3be713501b6f30c7c73eeb16bba46c4087893f5 Mon Sep 17 00:00:00 2001 From: chenglimin Date: Fri, 22 May 2026 11:27:27 +0800 Subject: [PATCH 14/14] add spdx licence header --- src/layer/riscv/lstm_riscv.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layer/riscv/lstm_riscv.cpp b/src/layer/riscv/lstm_riscv.cpp index 5fbbf9158551..5096d8d866d1 100644 --- a/src/layer/riscv/lstm_riscv.cpp +++ b/src/layer/riscv/lstm_riscv.cpp @@ -1,5 +1,5 @@ - - +// Copyright 2021 Tencent +// SPDX-License-Identifier: BSD-3-Clause #include "lstm_riscv.h" #include #include "riscv_usability.h"