diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 614c3b8f31f1..4912f5791053 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -101,6 +101,11 @@ ncnn_add_layer(SPP OFF) ncnn_add_layer(TanH) ncnn_add_layer(Threshold) ncnn_add_layer(Tile) +ncnn_add_layer(TopK) +ncnn_add_layer(Gather) +ncnn_add_layer(GatherElements) +ncnn_add_layer(Mod) +ncnn_add_layer(Expand) ncnn_add_layer(RNN) ncnn_add_layer(LSTM) ncnn_add_layer(BinaryOp) diff --git a/src/layer/cast.cpp b/src/layer/cast.cpp index 3dcff38f3cac..e18a7c3a8ae2 100644 --- a/src/layer/cast.cpp +++ b/src/layer/cast.cpp @@ -74,6 +74,16 @@ int Cast::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons // bfloat16 out_elemsize = 2 * elempack; } + else if (type_to == 5) + { + // int64 + out_elemsize = 8 * elempack; + } + else if (type_to == 6) + { + // int32 + out_elemsize = 4 * elempack; + } if (dims == 1) { @@ -173,6 +183,70 @@ int Cast::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons // TODO more cast type + if (type_from == 5 && type_to == 1) + { + // int64 → float32 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const long long* ptr = bottom_blob.channel(q); + float* outptr = top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = (float)ptr[i]; + } + } + } + + if (type_from == 1 && type_to == 5) + { + // float32 → int64 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = bottom_blob.channel(q); + long long* outptr = top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = (long long)ptr[i]; + } + } + } + + if (type_from == 6 && type_to == 1) + { + // int32 → float32 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const int* ptr = bottom_blob.channel(q); + float* outptr = top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = (float)ptr[i]; + } + } + } + + if (type_from == 1 && type_to == 6) + { + // float32 → int32 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = bottom_blob.channel(q); + int* outptr = top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = (int)ptr[i]; + } + } + } + return 0; } diff --git a/src/layer/cast.h b/src/layer/cast.h index 036e61efed04..22c8f5da4626 100644 --- a/src/layer/cast.h +++ b/src/layer/cast.h @@ -24,6 +24,8 @@ class Cast : public Layer // 2 = float16 // 3 = int8 // 4 = bfloat16 + // 5 = int64 + // 6 = int32 int type_from; int type_to; }; diff --git a/src/layer/expand.cpp b/src/layer/expand.cpp new file mode 100644 index 000000000000..7553ce957bad --- /dev/null +++ b/src/layer/expand.cpp @@ -0,0 +1,134 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "expand.h" + +#include +#if !NCNN_SIMPLESTL +#include +#endif + +#if __ARM_NEON +#include +#endif + +namespace ncnn { + +Expand::Expand() +{ + one_blob_only = false; + support_inplace = false; +} + +int Expand::load_param(const ParamDict& /*pd*/) +{ + return 0; +} + +int Expand::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + if (bottom_blobs.size() < 2) + return -1; + + const Mat& input_blob = bottom_blobs[0]; + const Mat& shape_blob = bottom_blobs[1]; + + // shape_blob: 1D tensor of int32 or int64 in ncnn ordering (w, h, c) + const size_t shape_elemsize = shape_blob.elemsize / shape_blob.elempack; + const bool shape_is_int64 = (shape_elemsize == 8); + int target_dims = (shape_blob.dims == 1) ? shape_blob.w : (int)shape_blob.total(); + if (target_dims > 3) target_dims = 3; + + // Input shape in ncnn ordering: index 0=w (innermost), 1=h, 2=c (outermost) + const int in_dims = input_blob.dims; + int in_w = input_blob.w; + int in_h = (in_dims >= 2) ? input_blob.h : 1; + int in_c = (in_dims >= 3) ? input_blob.c : 1; + + // Read target shape from shape_blob (ncnn ordering) + int tgt_w = 1, tgt_h = 1, tgt_c = 1; + auto read_shape_dim = [&](int idx) -> int { + if (idx < 0 || idx >= target_dims) return 1; + if (shape_is_int64) return (int)((const int64_t*)(const void*)shape_blob)[idx]; + return ((const int*)(const void*)shape_blob)[idx]; + }; + if (target_dims >= 1) tgt_w = read_shape_dim(0); + if (target_dims >= 2) tgt_h = read_shape_dim(1); + if (target_dims >= 3) tgt_c = read_shape_dim(2); + + // Resolve broadcast: -1 means keep input dim; 1 means broadcast + auto resolve_dim = [](int in_dim, int tgt_dim) -> int { + if (tgt_dim <= 0) return in_dim; // -1 or 0: keep + if (in_dim == 1) return tgt_dim; + return in_dim; // tgt==1 or tgt==in_dim: keep in_dim + }; + + const int out_w = resolve_dim(in_w, tgt_w); + const int out_h = resolve_dim(in_h, tgt_h); + const int out_c = resolve_dim(in_c, tgt_c); + const int out_dims = std::max(in_dims, target_dims); + + // Validate: if neither is 1 and they differ, it's invalid + if ((in_w != 1 && tgt_w != 1 && tgt_w > 0 && in_w != tgt_w) || (in_h != 1 && tgt_h != 1 && tgt_h > 0 && in_h != tgt_h) || (in_c != 1 && tgt_c != 1 && tgt_c > 0 && in_c != tgt_c)) + return -1; + + Mat& top_blob = top_blobs[0]; + if (out_dims == 1) + top_blob.create(out_w, input_blob.elemsize, input_blob.elempack, opt.blob_allocator); + else if (out_dims == 2) + top_blob.create(out_w, out_h, input_blob.elemsize, input_blob.elempack, opt.blob_allocator); + else + top_blob.create(out_w, out_h, out_c, input_blob.elemsize, input_blob.elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + const float* inp = input_blob; + float* out = top_blob; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int z = 0; z < out_c; z++) + { + int sz = (in_c > 1) ? z : 0; + const float* src_chan = inp + sz * (int)input_blob.cstep; + float* dst_chan = out + z * (int)top_blob.cstep; + + for (int y = 0; y < out_h; y++) + { + int sy = (in_h > 1) ? y : 0; + const float* src_row = src_chan + sy * in_w; + float* dst_row = dst_chan + y * out_w; + + if (in_w == out_w) + { + memcpy(dst_row, src_row, out_w * sizeof(float)); + } + else // in_w == 1: broadcast scalar across row + { + const float val = src_row[0]; +#if __ARM_NEON + float32x4_t vval = vdupq_n_f32(val); + int x = 0; + // Unroll 4x NEON stores (4 vectors × 4 floats = 16 elements per iteration) + for (; x + 16 <= out_w; x += 16) + { + vst1q_f32(dst_row + x, vval); + vst1q_f32(dst_row + x + 4, vval); + vst1q_f32(dst_row + x + 8, vval); + vst1q_f32(dst_row + x + 12, vval); + } + for (; x + 4 <= out_w; x += 4) + vst1q_f32(dst_row + x, vval); + for (; x < out_w; x++) + dst_row[x] = val; +#else + for (int x = 0; x < out_w; x++) + dst_row[x] = val; +#endif + } + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/expand.h b/src/layer/expand.h new file mode 100644 index 000000000000..3d8e0f2534a7 --- /dev/null +++ b/src/layer/expand.h @@ -0,0 +1,23 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_EXPAND_H +#define LAYER_EXPAND_H + +#include "layer.h" + +namespace ncnn { + +class Expand : public Layer +{ +public: + Expand(); + + virtual int load_param(const ParamDict& pd); + + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +}; + +} // namespace ncnn + +#endif // LAYER_EXPAND_H diff --git a/src/layer/gather.cpp b/src/layer/gather.cpp new file mode 100644 index 000000000000..b8b3e7aa926b --- /dev/null +++ b/src/layer/gather.cpp @@ -0,0 +1,502 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "gather.h" + +#include +#include + +namespace ncnn { + +Gather::Gather() +{ + one_blob_only = false; + support_inplace = false; +} + +int Gather::load_param(const ParamDict& pd) +{ + axis = pd.get(0, 0); + + return 0; +} + +int Gather::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + if (bottom_blobs.size() < 2) + return -1; + + const Mat& input_blob = bottom_blobs[0]; + const Mat& index_blob = bottom_blobs[1]; + const int dims = input_blob.dims; + + // Only float32 data supported + if (input_blob.elemsize / input_blob.elempack != 4) + return -1; + + // Only dims 1/2/3 supported + if (dims > 3 || index_blob.dims > 3) + return -1; + + int positive_axis = axis < 0 ? axis + dims : axis; + if (positive_axis < 0 || positive_axis >= dims) + return -1; + + // PyTorch-style axis ordering: axis=0 is outermost (c for 3D, h for 2D, w for 1D) + // shape[] maps axis -> dimension size in that PyTorch order + int shape[3] = {1, 1, 1}; + if (dims == 1) + shape[0] = input_blob.w; + else if (dims == 2) + { + shape[0] = input_blob.h; + shape[1] = input_blob.w; + } + else + { + shape[0] = input_blob.c; + shape[1] = input_blob.h; + shape[2] = input_blob.w; + } + + const int axis_dim_size = shape[positive_axis]; + + // Output shape matches index_blob shape exactly (preserve rank) + Mat& top_blob = top_blobs[0]; + if (index_blob.dims == 1) + top_blob.create(index_blob.w, input_blob.elemsize, input_blob.elempack, opt.blob_allocator); + else if (index_blob.dims == 2) + top_blob.create(index_blob.w, index_blob.h, input_blob.elemsize, input_blob.elempack, opt.blob_allocator); + else + top_blob.create(index_blob.w, index_blob.h, index_blob.c, input_blob.elemsize, input_blob.elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + const float* inp = input_blob; + // Indices may be int32 (elemsize=4) or int64 (elemsize=8) + const size_t idx_elemsize = index_blob.elemsize / index_blob.elempack; + float* out = top_blob; + + const int64_t* idx_ptr64 = (const int64_t*)(const void*)index_blob; + const int* idx_ptr32 = (const int*)(const void*)index_blob; + +#define CLAMP_IDX(gi) \ + do \ + { \ + if ((gi) < 0) (gi) += axis_dim_size; \ + if ((gi) < 0) (gi) = 0; \ + if ((gi) >= axis_dim_size) (gi) = axis_dim_size - 1; \ + } while (0) + + // use_i32: branch hoisted once per forward() call, not per element + const bool use_i32 = (idx_elemsize == 4); + + if (dims == 1) + { + if (use_i32) + { + int x = 0; + for (; x + 4 <= index_blob.w; x += 4) + { + int gi0 = idx_ptr32[x]; + CLAMP_IDX(gi0); + int gi1 = idx_ptr32[x + 1]; + CLAMP_IDX(gi1); + int gi2 = idx_ptr32[x + 2]; + CLAMP_IDX(gi2); + int gi3 = idx_ptr32[x + 3]; + CLAMP_IDX(gi3); + out[x] = inp[gi0]; + out[x + 1] = inp[gi1]; + out[x + 2] = inp[gi2]; + out[x + 3] = inp[gi3]; + } + for (; x < index_blob.w; x++) + { + int gi = idx_ptr32[x]; + CLAMP_IDX(gi); + out[x] = inp[gi]; + } + } + else + { + int x = 0; + for (; x + 4 <= index_blob.w; x += 4) + { + int gi0 = (int)idx_ptr64[x]; + CLAMP_IDX(gi0); + int gi1 = (int)idx_ptr64[x + 1]; + CLAMP_IDX(gi1); + int gi2 = (int)idx_ptr64[x + 2]; + CLAMP_IDX(gi2); + int gi3 = (int)idx_ptr64[x + 3]; + CLAMP_IDX(gi3); + out[x] = inp[gi0]; + out[x + 1] = inp[gi1]; + out[x + 2] = inp[gi2]; + out[x + 3] = inp[gi3]; + } + for (; x < index_blob.w; x++) + { + int gi = (int)idx_ptr64[x]; + CLAMP_IDX(gi); + out[x] = inp[gi]; + } + } + } + else if (dims == 2) + { + const int iw = input_blob.w; + const int idxw = index_blob.w; + + if (positive_axis == 0) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < index_blob.h; y++) + { + float* out_row = out + y * top_blob.w; + if (use_i32) + { + const int* ir = idx_ptr32 + y * idxw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = ir[x]; + CLAMP_IDX(gi0); + int gi1 = ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = inp[gi0 * iw + x]; + out_row[x + 1] = inp[gi1 * iw + x + 1]; + out_row[x + 2] = inp[gi2 * iw + x + 2]; + out_row[x + 3] = inp[gi3 * iw + x + 3]; + } + for (; x < idxw; x++) + { + int gi = ir[x]; + CLAMP_IDX(gi); + out_row[x] = inp[gi * iw + x]; + } + } + else + { + const int64_t* ir = idx_ptr64 + y * idxw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = (int)ir[x]; + CLAMP_IDX(gi0); + int gi1 = (int)ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = (int)ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = (int)ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = inp[gi0 * iw + x]; + out_row[x + 1] = inp[gi1 * iw + x + 1]; + out_row[x + 2] = inp[gi2 * iw + x + 2]; + out_row[x + 3] = inp[gi3 * iw + x + 3]; + } + for (; x < idxw; x++) + { + int gi = (int)ir[x]; + CLAMP_IDX(gi); + out_row[x] = inp[gi * iw + x]; + } + } + } + } + else // positive_axis == 1 + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < index_blob.h; y++) + { + const float* inp_row = inp + y * iw; + float* out_row = out + y * top_blob.w; + if (use_i32) + { + const int* ir = idx_ptr32 + y * idxw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = ir[x]; + CLAMP_IDX(gi0); + int gi1 = ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = inp_row[gi0]; + out_row[x + 1] = inp_row[gi1]; + out_row[x + 2] = inp_row[gi2]; + out_row[x + 3] = inp_row[gi3]; + } + for (; x < idxw; x++) + { + int gi = ir[x]; + CLAMP_IDX(gi); + out_row[x] = inp_row[gi]; + } + } + else + { + const int64_t* ir = idx_ptr64 + y * idxw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = (int)ir[x]; + CLAMP_IDX(gi0); + int gi1 = (int)ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = (int)ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = (int)ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = inp_row[gi0]; + out_row[x + 1] = inp_row[gi1]; + out_row[x + 2] = inp_row[gi2]; + out_row[x + 3] = inp_row[gi3]; + } + for (; x < idxw; x++) + { + int gi = (int)ir[x]; + CLAMP_IDX(gi); + out_row[x] = inp_row[gi]; + } + } + } + } + } + else // dims == 3 + { + const int iw = input_blob.w; + const size_t in_cstep = input_blob.cstep; + const size_t idx_cstep = index_blob.cstep; + const size_t out_cstep = top_blob.cstep; + const int idxw = index_blob.w; + + if (positive_axis == 0) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int z = 0; z < index_blob.c; z++) + { + float* out_chan = out + z * out_cstep; + const int idx_z_base = (int)(z * idx_cstep); + if (use_i32) + { + for (int y = 0; y < index_blob.h; y++) + { + float* out_row = out_chan + y * top_blob.w; + const int* ir = idx_ptr32 + idx_z_base + y * idxw; + const int inp_y_off = y * iw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = ir[x]; + CLAMP_IDX(gi0); + int gi1 = ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = inp[(int)(gi0 * in_cstep) + inp_y_off + x]; + out_row[x + 1] = inp[(int)(gi1 * in_cstep) + inp_y_off + x + 1]; + out_row[x + 2] = inp[(int)(gi2 * in_cstep) + inp_y_off + x + 2]; + out_row[x + 3] = inp[(int)(gi3 * in_cstep) + inp_y_off + x + 3]; + } + for (; x < idxw; x++) + { + int gi = ir[x]; + CLAMP_IDX(gi); + out_row[x] = inp[(int)(gi * in_cstep) + inp_y_off + x]; + } + } + } + else + { + for (int y = 0; y < index_blob.h; y++) + { + float* out_row = out_chan + y * top_blob.w; + const int64_t* ir = idx_ptr64 + idx_z_base + y * idxw; + const int inp_y_off = y * iw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = (int)ir[x]; + CLAMP_IDX(gi0); + int gi1 = (int)ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = (int)ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = (int)ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = inp[(int)(gi0 * in_cstep) + inp_y_off + x]; + out_row[x + 1] = inp[(int)(gi1 * in_cstep) + inp_y_off + x + 1]; + out_row[x + 2] = inp[(int)(gi2 * in_cstep) + inp_y_off + x + 2]; + out_row[x + 3] = inp[(int)(gi3 * in_cstep) + inp_y_off + x + 3]; + } + for (; x < idxw; x++) + { + int gi = (int)ir[x]; + CLAMP_IDX(gi); + out_row[x] = inp[(int)(gi * in_cstep) + inp_y_off + x]; + } + } + } + } + } + else if (positive_axis == 1) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int z = 0; z < index_blob.c; z++) + { + const float* inp_chan = inp + z * in_cstep; + float* out_chan = out + z * out_cstep; + const int idx_z_base = (int)(z * idx_cstep); + if (use_i32) + { + for (int y = 0; y < index_blob.h; y++) + { + float* out_row = out_chan + y * top_blob.w; + const int* ir = idx_ptr32 + idx_z_base + y * idxw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = ir[x]; + CLAMP_IDX(gi0); + int gi1 = ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = inp_chan[gi0 * iw + x]; + out_row[x + 1] = inp_chan[gi1 * iw + x + 1]; + out_row[x + 2] = inp_chan[gi2 * iw + x + 2]; + out_row[x + 3] = inp_chan[gi3 * iw + x + 3]; + } + for (; x < idxw; x++) + { + int gi = ir[x]; + CLAMP_IDX(gi); + out_row[x] = inp_chan[gi * iw + x]; + } + } + } + else + { + for (int y = 0; y < index_blob.h; y++) + { + float* out_row = out_chan + y * top_blob.w; + const int64_t* ir = idx_ptr64 + idx_z_base + y * idxw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = (int)ir[x]; + CLAMP_IDX(gi0); + int gi1 = (int)ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = (int)ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = (int)ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = inp_chan[gi0 * iw + x]; + out_row[x + 1] = inp_chan[gi1 * iw + x + 1]; + out_row[x + 2] = inp_chan[gi2 * iw + x + 2]; + out_row[x + 3] = inp_chan[gi3 * iw + x + 3]; + } + for (; x < idxw; x++) + { + int gi = (int)ir[x]; + CLAMP_IDX(gi); + out_row[x] = inp_chan[gi * iw + x]; + } + } + } + } + } + else // positive_axis == 2 + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int z = 0; z < index_blob.c; z++) + { + const float* inp_chan = inp + z * in_cstep; + float* out_chan = out + z * out_cstep; + const int idx_z_base = (int)(z * idx_cstep); + if (use_i32) + { + for (int y = 0; y < index_blob.h; y++) + { + const float* inp_row = inp_chan + y * iw; + float* out_row = out_chan + y * top_blob.w; + const int* ir = idx_ptr32 + idx_z_base + y * idxw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = ir[x]; + CLAMP_IDX(gi0); + int gi1 = ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = inp_row[gi0]; + out_row[x + 1] = inp_row[gi1]; + out_row[x + 2] = inp_row[gi2]; + out_row[x + 3] = inp_row[gi3]; + } + for (; x < idxw; x++) + { + int gi = ir[x]; + CLAMP_IDX(gi); + out_row[x] = inp_row[gi]; + } + } + } + else + { + for (int y = 0; y < index_blob.h; y++) + { + const float* inp_row = inp_chan + y * iw; + float* out_row = out_chan + y * top_blob.w; + const int64_t* ir = idx_ptr64 + idx_z_base + y * idxw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = (int)ir[x]; + CLAMP_IDX(gi0); + int gi1 = (int)ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = (int)ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = (int)ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = inp_row[gi0]; + out_row[x + 1] = inp_row[gi1]; + out_row[x + 2] = inp_row[gi2]; + out_row[x + 3] = inp_row[gi3]; + } + for (; x < idxw; x++) + { + int gi = (int)ir[x]; + CLAMP_IDX(gi); + out_row[x] = inp_row[gi]; + } + } + } + } + } + } + +#undef CLAMP_IDX + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/gather.h b/src/layer/gather.h new file mode 100644 index 000000000000..f8d24d9afb54 --- /dev/null +++ b/src/layer/gather.h @@ -0,0 +1,27 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_GATHER_H +#define LAYER_GATHER_H + +#include "layer.h" + +namespace ncnn { + +class Gather : public Layer +{ +public: + Gather(); + + virtual int load_param(const ParamDict& pd); + + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + +public: + // param_0 = axis (default 0) + int axis; +}; + +} // namespace ncnn + +#endif // LAYER_GATHER_H diff --git a/src/layer/gatherelements.cpp b/src/layer/gatherelements.cpp new file mode 100644 index 000000000000..70733c958107 --- /dev/null +++ b/src/layer/gatherelements.cpp @@ -0,0 +1,491 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "gatherelements.h" + +#include +#include + +namespace ncnn { + +GatherElements::GatherElements() +{ + one_blob_only = false; + support_inplace = false; +} + +int GatherElements::load_param(const ParamDict& pd) +{ + axis = pd.get(0, 0); + + return 0; +} + +int GatherElements::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + if (bottom_blobs.size() < 2) + return -1; + + const Mat& data_blob = bottom_blobs[0]; + const Mat& index_blob = bottom_blobs[1]; + + // Output has same shape as index_blob (same rank) + Mat& top_blob = top_blobs[0]; + if (index_blob.dims == 1) + top_blob.create(index_blob.w, data_blob.elemsize, data_blob.elempack, opt.blob_allocator); + else if (index_blob.dims == 2) + top_blob.create(index_blob.w, index_blob.h, data_blob.elemsize, data_blob.elempack, opt.blob_allocator); + else + top_blob.create(index_blob.w, index_blob.h, index_blob.c, data_blob.elemsize, data_blob.elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + const int data_dims = data_blob.dims; + const int positive_axis = axis < 0 ? axis + data_dims : axis; + if (positive_axis < 0 || positive_axis >= data_dims) + return -1; + + const float* data = data_blob; + const size_t idx_elemsize = index_blob.elemsize / index_blob.elempack; + float* out = top_blob; + + // PyTorch/ONNX axis ordering: axis=0 = outermost (c for 3D, h for 2D, w for 1D) + int data_shape[3] = {1, 1, 1}; + if (data_dims == 1) + data_shape[0] = data_blob.w; + else if (data_dims == 2) + { + data_shape[0] = data_blob.h; + data_shape[1] = data_blob.w; + } + else + { + data_shape[0] = data_blob.c; + data_shape[1] = data_blob.h; + data_shape[2] = data_blob.w; + } + const int axis_dim_size = data_shape[positive_axis]; + + const int64_t* idx_ptr64 = (const int64_t*)(const void*)index_blob; + const int* idx_ptr32 = (const int*)(const void*)index_blob; + +#define CLAMP_IDX(gi) \ + do \ + { \ + if ((gi) < 0) (gi) += axis_dim_size; \ + if ((gi) < 0) (gi) = 0; \ + if ((gi) >= axis_dim_size) (gi) = axis_dim_size - 1; \ + } while (0) + + // use_i32: branch hoisted once per forward() call, not per element + const bool use_i32 = (idx_elemsize == 4); + + if (data_dims == 1) + { + if (use_i32) + { + int x = 0; + for (; x + 4 <= index_blob.w; x += 4) + { + int gi0 = idx_ptr32[x]; + CLAMP_IDX(gi0); + int gi1 = idx_ptr32[x + 1]; + CLAMP_IDX(gi1); + int gi2 = idx_ptr32[x + 2]; + CLAMP_IDX(gi2); + int gi3 = idx_ptr32[x + 3]; + CLAMP_IDX(gi3); + out[x] = data[gi0]; + out[x + 1] = data[gi1]; + out[x + 2] = data[gi2]; + out[x + 3] = data[gi3]; + } + for (; x < index_blob.w; x++) + { + int gi = idx_ptr32[x]; + CLAMP_IDX(gi); + out[x] = data[gi]; + } + } + else + { + int x = 0; + for (; x + 4 <= index_blob.w; x += 4) + { + int gi0 = (int)idx_ptr64[x]; + CLAMP_IDX(gi0); + int gi1 = (int)idx_ptr64[x + 1]; + CLAMP_IDX(gi1); + int gi2 = (int)idx_ptr64[x + 2]; + CLAMP_IDX(gi2); + int gi3 = (int)idx_ptr64[x + 3]; + CLAMP_IDX(gi3); + out[x] = data[gi0]; + out[x + 1] = data[gi1]; + out[x + 2] = data[gi2]; + out[x + 3] = data[gi3]; + } + for (; x < index_blob.w; x++) + { + int gi = (int)idx_ptr64[x]; + CLAMP_IDX(gi); + out[x] = data[gi]; + } + } + } + else if (data_dims == 2) + { + const int dw = data_blob.w; + const int idxw = index_blob.w; + + if (positive_axis == 0) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < index_blob.h; y++) + { + float* out_row = out + y * top_blob.w; + if (use_i32) + { + const int* ir = idx_ptr32 + y * idxw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = ir[x]; + CLAMP_IDX(gi0); + int gi1 = ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = data[gi0 * dw + x]; + out_row[x + 1] = data[gi1 * dw + x + 1]; + out_row[x + 2] = data[gi2 * dw + x + 2]; + out_row[x + 3] = data[gi3 * dw + x + 3]; + } + for (; x < idxw; x++) + { + int gi = ir[x]; + CLAMP_IDX(gi); + out_row[x] = data[gi * dw + x]; + } + } + else + { + const int64_t* ir = idx_ptr64 + y * idxw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = (int)ir[x]; + CLAMP_IDX(gi0); + int gi1 = (int)ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = (int)ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = (int)ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = data[gi0 * dw + x]; + out_row[x + 1] = data[gi1 * dw + x + 1]; + out_row[x + 2] = data[gi2 * dw + x + 2]; + out_row[x + 3] = data[gi3 * dw + x + 3]; + } + for (; x < idxw; x++) + { + int gi = (int)ir[x]; + CLAMP_IDX(gi); + out_row[x] = data[gi * dw + x]; + } + } + } + } + else // positive_axis == 1 + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < index_blob.h; y++) + { + const float* data_row = data + y * dw; + float* out_row = out + y * top_blob.w; + if (use_i32) + { + const int* ir = idx_ptr32 + y * idxw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = ir[x]; + CLAMP_IDX(gi0); + int gi1 = ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = data_row[gi0]; + out_row[x + 1] = data_row[gi1]; + out_row[x + 2] = data_row[gi2]; + out_row[x + 3] = data_row[gi3]; + } + for (; x < idxw; x++) + { + int gi = ir[x]; + CLAMP_IDX(gi); + out_row[x] = data_row[gi]; + } + } + else + { + const int64_t* ir = idx_ptr64 + y * idxw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = (int)ir[x]; + CLAMP_IDX(gi0); + int gi1 = (int)ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = (int)ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = (int)ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = data_row[gi0]; + out_row[x + 1] = data_row[gi1]; + out_row[x + 2] = data_row[gi2]; + out_row[x + 3] = data_row[gi3]; + } + for (; x < idxw; x++) + { + int gi = (int)ir[x]; + CLAMP_IDX(gi); + out_row[x] = data_row[gi]; + } + } + } + } + } + else // data_dims == 3 + { + const int dw = data_blob.w; + const size_t in_cstep = data_blob.cstep; + const size_t idx_cstep = index_blob.cstep; + const size_t out_cstep = top_blob.cstep; + const int idxw = index_blob.w; + + if (positive_axis == 0) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int z = 0; z < index_blob.c; z++) + { + float* out_chan = out + z * out_cstep; + const int idx_z_base = (int)(z * idx_cstep); + if (use_i32) + { + for (int y = 0; y < index_blob.h; y++) + { + float* out_row = out_chan + y * top_blob.w; + const int* ir = idx_ptr32 + idx_z_base + y * idxw; + const int inp_y_off = y * dw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = ir[x]; + CLAMP_IDX(gi0); + int gi1 = ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = data[(int)(gi0 * in_cstep) + inp_y_off + x]; + out_row[x + 1] = data[(int)(gi1 * in_cstep) + inp_y_off + x + 1]; + out_row[x + 2] = data[(int)(gi2 * in_cstep) + inp_y_off + x + 2]; + out_row[x + 3] = data[(int)(gi3 * in_cstep) + inp_y_off + x + 3]; + } + for (; x < idxw; x++) + { + int gi = ir[x]; + CLAMP_IDX(gi); + out_row[x] = data[(int)(gi * in_cstep) + inp_y_off + x]; + } + } + } + else + { + for (int y = 0; y < index_blob.h; y++) + { + float* out_row = out_chan + y * top_blob.w; + const int64_t* ir = idx_ptr64 + idx_z_base + y * idxw; + const int inp_y_off = y * dw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = (int)ir[x]; + CLAMP_IDX(gi0); + int gi1 = (int)ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = (int)ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = (int)ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = data[(int)(gi0 * in_cstep) + inp_y_off + x]; + out_row[x + 1] = data[(int)(gi1 * in_cstep) + inp_y_off + x + 1]; + out_row[x + 2] = data[(int)(gi2 * in_cstep) + inp_y_off + x + 2]; + out_row[x + 3] = data[(int)(gi3 * in_cstep) + inp_y_off + x + 3]; + } + for (; x < idxw; x++) + { + int gi = (int)ir[x]; + CLAMP_IDX(gi); + out_row[x] = data[(int)(gi * in_cstep) + inp_y_off + x]; + } + } + } + } + } + else if (positive_axis == 1) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int z = 0; z < index_blob.c; z++) + { + const float* data_chan = data + z * in_cstep; + float* out_chan = out + z * out_cstep; + const int idx_z_base = (int)(z * idx_cstep); + if (use_i32) + { + for (int y = 0; y < index_blob.h; y++) + { + float* out_row = out_chan + y * top_blob.w; + const int* ir = idx_ptr32 + idx_z_base + y * idxw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = ir[x]; + CLAMP_IDX(gi0); + int gi1 = ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = data_chan[gi0 * dw + x]; + out_row[x + 1] = data_chan[gi1 * dw + x + 1]; + out_row[x + 2] = data_chan[gi2 * dw + x + 2]; + out_row[x + 3] = data_chan[gi3 * dw + x + 3]; + } + for (; x < idxw; x++) + { + int gi = ir[x]; + CLAMP_IDX(gi); + out_row[x] = data_chan[gi * dw + x]; + } + } + } + else + { + for (int y = 0; y < index_blob.h; y++) + { + float* out_row = out_chan + y * top_blob.w; + const int64_t* ir = idx_ptr64 + idx_z_base + y * idxw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = (int)ir[x]; + CLAMP_IDX(gi0); + int gi1 = (int)ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = (int)ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = (int)ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = data_chan[gi0 * dw + x]; + out_row[x + 1] = data_chan[gi1 * dw + x + 1]; + out_row[x + 2] = data_chan[gi2 * dw + x + 2]; + out_row[x + 3] = data_chan[gi3 * dw + x + 3]; + } + for (; x < idxw; x++) + { + int gi = (int)ir[x]; + CLAMP_IDX(gi); + out_row[x] = data_chan[gi * dw + x]; + } + } + } + } + } + else // positive_axis == 2 + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int z = 0; z < index_blob.c; z++) + { + const float* data_chan = data + z * in_cstep; + float* out_chan = out + z * out_cstep; + const int idx_z_base = (int)(z * idx_cstep); + if (use_i32) + { + for (int y = 0; y < index_blob.h; y++) + { + const float* data_row = data_chan + y * dw; + float* out_row = out_chan + y * top_blob.w; + const int* ir = idx_ptr32 + idx_z_base + y * idxw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = ir[x]; + CLAMP_IDX(gi0); + int gi1 = ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = data_row[gi0]; + out_row[x + 1] = data_row[gi1]; + out_row[x + 2] = data_row[gi2]; + out_row[x + 3] = data_row[gi3]; + } + for (; x < idxw; x++) + { + int gi = ir[x]; + CLAMP_IDX(gi); + out_row[x] = data_row[gi]; + } + } + } + else + { + for (int y = 0; y < index_blob.h; y++) + { + const float* data_row = data_chan + y * dw; + float* out_row = out_chan + y * top_blob.w; + const int64_t* ir = idx_ptr64 + idx_z_base + y * idxw; + int x = 0; + for (; x + 4 <= idxw; x += 4) + { + int gi0 = (int)ir[x]; + CLAMP_IDX(gi0); + int gi1 = (int)ir[x + 1]; + CLAMP_IDX(gi1); + int gi2 = (int)ir[x + 2]; + CLAMP_IDX(gi2); + int gi3 = (int)ir[x + 3]; + CLAMP_IDX(gi3); + out_row[x] = data_row[gi0]; + out_row[x + 1] = data_row[gi1]; + out_row[x + 2] = data_row[gi2]; + out_row[x + 3] = data_row[gi3]; + } + for (; x < idxw; x++) + { + int gi = (int)ir[x]; + CLAMP_IDX(gi); + out_row[x] = data_row[gi]; + } + } + } + } + } + } + +#undef CLAMP_IDX + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/gatherelements.h b/src/layer/gatherelements.h new file mode 100644 index 000000000000..2399c1581b20 --- /dev/null +++ b/src/layer/gatherelements.h @@ -0,0 +1,27 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_GATHERELEMENTS_H +#define LAYER_GATHERELEMENTS_H + +#include "layer.h" + +namespace ncnn { + +class GatherElements : public Layer +{ +public: + GatherElements(); + + virtual int load_param(const ParamDict& pd); + + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + +public: + // param_0 = axis (default 0) + int axis; +}; + +} // namespace ncnn + +#endif // LAYER_GATHERELEMENTS_H diff --git a/src/layer/mod.cpp b/src/layer/mod.cpp new file mode 100644 index 000000000000..df48f6fdb382 --- /dev/null +++ b/src/layer/mod.cpp @@ -0,0 +1,92 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "mod.h" + +#include + +namespace ncnn { + +Mod::Mod() +{ + one_blob_only = false; + support_inplace = false; + fmod = 0; +} + +int Mod::load_param(const ParamDict& pd) +{ + fmod = pd.get(0, 0); + + return 0; +} + +int Mod::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + if (bottom_blobs.size() < 2) + return -1; + + const Mat& a_blob = bottom_blobs[0]; + const Mat& b_blob = bottom_blobs[1]; + + // Output has same shape as a_blob + const Mat& out_shape = a_blob; + + Mat& top_blob = top_blobs[0]; + top_blob.create(out_shape.w, out_shape.h, out_shape.c, a_blob.elemsize, a_blob.elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + const int out_w = top_blob.w; + const int out_h = top_blob.h; + const int out_c = top_blob.c; + + const int count = out_h * out_w; // contiguous elements per channel slice + + if (fmod == 0) + { + // Python-style modulo (remainder with same sign as divisor) + #pragma omp parallel for num_threads(opt.num_threads) + for (int z = 0; z < out_c; z++) + { + const float* aptr = (const float*)a_blob + z * (int)a_blob.cstep; + const float* bptr = (const float*)b_blob + z * (int)b_blob.cstep; + float* optr = (float*)top_blob + z * (int)top_blob.cstep; + for (int i = 0; i < count; i++) + { + const float val_b = bptr[i]; + if (val_b == 0.0f) + { + optr[i] = 0.0f; + } + else + { + float result = ::fmodf(aptr[i], val_b); + if ((result != 0.0f) && ((val_b < 0.0f) != (result < 0.0f))) + result += val_b; + optr[i] = result; + } + } + } + } + else + { + // C-style fmod (remainder with same sign as dividend) + #pragma omp parallel for num_threads(opt.num_threads) + for (int z = 0; z < out_c; z++) + { + const float* aptr = (const float*)a_blob + z * (int)a_blob.cstep; + const float* bptr = (const float*)b_blob + z * (int)b_blob.cstep; + float* optr = (float*)top_blob + z * (int)top_blob.cstep; + for (int i = 0; i < count; i++) + { + const float val_b = bptr[i]; + optr[i] = (val_b == 0.0f) ? 0.0f : ::fmodf(aptr[i], val_b); + } + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/mod.h b/src/layer/mod.h new file mode 100644 index 000000000000..9f7e23a39c76 --- /dev/null +++ b/src/layer/mod.h @@ -0,0 +1,26 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_MOD_H +#define LAYER_MOD_H + +#include "layer.h" + +namespace ncnn { + +class Mod : public Layer +{ +public: + Mod(); + + virtual int load_param(const ParamDict& pd); + + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + +public: + int fmod; // 0 = remainder (Python-style), 1 = fmod (C-style) +}; + +} // namespace ncnn + +#endif // LAYER_MOD_H diff --git a/src/layer/tile.cpp b/src/layer/tile.cpp index f9d253e434f4..e3005483a58b 100644 --- a/src/layer/tile.cpp +++ b/src/layer/tile.cpp @@ -3,11 +3,13 @@ #include "tile.h" +#include + namespace ncnn { Tile::Tile() { - one_blob_only = true; + one_blob_only = false; support_inplace = false; } @@ -20,6 +22,34 @@ int Tile::load_param(const ParamDict& pd) return 0; } +int Tile::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + // ONNX mode: repeats comes as the second input blob. + // Extract repeats into a local Mat and delegate to the single-blob path. + if (bottom_blobs.size() >= 2 && !bottom_blobs[1].empty()) + { + const Mat& repeats_blob = bottom_blobs[1]; + const int* rptr = (const int*)(const void*)repeats_blob; + int rcount = (repeats_blob.dims == 1) ? repeats_blob.w : (int)repeats_blob.total(); + + // Build a param-style Mat for the repeats (int32, 1D, length rcount) + Mat repeats_param(rcount, (size_t)4u); + int* dst = (int*)(void*)repeats_param; + for (int i = 0; i < rcount; i++) + dst[i] = rptr[i]; + + // Temporarily override member repeats using a local Tile + Tile tile_op; + tile_op.axis = axis; + tile_op.tiles = tiles; + tile_op.repeats = repeats_param; + + return tile_op.forward(bottom_blobs[0], top_blobs[0], opt); + } + + return forward(bottom_blobs[0], top_blobs[0], opt); +} + int Tile::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { int dims = bottom_blob.dims; @@ -100,36 +130,57 @@ int Tile::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons size_t elemsize = bottom_blob.elemsize; const int outdims = std::max(dims, repeats_num); - - if (repeat_w == 1 && repeat_h == 1 && repeat_d == 1 && repeat_c == 1) + if (repeat_w != 1 && repeat_h == 1 && repeat_d == 1 && repeat_c == 1) { - // all ones - if (repeats_num == 0 || dims == repeats_num) - { - top_blob = bottom_blob; - return 0; - } + if (outdims == 1) + top_blob.create(w * repeat_w, elemsize, opt.blob_allocator); + if (outdims == 2) + top_blob.create(w * repeat_w, h, elemsize, opt.blob_allocator); + if (outdims == 3) + top_blob.create(w * repeat_w, h, channels, elemsize, opt.blob_allocator); + if (outdims == 4) + top_blob.create(w * repeat_w, h, d, channels, elemsize, opt.blob_allocator); } - - int outw = w * repeat_w; - int outh = h * repeat_h; - int outd = d * repeat_d; - int outc = channels * repeat_c; - if (outdims == 1) + else if (repeat_h != 1 && repeat_d == 1 && repeat_c == 1) { - top_blob.create(outw, elemsize, opt.blob_allocator); + if (outdims == 2) + top_blob.create(w * repeat_w, h * repeat_h, elemsize, opt.blob_allocator); + if (outdims == 3) + top_blob.create(w * repeat_w, h * repeat_h, channels, elemsize, opt.blob_allocator); + if (outdims == 4) + top_blob.create(w * repeat_w, h * repeat_h, d, channels, elemsize, opt.blob_allocator); } - if (outdims == 2) + else if (repeat_d != 1 && repeat_c == 1) { - top_blob.create(outw, outh, elemsize, opt.blob_allocator); + if (outdims == 4) + top_blob.create(w * repeat_w, h * repeat_h, d * repeat_d, channels, elemsize, opt.blob_allocator); } - if (outdims == 3) + else if (repeat_d == 1 && repeat_c != 1) { - top_blob.create(outw, outh, outc, elemsize, opt.blob_allocator); + if (outdims == 3) + top_blob.create(w * repeat_w, h * repeat_h, channels * repeat_c, elemsize, opt.blob_allocator); + if (outdims == 4) + top_blob.create(w * repeat_w, h * repeat_h, d, channels * repeat_c, elemsize, opt.blob_allocator); } - if (outdims == 4) + else if (repeat_d != 1 && repeat_c != 1) { - top_blob.create(outw, outh, outd, outc, elemsize, opt.blob_allocator); + if (outdims == 4) + top_blob.create(w * repeat_w, h * repeat_h, d * repeat_d, channels * repeat_c, elemsize, opt.blob_allocator); + } + else // all ones + { + if (repeats_num == 0 || dims == repeats_num) + { + top_blob = bottom_blob; + return 0; + } + + if (outdims == 2) + top_blob.create(w * repeat_w, h * repeat_h, elemsize, opt.blob_allocator); + if (outdims == 3) + top_blob.create(w * repeat_w, h * repeat_h, channels * repeat_c, elemsize, opt.blob_allocator); + if (outdims == 4) + top_blob.create(w * repeat_w, h * repeat_h, d * repeat_d, channels * repeat_c, elemsize, opt.blob_allocator); } if (top_blob.empty()) return -100; diff --git a/src/layer/tile.h b/src/layer/tile.h index 7fc9ae630c6e..ffa92225c8b0 100644 --- a/src/layer/tile.h +++ b/src/layer/tile.h @@ -15,6 +15,7 @@ class Tile : public Layer virtual int load_param(const ParamDict& pd); + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; public: diff --git a/src/layer/topk.cpp b/src/layer/topk.cpp new file mode 100644 index 000000000000..a2c42383ded9 --- /dev/null +++ b/src/layer/topk.cpp @@ -0,0 +1,669 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "topk.h" + +#include +#include + +#if NCNN_SIMPLESTL +#include "simplestl.h" +#else +#include +#include +#endif + +#if __ARM_NEON +#include +#endif // __ARM_NEON + +namespace ncnn { + +static inline bool topk_isnan(float v) +{ + uint32_t u; + memcpy(&u, &v, sizeof(uint32_t)); + return (u & 0x7fffffff) > 0x7f800000; +} + +static inline bool topk_pair_comp(const std::pair& a, const std::pair& b, bool largest) +{ + const bool a_nan = topk_isnan(a.first); + const bool b_nan = topk_isnan(b.first); + + // Keep NaN at the end for both largest/smallest to ensure deterministic ordering. + if (a_nan || b_nan) + { + if (a_nan != b_nan) + return !a_nan && b_nan; + + return a.second < b.second; + } + + if (a.first != b.first) + return largest ? (a.first > b.first) : (a.first < b.first); + + return a.second < b.second; +} + +// Fast comparison assuming both values are non-NaN (common case). +static inline bool topk_value_index_comp_nonnan(float a_value, int a_index, float b_value, int b_index, bool largest) +{ + if (a_value != b_value) + return largest ? (a_value > b_value) : (a_value < b_value); + return a_index < b_index; +} + +static inline bool topk_value_index_comp(float a_value, int a_index, float b_value, int b_index, bool largest) +{ + const bool a_nan = topk_isnan(a_value); + const bool b_nan = topk_isnan(b_value); + + if (a_nan || b_nan) + { + if (a_nan != b_nan) + return !a_nan && b_nan; + + return a_index < b_index; + } + + if (a_value != b_value) + return largest ? (a_value > b_value) : (a_value < b_value); + + return a_index < b_index; +} + +struct topk_pair_comparator +{ + topk_pair_comparator(bool _largest) + : largest(_largest) + { + } + + bool operator()(const std::pair& a, const std::pair& b) const + { + return topk_pair_comp(a, b, largest); + } + + bool largest; +}; + +TopK::TopK() +{ + one_blob_only = false; + support_inplace = false; +} + +int TopK::load_param(const ParamDict& pd) +{ + axis = pd.get(0, -1); + largest = pd.get(1, 1); + sorted = pd.get(2, 1); + k = pd.get(3, 1); + + return 0; +} + +int TopK::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + if (bottom_blobs.empty()) + return -1; + + const Mat& bottom_blob = bottom_blobs[0]; + + int _k = k; + if (bottom_blobs.size() >= 2) + { + const Mat& k_blob = bottom_blobs[1]; + if (k_blob.total() < 1) + return -1; + + const size_t k_elemsize = k_blob.elemsize / k_blob.elempack; + if (k_elemsize == 8) + _k = (int)((const int64_t*)(const void*)k_blob)[0]; + else if (k_elemsize == 4) + _k = ((const int*)(const void*)k_blob)[0]; + else + return -1; + } + + if (bottom_blob.dims < 1 || bottom_blob.dims > 4) + return -100; + + const int dims = bottom_blob.dims; + + const int positive_axis = axis < 0 ? axis + dims : axis; + if (positive_axis < 0 || positive_axis >= dims) + return -1; + + int shape[4] = {1, 1, 1, 1}; + shape[0] = bottom_blob.w; + if (dims >= 2) shape[1] = bottom_blob.h; + if (dims >= 3) shape[2] = bottom_blob.dims == 3 ? bottom_blob.c : bottom_blob.d; + if (dims >= 4) shape[3] = bottom_blob.c; + + const int axis_size = shape[positive_axis]; + if (axis_size <= 0) + return -1; + + if (_k < 0) + return -1; + if (_k > axis_size) + _k = axis_size; + + if (_k == 0) + { + // Return empty (zero-sized) output blobs without allocation + top_blobs[0] = Mat(); + if (top_blobs.size() >= 2) + top_blobs[1] = Mat(); + return 0; + } + + int out_shape[4] = {shape[0], shape[1], shape[2], shape[3]}; + out_shape[positive_axis] = _k; + + Mat values; + if (dims == 1) values.create(out_shape[0], 4u, opt.blob_allocator); + if (dims == 2) values.create(out_shape[0], out_shape[1], 4u, opt.blob_allocator); + if (dims == 3) values.create(out_shape[0], out_shape[1], out_shape[2], 4u, opt.blob_allocator); + if (dims == 4) values.create(out_shape[0], out_shape[1], out_shape[2], out_shape[3], 4u, opt.blob_allocator); + if (values.empty()) + return -100; + + Mat indices; + if (top_blobs.size() >= 2) + { + if (dims == 1) indices.create(out_shape[0], 4u, opt.blob_allocator); + if (dims == 2) indices.create(out_shape[0], out_shape[1], 4u, opt.blob_allocator); + if (dims == 3) indices.create(out_shape[0], out_shape[1], out_shape[2], 4u, opt.blob_allocator); + if (dims == 4) indices.create(out_shape[0], out_shape[1], out_shape[2], out_shape[3], 4u, opt.blob_allocator); + if (indices.empty()) + return -100; + } + + const float* ptr = bottom_blob; + float* outptr = values; + int* outidxptr = (int*)(void*)(indices.data); + const bool output_indices = outidxptr != 0; + + int inner = 1; + for (int i = 0; i < positive_axis; i++) + { + inner *= shape[i]; + } + + int outer = 1; + for (int i = positive_axis + 1; i < dims; i++) + { + outer *= shape[i]; + } + + const bool largest_flag = largest != 0; + const bool sorted_flag = sorted != 0; + + const int total_lines = outer * inner; + + // ncnn 3-/4-D mats have a channel stride (cstep) that may be larger than w*h + // due to alignment padding. The flat inner/outer indexing must account for this: + // - when axis reduces a non-channel dim, the outer loop spans channels and + // the channel offset must use cstep rather than the product of spatial sizes; + // - when axis IS the channel dim, the per-element j-stride must be cstep. + const size_t in_cstep = (dims >= 3) ? (size_t)bottom_blob.cstep : 0; + const size_t out_cstep = (dims >= 3) ? values.cstep : 0; + const bool axis_is_channel = (dims >= 3 && positive_axis == dims - 1); + // spatial-only outer count: channels factored out so cstep can be used separately + const int c_channels = (!axis_is_channel && dims >= 3) ? shape[dims - 1] : 1; + const int outer_spatial = (dims >= 3 && !axis_is_channel) ? outer / c_channels : outer; + // stride when stepping along the axis in memory + const size_t in_axis_stride = axis_is_channel ? in_cstep : (size_t)inner; + const size_t out_axis_stride = axis_is_channel ? out_cstep : (size_t)inner; + + if (_k == 1) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int line = 0; line < total_lines; line++) + { + int outer_i = line / inner; + int inner_i = line - outer_i * inner; + + size_t in_base, out_base; + if (!axis_is_channel && dims >= 3) + { + const int ci = outer_i / outer_spatial; + const int sp_i = outer_i % outer_spatial; + in_base = (size_t)ci * in_cstep + (size_t)sp_i * axis_size * inner + inner_i; + out_base = (size_t)ci * out_cstep + (size_t)sp_i * 1 * inner + inner_i; + } + else + { + in_base = (size_t)outer_i * axis_size * inner + inner_i; + out_base = (size_t)outer_i * 1 * inner + inner_i; + } + +#if __ARM_NEON + // Fast path: NEON-optimized k=1 without indices (values-only) + // Requires: no NaN values in input (NaN breaks vector comparisons) + if (!output_indices && inner == 1 && axis_size >= 4) + { + const float* lineptr = ptr + in_base; + + // Pre-scan for NaN - if found, fall through to NaN-aware scalar path + bool has_nan = false; + for (int j = 0; j < axis_size; j++) + { + if (topk_isnan(lineptr[j])) + { + has_nan = true; + break; + } + } + + if (!has_nan) + { + // Accumulate best4 across all NEON chunks; reduce to scalar only once. + float32x4_t best4 = vld1q_f32(lineptr); + int j = 4; + + for (; j + 3 < axis_size; j += 4) + { + float32x4_t v = vld1q_f32(lineptr + j); + best4 = largest_flag ? vmaxq_f32(best4, v) : vminq_f32(best4, v); + } + + // Reduce best4 to scalar once after the loop + float32x2_t m = largest_flag + ? vpmax_f32(vget_low_f32(best4), vget_high_f32(best4)) + : vpmin_f32(vget_low_f32(best4), vget_high_f32(best4)); + m = largest_flag ? vpmax_f32(m, m) : vpmin_f32(m, m); + float best_value = vget_lane_f32(m, 0); + + // Handle remaining elements (scalar) + for (; j < axis_size; j++) + { + const float candidate_value = lineptr[j]; + if (largest_flag) + { + if (candidate_value > best_value) + best_value = candidate_value; + } + else + { + if (candidate_value < best_value) + best_value = candidate_value; + } + } + + outptr[out_base] = best_value; + continue; + } + // Fall through to NaN-aware scalar path for proper tie-breaking + } +#endif // __ARM_NEON + + float best_value = ptr[in_base]; + int best_index = 0; + + // Fast path: no NaN check per comparison pair (common case). + // topk_value_index_comp checks both operands for NaN on every call; + // here we check only the candidate, and fall back only when NaN is found. + bool has_nan = topk_isnan(best_value); + if (!has_nan) + { + if (largest_flag) + { + for (int j = 1; j < axis_size; j++) + { + const float v = ptr[in_base + j * in_axis_stride]; + if (topk_isnan(v)) + { + has_nan = true; + break; + } + if (v > best_value) + { + best_value = v; + best_index = j; + } + } + } + else + { + for (int j = 1; j < axis_size; j++) + { + const float v = ptr[in_base + j * in_axis_stride]; + if (topk_isnan(v)) + { + has_nan = true; + break; + } + if (v < best_value) + { + best_value = v; + best_index = j; + } + } + } + } + if (has_nan) + { + // NaN-aware fallback: NaN sorts last, ties broken by index. + best_value = ptr[in_base]; + best_index = 0; + for (int j = 1; j < axis_size; j++) + { + const float v = ptr[in_base + j * in_axis_stride]; + if (topk_value_index_comp(v, j, best_value, best_index, largest_flag)) + { + best_value = v; + best_index = j; + } + } + } + + outptr[out_base] = best_value; + if (output_indices) + outidxptr[out_base] = best_index; + } + + top_blobs[0] = values; + if (top_blobs.size() >= 2) + top_blobs[1] = indices; + + return 0; + } + + if (_k == axis_size && !sorted_flag) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int line = 0; line < total_lines; line++) + { + int outer_i = line / inner; + int inner_i = line - outer_i * inner; + + size_t in_base, out_base; + if (!axis_is_channel && dims >= 3) + { + const int ci = outer_i / outer_spatial; + const int sp_i = outer_i % outer_spatial; + in_base = (size_t)ci * in_cstep + (size_t)sp_i * axis_size * inner + inner_i; + out_base = (size_t)ci * out_cstep + (size_t)sp_i * _k * inner + inner_i; + } + else + { + in_base = (size_t)outer_i * axis_size * inner + inner_i; + out_base = (size_t)outer_i * _k * inner + inner_i; + } + + if (output_indices) + { + for (int j = 0; j < _k; j++) + { + outptr[out_base + j * out_axis_stride] = ptr[in_base + j * in_axis_stride]; + outidxptr[out_base + j * out_axis_stride] = j; + } + } + else + { + for (int j = 0; j < _k; j++) + { + outptr[out_base + j * out_axis_stride] = ptr[in_base + j * in_axis_stride]; + } + } + } + + top_blobs[0] = values; + if (top_blobs.size() >= 2) + top_blobs[1] = indices; + + return 0; + } + + if (_k <= 4) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int line = 0; line < total_lines; line++) + { + int outer_i = line / inner; + int inner_i = line - outer_i * inner; + + size_t in_base, out_base; + if (!axis_is_channel && dims >= 3) + { + const int ci = outer_i / outer_spatial; + const int sp_i = outer_i % outer_spatial; + in_base = (size_t)ci * in_cstep + (size_t)sp_i * axis_size * inner + inner_i; + out_base = (size_t)ci * out_cstep + (size_t)sp_i * _k * inner + inner_i; + } + else + { + in_base = (size_t)outer_i * axis_size * inner + inner_i; + out_base = (size_t)outer_i * _k * inner + inner_i; + } + + float top_values[4]; + int top_indices[4]; + int top_count = 0; + + // has_nan_in_top: tracks whether the current top-k buffer contains any NaN. + // When false, use the cheaper non-NaN comparator in the insertion sort. + bool has_nan_in_top = false; + + if (sorted_flag) + { + for (int j = 0; j < axis_size; j++) + { + const float candidate_value = ptr[in_base + j * in_axis_stride]; + const bool cand_nan = topk_isnan(candidate_value); + +// Select comparator: skip NaN handling when neither side has NaN. +#define COMP_K4(a_v, a_i, b_v, b_i) \ + ((!cand_nan && !has_nan_in_top) \ + ? topk_value_index_comp_nonnan(a_v, a_i, b_v, b_i, largest_flag) \ + : topk_value_index_comp(a_v, a_i, b_v, b_i, largest_flag)) + + if (top_count < _k) + { + int insert_pos = top_count; + while (insert_pos > 0 && COMP_K4(candidate_value, j, top_values[insert_pos - 1], top_indices[insert_pos - 1])) + { + top_values[insert_pos] = top_values[insert_pos - 1]; + top_indices[insert_pos] = top_indices[insert_pos - 1]; + insert_pos--; + } + + top_values[insert_pos] = candidate_value; + top_indices[insert_pos] = j; + top_count++; + if (cand_nan) has_nan_in_top = true; + } + else if (COMP_K4(candidate_value, j, top_values[_k - 1], top_indices[_k - 1])) + { + if (!cand_nan && has_nan_in_top) + { + // Evicting a NaN: recheck whether any NaN remains in top buffer. + has_nan_in_top = false; + for (int t = 0; t < _k - 1; t++) + if (topk_isnan(top_values[t])) + { + has_nan_in_top = true; + break; + } + } + + int insert_pos = _k - 1; + while (insert_pos > 0 && COMP_K4(candidate_value, j, top_values[insert_pos - 1], top_indices[insert_pos - 1])) + { + top_values[insert_pos] = top_values[insert_pos - 1]; + top_indices[insert_pos] = top_indices[insert_pos - 1]; + insert_pos--; + } + + top_values[insert_pos] = candidate_value; + top_indices[insert_pos] = j; + if (cand_nan) has_nan_in_top = true; + } + +#undef COMP_K4 + } + } + else + { + for (int j = 0; j < axis_size; j++) + { + const float candidate_value = ptr[in_base + j * in_axis_stride]; + const bool cand_nan = topk_isnan(candidate_value); + + if (top_count < _k) + { + top_values[top_count] = candidate_value; + top_indices[top_count] = j; + top_count++; + if (cand_nan) has_nan_in_top = true; + } + else + { + const bool use_fast = (!cand_nan && !has_nan_in_top); + int worst_pos = 0; + for (int t = 1; t < _k; t++) + { + bool is_worse = use_fast + ? topk_value_index_comp_nonnan(top_values[worst_pos], top_indices[worst_pos], top_values[t], top_indices[t], largest_flag) + : topk_value_index_comp(top_values[worst_pos], top_indices[worst_pos], top_values[t], top_indices[t], largest_flag); + if (is_worse) worst_pos = t; + } + + bool replace = use_fast + ? topk_value_index_comp_nonnan(candidate_value, j, top_values[worst_pos], top_indices[worst_pos], largest_flag) + : topk_value_index_comp(candidate_value, j, top_values[worst_pos], top_indices[worst_pos], largest_flag); + + if (replace) + { + if (!cand_nan && has_nan_in_top) + { + has_nan_in_top = false; + for (int t = 0; t < _k; t++) + if (t != worst_pos && topk_isnan(top_values[t])) + { + has_nan_in_top = true; + break; + } + } + top_values[worst_pos] = candidate_value; + top_indices[worst_pos] = j; + if (cand_nan) has_nan_in_top = true; + } + } + } + } + + if (output_indices) + { + for (int j = 0; j < _k; j++) + { + outptr[out_base + j * out_axis_stride] = top_values[j]; + outidxptr[out_base + j * out_axis_stride] = top_indices[j]; + } + } + else + { + for (int j = 0; j < _k; j++) + { + outptr[out_base + j * out_axis_stride] = top_values[j]; + } + } + } + + top_blobs[0] = values; + if (top_blobs.size() >= 2) + top_blobs[1] = indices; + + return 0; + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int line = 0; line < total_lines; line++) + { + // Reuse thread-local scratch to avoid one malloc/free per line. +#if !NCNN_SIMPLESTL + static thread_local std::vector > tl_vec; + tl_vec.resize(axis_size); + std::vector >& vec = tl_vec; +#else + std::vector > vec(axis_size); +#endif + + topk_pair_comparator comp(largest_flag); + + int outer_i = line / inner; + int inner_i = line - outer_i * inner; + + size_t in_base, out_base; + if (!axis_is_channel && dims >= 3) + { + const int ci = outer_i / outer_spatial; + const int sp_i = outer_i % outer_spatial; + in_base = (size_t)ci * in_cstep + (size_t)sp_i * axis_size * inner + inner_i; + out_base = (size_t)ci * out_cstep + (size_t)sp_i * _k * inner + inner_i; + } + else + { + in_base = (size_t)outer_i * axis_size * inner + inner_i; + out_base = (size_t)outer_i * _k * inner + inner_i; + } + + for (int j = 0; j < axis_size; j++) + { + vec[j].first = ptr[in_base + j * in_axis_stride]; + vec[j].second = j; + } + + if (_k < axis_size) + { +#if NCNN_SIMPLESTL + std::partial_sort(vec.begin(), vec.begin() + _k, vec.end(), comp); +#else + if (sorted_flag) + { + std::nth_element(vec.begin(), vec.begin() + _k, vec.end(), comp); + std::sort(vec.begin(), vec.begin() + _k, comp); + } + else + std::nth_element(vec.begin(), vec.begin() + _k, vec.end(), comp); +#endif + } + else + { + if (sorted_flag) +#if NCNN_SIMPLESTL + std::partial_sort(vec.begin(), vec.end(), vec.end(), comp); +#else + std::sort(vec.begin(), vec.end(), comp); +#endif + } + + if (output_indices) + { + for (int j = 0; j < _k; j++) + { + outptr[out_base + j * out_axis_stride] = vec[j].first; + outidxptr[out_base + j * out_axis_stride] = vec[j].second; + } + } + else + { + for (int j = 0; j < _k; j++) + { + outptr[out_base + j * out_axis_stride] = vec[j].first; + } + } + } + + top_blobs[0] = values; + if (top_blobs.size() >= 2) + top_blobs[1] = indices; + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/topk.h b/src/layer/topk.h new file mode 100644 index 000000000000..947dc21343ff --- /dev/null +++ b/src/layer/topk.h @@ -0,0 +1,29 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_TOPK_H +#define LAYER_TOPK_H + +#include "layer.h" + +namespace ncnn { + +class TopK : public Layer +{ +public: + TopK(); + + virtual int load_param(const ParamDict& pd); + + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + +public: + int axis; + int largest; + int sorted; + int k; +}; + +} // namespace ncnn + +#endif // LAYER_TOPK_H diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index a273e2221cbe..f05a6f0325b9 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -48,6 +48,23 @@ if(NCNN_PIXEL_DRAWING) ncnn_add_test(mat_pixel_drawing) endif() +# YOLO26 support tests +if(WITH_LAYER_gather) + ncnn_add_test(gather) +endif() + +if(WITH_LAYER_gatherelements) + ncnn_add_test(gatherelements) +endif() + +if(WITH_LAYER_expand) + ncnn_add_test(expand) +endif() + +if(WITH_LAYER_mod) + ncnn_add_test(mod) +endif() + if(NCNN_PIXEL_ROTATE) ncnn_add_test(mat_pixel_rotate) endif() @@ -173,6 +190,7 @@ ncnn_add_layer_test(Swish) ncnn_add_layer_test(TanH) ncnn_add_layer_test(Threshold) ncnn_add_layer_test(Tile) +ncnn_add_layer_test(TopK) ncnn_add_layer_test(UnaryOp) ncnn_add_layer_test(Unfold) ncnn_add_layer_test(Yolov3DetectionOutput) diff --git a/tests/test_expand.cpp b/tests/test_expand.cpp new file mode 100644 index 000000000000..407cfda67ae8 --- /dev/null +++ b/tests/test_expand.cpp @@ -0,0 +1,246 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "testutil.h" + +#include + +static int run_expand(const ncnn::Mat& data, const ncnn::Mat& shape, ncnn::Mat& out) +{ + ncnn::ParamDict pd; + + ncnn::Option opt; + opt.num_threads = 1; + opt.use_vulkan_compute = false; + opt.use_packing_layout = false; + + ncnn::Layer* op = ncnn::create_layer_cpu("Expand"); + if (!op) + return -1; + + op->load_param(pd); + + std::vector weights(0); + ncnn::ModelBinFromMatArray mb(weights.data()); + op->load_model(mb); + op->create_pipeline(opt); + + std::vector bottom_blobs(2); + bottom_blobs[0] = data; + bottom_blobs[1] = shape; + + std::vector top_blobs(1); + int ret = op->forward(bottom_blobs, top_blobs, opt); + + op->destroy_pipeline(opt); + delete op; + + if (ret != 0) + return ret; + + out = top_blobs[0]; + return 0; +} + +// Build a 1D int32 shape Mat in ncnn ordering (w, h, c). +static ncnn::Mat make_shape_i32(int w, int h, int c) +{ + ncnn::Mat s(3, (size_t)4u); + int* p = (int*)(void*)s; + p[0] = w; + p[1] = h; + p[2] = c; + return s; +} + +// Build a 1D int64 shape Mat (same values, different elemsize). +static ncnn::Mat make_shape_i64(int w, int h, int c) +{ + ncnn::Mat s(3, (size_t)8u); + int64_t* p = (int64_t*)(void*)s; + p[0] = w; + p[1] = h; + p[2] = c; + return s; +} + +static int check_equal(const ncnn::Mat& a, const ncnn::Mat& b, const char* name) +{ + if (a.dims != b.dims || a.w != b.w || a.h != b.h || a.c != b.c) + { + fprintf(stderr, "%s: shape mismatch got(%d %d %d dims=%d) expected(%d %d %d dims=%d)\n", + name, a.w, a.h, a.c, a.dims, b.w, b.h, b.c, b.dims); + return -1; + } + const float* ap = a; + const float* bp = b; + for (int z = 0; z < a.c; z++) + for (int y = 0; y < a.h; y++) + for (int x = 0; x < a.w; x++) + { + float got = ap[(int)(z * a.cstep) + y * a.w + x]; + float exp = bp[(int)(z * b.cstep) + y * b.w + x]; + if (got != exp) + { + fprintf(stderr, "%s: value mismatch at [%d,%d,%d]: got %f expected %f\n", + name, z, y, x, got, exp); + return -1; + } + } + return 0; +} + +static ncnn::Mat ref_expand(const ncnn::Mat& src, int out_w, int out_h, int out_c) +{ + ncnn::Mat out; + out.create(out_w, out_h, out_c, (size_t)4u); + + const float* sp = src; + float* op = out; + + for (int z = 0; z < out_c; z++) + { + int sz = (src.c > 1) ? z : 0; + const float* sc = sp + sz * (int)src.cstep; + float* dc = op + z * (int)out.cstep; + for (int y = 0; y < out_h; y++) + { + int sy = (src.h > 1) ? y : 0; + const float* sr = sc + sy * src.w; + float* dr = dc + y * out_w; + for (int x = 0; x < out_w; x++) + { + int sx = (src.w > 1) ? x : 0; + dr[x] = sr[sx]; + } + } + } + return out; +} + +static int test_expand(const ncnn::Mat& data, int out_w, int out_h, int out_c, const char* name) +{ + ncnn::Mat shape = make_shape_i32(out_w, out_h, out_c); + ncnn::Mat expected = ref_expand(data, out_w, out_h, out_c); + ncnn::Mat got; + if (run_expand(data, shape, got) != 0) + { + fprintf(stderr, "%s: forward failed\n", name); + return -1; + } + return check_equal(got, expected, name); +} + +// --- Tests --- + +static int test_expand_scalar_to_1d() +{ + ncnn::Mat data = RandomMat(1, 1, 1); + return test_expand(data, 10, 1, 1, "expand_scalar_to_w10"); +} + +static int test_expand_broadcast_w() +{ + // in_w=1 → out_w=5: exercises the scalar broadcast fill path (out_w < 16) + ncnn::Mat data = RandomMat(1, 3, 1); + return test_expand(data, 5, 3, 1, "expand_broadcast_w"); +} + +static int test_expand_broadcast_w_neon() +{ + // in_w=1 → out_w=20: out_w >= 16 triggers the NEON 4×-unrolled fill path + ncnn::Mat data = RandomMat(1, 4, 1); + return test_expand(data, 20, 4, 1, "expand_broadcast_w_neon"); +} + +static int test_expand_broadcast_h() +{ + ncnn::Mat data = RandomMat(4, 1, 1); + return test_expand(data, 4, 6, 1, "expand_broadcast_h"); +} + +static int test_expand_broadcast_c() +{ + ncnn::Mat data = RandomMat(4, 3, 1); + return test_expand(data, 4, 3, 8, "expand_broadcast_c"); +} + +static int test_expand_broadcast_wh() +{ + // Broadcasts both w and h simultaneously + ncnn::Mat data = RandomMat(1, 1, 3); + return test_expand(data, 8, 5, 3, "expand_broadcast_wh"); +} + +static int test_expand_full_broadcast() +{ + ncnn::Mat data = RandomMat(1, 1, 1); + return test_expand(data, 4, 6, 8, "expand_full_broadcast"); +} + +static int test_expand_no_broadcast() +{ + ncnn::Mat data = RandomMat(4, 3, 2); + return test_expand(data, 4, 3, 2, "expand_no_broadcast"); +} + +static int test_expand_1d_to_3d() +{ + ncnn::Mat data = RandomMat(4); + return test_expand(data, 4, 6, 8, "expand_1d_to_3d"); +} + +static int test_expand_2d_to_3d() +{ + ncnn::Mat data = RandomMat(4, 3); + return test_expand(data, 4, 3, 8, "expand_2d_to_3d"); +} + +// int64 shape blob — exercises the shape_is_int64 branch in Expand::forward. +static int test_expand_int64_shape() +{ + ncnn::Mat data = RandomMat(1, 2, 1); + ncnn::Mat shape = make_shape_i64(6, 2, 4); + ncnn::Mat expected = ref_expand(data, 6, 2, 4); + ncnn::Mat got; + if (run_expand(data, shape, got) != 0) + { + fprintf(stderr, "expand_int64_shape: forward failed\n"); + return -1; + } + return check_equal(got, expected, "expand_int64_shape"); +} + +// -1 in shape means "keep that dimension" (tgt_dim <= 0 branch). +static int test_expand_negative_one_shape() +{ + ncnn::Mat data = RandomMat(4, 3, 2); + // shape = (-1, -1, -1) should return data unchanged + ncnn::Mat shape = make_shape_i32(-1, -1, -1); + ncnn::Mat got; + if (run_expand(data, shape, got) != 0) + { + fprintf(stderr, "expand_negative_one_shape: forward failed\n"); + return -1; + } + return check_equal(got, data, "expand_negative_one_shape"); +} + +int main() +{ + SRAND(7767517); + + return 0 + || test_expand_scalar_to_1d() + || test_expand_broadcast_w() + || test_expand_broadcast_w_neon() + || test_expand_broadcast_h() + || test_expand_broadcast_c() + || test_expand_broadcast_wh() + || test_expand_full_broadcast() + || test_expand_no_broadcast() + || test_expand_1d_to_3d() + || test_expand_2d_to_3d() + || test_expand_int64_shape() + || test_expand_negative_one_shape(); +} diff --git a/tests/test_gather.cpp b/tests/test_gather.cpp new file mode 100644 index 000000000000..f53f78193dd7 --- /dev/null +++ b/tests/test_gather.cpp @@ -0,0 +1,393 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "testutil.h" + +// Run the Gather layer and return the output blob. +static int run_gather(const ncnn::Mat& data, const ncnn::Mat& indices, int axis, ncnn::Mat& out, + int num_threads = 1) +{ + ncnn::ParamDict pd; + pd.set(0, axis); + + ncnn::Option opt; + opt.num_threads = num_threads; + opt.use_vulkan_compute = false; + opt.use_packing_layout = false; + + ncnn::Layer* op = ncnn::create_layer_cpu("Gather"); + if (!op) + return -1; + + op->load_param(pd); + + std::vector weights(0); + ncnn::ModelBinFromMatArray mb(weights.data()); + op->load_model(mb); + op->create_pipeline(opt); + + std::vector bottom_blobs(2); + bottom_blobs[0] = data; + bottom_blobs[1] = indices; + + std::vector top_blobs(1); + int ret = op->forward(bottom_blobs, top_blobs, opt); + + op->destroy_pipeline(opt); + delete op; + + if (ret != 0) + return ret; + + out = top_blobs[0]; + return 0; +} + +// Read index at flat element offset, supporting int32 and int64. +static int read_flat_idx(const ncnn::Mat& m, int flat) +{ + if (m.elemsize == 8) + return (int)((const int64_t*)(const void*)m)[flat]; + return ((const int*)(const void*)m)[flat]; +} + +// Reference gather: PyTorch-style axis ordering (axis=0 = outermost). +// 1D axis=0: out[x] = data[idx[x]] +// 2D axis=0: out[y,x] = data[idx[y,x], x] +// 2D axis=1: out[y,x] = data[y, idx[y,x]] +// 3D axis=0: out[z,y,x] = data[idx[z,y,x], y, x] +// 3D axis=1: out[z,y,x] = data[z, idx[z,y,x], x] +// 3D axis=2: out[z,y,x] = data[z, y, idx[z,y,x]] +static ncnn::Mat ref_gather(const ncnn::Mat& data, const ncnn::Mat& indices, int axis) +{ + const int dims = data.dims; + int positive_axis = axis < 0 ? axis + dims : axis; + + int shape[3] = {1, 1, 1}; + if (dims == 1) + shape[0] = data.w; + else if (dims == 2) + { + shape[0] = data.h; + shape[1] = data.w; + } + else + { + shape[0] = data.c; + shape[1] = data.h; + shape[2] = data.w; + } + const int axis_size = shape[positive_axis]; + + ncnn::Mat out; + if (indices.dims == 1) + out.create(indices.w, (size_t)4u); + else if (indices.dims == 2) + out.create(indices.w, indices.h, (size_t)4u); + else + out.create(indices.w, indices.h, indices.c, (size_t)4u); + + const float* dp = data; + float* op_ptr = out; + + if (dims == 1) + { + for (int x = 0; x < indices.w; x++) + { + int gi = read_flat_idx(indices, x); + if (gi < 0) gi += axis_size; + if (gi < 0) gi = 0; + if (gi >= axis_size) gi = axis_size - 1; + op_ptr[x] = dp[gi]; + } + } + else if (dims == 2) + { + const int dw = data.w; + const int idxw = indices.w; + if (positive_axis == 0) + { + for (int y = 0; y < indices.h; y++) + for (int x = 0; x < idxw; x++) + { + int gi = read_flat_idx(indices, y * idxw + x); + if (gi < 0) gi += axis_size; + if (gi < 0) gi = 0; + if (gi >= axis_size) gi = axis_size - 1; + op_ptr[y * out.w + x] = dp[gi * dw + x]; + } + } + else + { + for (int y = 0; y < indices.h; y++) + for (int x = 0; x < idxw; x++) + { + int gi = read_flat_idx(indices, y * idxw + x); + if (gi < 0) gi += axis_size; + if (gi < 0) gi = 0; + if (gi >= axis_size) gi = axis_size - 1; + op_ptr[y * out.w + x] = dp[y * dw + gi]; + } + } + } + else // dims == 3 + { + const int dw = data.w; + const size_t d_cstep = data.cstep; + const size_t i_cstep = indices.cstep; + const size_t o_cstep = out.cstep; + const int idxw = indices.w; + + for (int z = 0; z < indices.c; z++) + for (int y = 0; y < indices.h; y++) + for (int x = 0; x < idxw; x++) + { + int gi = read_flat_idx(indices, (int)(z * i_cstep) + y * idxw + x); + if (gi < 0) gi += axis_size; + if (gi < 0) gi = 0; + if (gi >= axis_size) gi = axis_size - 1; + + float val; + if (positive_axis == 0) + val = dp[(int)(gi * d_cstep) + y * dw + x]; + else if (positive_axis == 1) + val = dp[(int)(z * d_cstep) + gi * dw + x]; + else + val = dp[(int)(z * d_cstep) + y * dw + gi]; + + op_ptr[(int)(z * o_cstep) + y * out.w + x] = val; + } + } + + return out; +} + +// Build an int32 index Mat with values in [0, axis_size). +// Uses a deterministic pattern: idx[i] = (i * 3 + 1) % axis_size. +static ncnn::Mat make_indices(int w, int h, int c, int axis_size) +{ + ncnn::Mat m; + if (c > 1) + m.create(w, h, c, (size_t)4u); + else if (h > 1) + m.create(w, h, (size_t)4u); + else + m.create(w, (size_t)4u); + + int* p = (int*)(void*)m; + int total = (int)m.total(); + for (int i = 0; i < total; i++) + p[i] = (i * 3 + 1) % axis_size; + return m; +} + +// Build an int64 index Mat with the same pattern. +static ncnn::Mat make_indices_i64(int w, int h, int c, int axis_size) +{ + ncnn::Mat m; + if (c > 1) + m.create(w, h, c, (size_t)8u); + else if (h > 1) + m.create(w, h, (size_t)8u); + else + m.create(w, (size_t)8u); + + int64_t* p = (int64_t*)(void*)m; + int total = (int)m.total(); + for (int i = 0; i < total; i++) + p[i] = (i * 3 + 1) % axis_size; + return m; +} + +static int check_equal(const ncnn::Mat& a, const ncnn::Mat& b, const char* name) +{ + if (a.dims != b.dims || a.w != b.w || a.h != b.h || a.c != b.c) + { + fprintf(stderr, "%s: shape mismatch got(%d %d %d dims=%d) expected(%d %d %d dims=%d)\n", + name, a.w, a.h, a.c, a.dims, b.w, b.h, b.c, b.dims); + return -1; + } + // Use explicit loops to avoid comparing uninitialized cstep padding bytes + const float* ad = (const float*)a.data; + const float* bd = (const float*)b.data; + for (int z = 0; z < a.c; z++) + for (int y = 0; y < a.h; y++) + for (int x = 0; x < a.w; x++) + { + float av = ad[z * a.cstep + y * a.w + x]; + float bv = bd[z * b.cstep + y * b.w + x]; + if (av != bv) + { + fprintf(stderr, "%s: value mismatch at z=%d y=%d x=%d: got %f expected %f\n", + name, z, y, x, av, bv); + return -1; + } + } + return 0; +} + +static int test_gather(const ncnn::Mat& data, const ncnn::Mat& indices, int axis, const char* name) +{ + ncnn::Mat expected = ref_gather(data, indices, axis); + ncnn::Mat got; + if (run_gather(data, indices, axis, got) != 0) + { + fprintf(stderr, "%s: forward failed\n", name); + return -1; + } + return check_equal(got, expected, name); +} + +static int test_gather_1d() +{ + ncnn::Mat data = RandomMat(10); + ncnn::Mat idx = make_indices(5, 1, 1, 10); + return test_gather(data, idx, 0, "gather_1d_axis0"); +} + +static int test_gather_2d() +{ + ncnn::Mat data = RandomMat(8, 5); // w=8 h=5 + + // axis=0 (PyTorch outermost = h, size=5), index shape [3,8] + ncnn::Mat idx0 = make_indices(8, 3, 1, 5); + if (test_gather(data, idx0, 0, "gather_2d_axis0") != 0) return -1; + + // axis=1 (PyTorch innermost = w, size=8), index shape [5,4] + ncnn::Mat idx1 = make_indices(4, 5, 1, 8); + if (test_gather(data, idx1, 1, "gather_2d_axis1") != 0) return -1; + + return 0; +} + +static int test_gather_3d() +{ + ncnn::Mat data = RandomMat(8, 6, 4); // w=8 h=6 c=4 + + // axis=0 (c, size=4), index shape [2,6,8] + ncnn::Mat idx0 = make_indices(8, 6, 2, 4); + if (test_gather(data, idx0, 0, "gather_3d_axis0") != 0) return -1; + + // axis=1 (h, size=6), index shape [4,3,8] + ncnn::Mat idx1 = make_indices(8, 3, 4, 6); + if (test_gather(data, idx1, 1, "gather_3d_axis1") != 0) return -1; + + // axis=2 (w, size=8), index shape [4,6,5] + ncnn::Mat idx2 = make_indices(5, 6, 4, 8); + if (test_gather(data, idx2, 2, "gather_3d_axis2") != 0) return -1; + + return 0; +} + +static int test_gather_negative_axis() +{ + ncnn::Mat data = RandomMat(8, 6, 4); // w=8 h=6 c=4 + + // axis=-1 == axis=2 (w, size=8) + ncnn::Mat idx = make_indices(5, 6, 4, 8); + if (test_gather(data, idx, -1, "gather_3d_axis-1") != 0) return -1; + + // axis=-3 == axis=0 (c, size=4) + ncnn::Mat idx0 = make_indices(8, 6, 2, 4); + if (test_gather(data, idx0, -3, "gather_3d_axis-3") != 0) return -1; + + return 0; +} + +static int test_gather_clamp() +{ + // 1D: out-of-range indices must clamp, not crash. + ncnn::Mat data = RandomMat(6); + ncnn::Mat idx; + idx.create(4, (size_t)4u); + int* p = (int*)(void*)idx; + p[0] = -10; // clamps to 0 + p[1] = 0; + p[2] = 5; + p[3] = 100; // clamps to 5 + + if (test_gather(data, idx, 0, "gather_clamp_1d") != 0) return -1; + + // 2D axis=0: out-of-range row indices + { + ncnn::Mat data2d = RandomMat(5, 4); // h=4, w=5 + ncnn::Mat idx2d; + idx2d.create(5, 3, (size_t)4u); // index shape [3, 5] + int* q = (int*)(void*)idx2d; + for (int i = 0; i < 15; i++) q[i] = (i % 3) - 1; // values: -1, 0, 1 + if (test_gather(data2d, idx2d, 0, "gather_clamp_2d_axis0") != 0) return -1; + } + + // 2D axis=1: out-of-range column indices + { + ncnn::Mat data2d = RandomMat(5, 4); + ncnn::Mat idx2d; + idx2d.create(3, 4, (size_t)4u); + int* q = (int*)(void*)idx2d; + for (int i = 0; i < 12; i++) q[i] = (i % 7) - 1; // includes -1 and 5+ + if (test_gather(data2d, idx2d, 1, "gather_clamp_2d_axis1") != 0) return -1; + } + + // 3D axis=2: out-of-range indices in the innermost dim + { + ncnn::Mat data3d = RandomMat(6, 4, 3); + ncnn::Mat idx3d; + idx3d.create(4, 4, 3, (size_t)4u); + int* q = (int*)(void*)idx3d; + for (int i = 0; i < (int)idx3d.total(); i++) q[i] = (i % 9) - 2; // includes negatives and overflow + if (test_gather(data3d, idx3d, 2, "gather_clamp_3d_axis2") != 0) return -1; + } + + return 0; +} + +// Multi-threaded: result must match single-threaded (catches OMP data races). +static int test_gather_multithread() +{ + ncnn::Mat data = RandomMat(16, 12, 8); + ncnn::Mat idx = make_indices(12, 8, 8, 12); // axis=1 (h=12) + + ncnn::Mat out_single, out_multi; + if (run_gather(data, idx, 1, out_single, 1) != 0 + || run_gather(data, idx, 1, out_multi, 4) != 0) + { + fprintf(stderr, "gather_multithread: forward failed\n"); + return -1; + } + return check_equal(out_single, out_multi, "gather_multithread"); +} + +static int test_gather_int64_indices() +{ + // Verify the int64 index path (elemsize==8) works identically to int32. + ncnn::Mat data = RandomMat(8, 5); // w=8 h=5 + + // 2D axis=0 with int64 indices + ncnn::Mat idx0_i64 = make_indices_i64(8, 3, 1, 5); + if (test_gather(data, idx0_i64, 0, "gather_i64_2d_axis0") != 0) return -1; + + // 2D axis=1 with int64 indices + ncnn::Mat idx1_i64 = make_indices_i64(4, 5, 1, 8); + if (test_gather(data, idx1_i64, 1, "gather_i64_2d_axis1") != 0) return -1; + + // 3D axis=1 with int64 indices + ncnn::Mat data3d = RandomMat(8, 6, 4); + ncnn::Mat idx3d_i64 = make_indices_i64(8, 3, 4, 6); + if (test_gather(data3d, idx3d_i64, 1, "gather_i64_3d_axis1") != 0) return -1; + + return 0; +} + +int main() +{ + SRAND(7767517); + + return 0 + || test_gather_1d() + || test_gather_2d() + || test_gather_3d() + || test_gather_negative_axis() + || test_gather_clamp() + || test_gather_int64_indices() + || test_gather_multithread(); +} diff --git a/tests/test_gatherelements.cpp b/tests/test_gatherelements.cpp new file mode 100644 index 000000000000..a7d07e5c62a1 --- /dev/null +++ b/tests/test_gatherelements.cpp @@ -0,0 +1,366 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "testutil.h" + +// Run the GatherElements layer and return the output blob. +static int run_gatherelements(const ncnn::Mat& data, const ncnn::Mat& indices, int axis, ncnn::Mat& out, + int num_threads = 1) +{ + ncnn::ParamDict pd; + pd.set(0, axis); + + ncnn::Option opt; + opt.num_threads = num_threads; + opt.use_vulkan_compute = false; + opt.use_packing_layout = false; + + ncnn::Layer* op = ncnn::create_layer_cpu("GatherElements"); + if (!op) + return -1; + + op->load_param(pd); + + std::vector weights(0); + ncnn::ModelBinFromMatArray mb(weights.data()); + op->load_model(mb); + op->create_pipeline(opt); + + std::vector bottom_blobs(2); + bottom_blobs[0] = data; + bottom_blobs[1] = indices; + + std::vector top_blobs(1); + int ret = op->forward(bottom_blobs, top_blobs, opt); + + op->destroy_pipeline(opt); + delete op; + + if (ret != 0) + return ret; + + out = top_blobs[0]; + return 0; +} + +// Read index at flat element offset, supporting int32 and int64. +static int read_flat_idx(const ncnn::Mat& m, int flat) +{ + if (m.elemsize == 8) + return (int)((const int64_t*)(const void*)m)[flat]; + return ((const int*)(const void*)m)[flat]; +} + +// Reference GatherElements: PyTorch-style axis ordering. +// Index has same rank as data. For each position (z,y,x) in index: +// axis=0: out[z,y,x] = data[idx[z,y,x], y, x] +// axis=1: out[z,y,x] = data[z, idx[z,y,x], x] +// axis=2: out[z,y,x] = data[z, y, idx[z,y,x]] +static ncnn::Mat ref_gatherelements(const ncnn::Mat& data, const ncnn::Mat& indices, int axis) +{ + const int dims = data.dims; + int positive_axis = axis < 0 ? axis + dims : axis; + + int shape[3] = {1, 1, 1}; + if (dims == 1) + shape[0] = data.w; + else if (dims == 2) + { + shape[0] = data.h; + shape[1] = data.w; + } + else + { + shape[0] = data.c; + shape[1] = data.h; + shape[2] = data.w; + } + const int axis_size = shape[positive_axis]; + + ncnn::Mat out; + if (indices.dims == 1) + out.create(indices.w, (size_t)4u); + else if (indices.dims == 2) + out.create(indices.w, indices.h, (size_t)4u); + else + out.create(indices.w, indices.h, indices.c, (size_t)4u); + + const float* dp = data; + float* op_ptr = out; + + if (dims == 1) + { + for (int x = 0; x < indices.w; x++) + { + int gi = read_flat_idx(indices, x); + if (gi < 0) gi += axis_size; + if (gi < 0) gi = 0; + if (gi >= axis_size) gi = axis_size - 1; + op_ptr[x] = dp[gi]; + } + } + else if (dims == 2) + { + const int dw = data.w; + const int idxw = indices.w; + for (int y = 0; y < indices.h; y++) + for (int x = 0; x < idxw; x++) + { + int gi = read_flat_idx(indices, y * idxw + x); + if (gi < 0) gi += axis_size; + if (gi < 0) gi = 0; + if (gi >= axis_size) gi = axis_size - 1; + + int flat_in = (positive_axis == 0) ? gi * dw + x : y * dw + gi; + op_ptr[y * out.w + x] = dp[flat_in]; + } + } + else // dims == 3 + { + const int dw = data.w; + const size_t d_cstep = data.cstep; + const size_t i_cstep = indices.cstep; + const size_t o_cstep = out.cstep; + const int idxw = indices.w; + + for (int z = 0; z < indices.c; z++) + for (int y = 0; y < indices.h; y++) + for (int x = 0; x < idxw; x++) + { + int gi = read_flat_idx(indices, (int)(z * i_cstep) + y * idxw + x); + if (gi < 0) gi += axis_size; + if (gi < 0) gi = 0; + if (gi >= axis_size) gi = axis_size - 1; + + int flat_in; + if (positive_axis == 0) + flat_in = (int)(gi * d_cstep) + y * dw + x; + else if (positive_axis == 1) + flat_in = (int)(z * d_cstep) + gi * dw + x; + else + flat_in = (int)(z * d_cstep) + y * dw + gi; + + op_ptr[(int)(z * o_cstep) + y * out.w + x] = dp[flat_in]; + } + } + + return out; +} + +// Build an int32 index Mat with values in [0, axis_size). +// Uses a deterministic pattern: idx[i] = (i * 3 + 1) % axis_size. +static ncnn::Mat make_indices(int w, int h, int c, int axis_size) +{ + ncnn::Mat m; + if (c > 1) + m.create(w, h, c, (size_t)4u); + else if (h > 1) + m.create(w, h, (size_t)4u); + else + m.create(w, (size_t)4u); + + int* p = (int*)(void*)m; + int total = (int)m.total(); + for (int i = 0; i < total; i++) + p[i] = (i * 3 + 1) % axis_size; + return m; +} + +// Build an int64 index Mat with the same pattern. +static ncnn::Mat make_indices_i64(int w, int h, int c, int axis_size) +{ + ncnn::Mat m; + if (c > 1) + m.create(w, h, c, (size_t)8u); + else if (h > 1) + m.create(w, h, (size_t)8u); + else + m.create(w, (size_t)8u); + + int64_t* p = (int64_t*)(void*)m; + int total = (int)m.total(); + for (int i = 0; i < total; i++) + p[i] = (i * 3 + 1) % axis_size; + return m; +} + +static int check_equal(const ncnn::Mat& a, const ncnn::Mat& b, const char* name) +{ + if (a.dims != b.dims || a.w != b.w || a.h != b.h || a.c != b.c) + { + fprintf(stderr, "%s: shape mismatch got(%d %d %d dims=%d) expected(%d %d %d dims=%d)\n", + name, a.w, a.h, a.c, a.dims, b.w, b.h, b.c, b.dims); + return -1; + } + // Use explicit loops to avoid comparing uninitialized cstep padding bytes + const float* ad = (const float*)a.data; + const float* bd = (const float*)b.data; + for (int z = 0; z < a.c; z++) + for (int y = 0; y < a.h; y++) + for (int x = 0; x < a.w; x++) + { + float av = ad[z * a.cstep + y * a.w + x]; + float bv = bd[z * b.cstep + y * b.w + x]; + if (av != bv) + { + fprintf(stderr, "%s: value mismatch at z=%d y=%d x=%d: got %f expected %f\n", + name, z, y, x, av, bv); + return -1; + } + } + return 0; +} + +static int test_gatherelements(const ncnn::Mat& data, const ncnn::Mat& indices, int axis, const char* name) +{ + ncnn::Mat expected = ref_gatherelements(data, indices, axis); + ncnn::Mat got; + int ret = run_gatherelements(data, indices, axis, got); + if (ret != 0) + { + fprintf(stderr, "%s: forward failed\n", name); + return -1; + } + return check_equal(got, expected, name); +} + +static int test_gatherelements_1d() +{ + ncnn::Mat data = RandomMat(10); + ncnn::Mat idx = make_indices(5, 1, 1, 10); + return test_gatherelements(data, idx, 0, "gatherelements_1d_axis0"); +} + +static int test_gatherelements_2d() +{ + ncnn::Mat data = RandomMat(8, 5); // w=8 h=5 + + // axis=0 (h, size=5), index shape [3,8] + ncnn::Mat idx0 = make_indices(8, 3, 1, 5); + if (test_gatherelements(data, idx0, 0, "gatherelements_2d_axis0") != 0) return -1; + + // axis=1 (w, size=8), index shape [5,4] + ncnn::Mat idx1 = make_indices(4, 5, 1, 8); + if (test_gatherelements(data, idx1, 1, "gatherelements_2d_axis1") != 0) return -1; + + return 0; +} + +static int test_gatherelements_3d() +{ + ncnn::Mat data = RandomMat(8, 6, 4); // w=8 h=6 c=4 + + // axis=0 (c, size=4), index shape [2,6,8] + ncnn::Mat idx0 = make_indices(8, 6, 2, 4); + if (test_gatherelements(data, idx0, 0, "gatherelements_3d_axis0") != 0) return -1; + + // axis=1 (h, size=6), index shape [4,3,8] + ncnn::Mat idx1 = make_indices(8, 3, 4, 6); + if (test_gatherelements(data, idx1, 1, "gatherelements_3d_axis1") != 0) return -1; + + // axis=2 (w, size=8), index shape [4,6,5] + ncnn::Mat idx2 = make_indices(5, 6, 4, 8); + if (test_gatherelements(data, idx2, 2, "gatherelements_3d_axis2") != 0) return -1; + + return 0; +} + +static int test_gatherelements_negative_axis() +{ + ncnn::Mat data = RandomMat(8, 6, 4); // w=8 h=6 c=4 + + // axis=-1 == axis=2 (w, size=8) + ncnn::Mat idx = make_indices(5, 6, 4, 8); + if (test_gatherelements(data, idx, -1, "gatherelements_3d_axis-1") != 0) return -1; + + // axis=-3 == axis=0 (c, size=4) + ncnn::Mat idx0 = make_indices(8, 6, 2, 4); + if (test_gatherelements(data, idx0, -3, "gatherelements_3d_axis-3") != 0) return -1; + + return 0; +} + +static int test_gatherelements_clamp() +{ + // 1D: out-of-range indices must clamp, not crash. + ncnn::Mat data = RandomMat(6); + ncnn::Mat idx; + idx.create(4, (size_t)4u); + int* p = (int*)(void*)idx; + p[0] = -10; // clamps to 0 + p[1] = 0; + p[2] = 5; + p[3] = 100; // clamps to 5 + + if (test_gatherelements(data, idx, 0, "gatherelements_clamp_1d") != 0) return -1; + + // 2D axis=0: out-of-range row indices + { + ncnn::Mat data2d = RandomMat(5, 4); + ncnn::Mat idx2d; + idx2d.create(5, 4, (size_t)4u); // same shape as data (GatherElements requirement) + int* q = (int*)(void*)idx2d; + for (int i = 0; i < 20; i++) q[i] = (i % 5) - 1; // includes -1 and 3+ + if (test_gatherelements(data2d, idx2d, 0, "gatherelements_clamp_2d_axis0") != 0) return -1; + } + + // 3D axis=1: out-of-range height indices + { + ncnn::Mat data3d = RandomMat(6, 4, 3); + ncnn::Mat idx3d; + idx3d.create(6, 4, 3, (size_t)4u); + int* q = (int*)(void*)idx3d; + for (int i = 0; i < (int)idx3d.total(); i++) q[i] = (i % 7) - 2; + if (test_gatherelements(data3d, idx3d, 1, "gatherelements_clamp_3d_axis1") != 0) return -1; + } + + return 0; +} + +// Multi-threaded: result must match single-threaded (catches OMP data races). +static int test_gatherelements_multithread() +{ + ncnn::Mat data = RandomMat(16, 12, 8); + ncnn::Mat idx = make_indices(16, 12, 8, 12); // axis=1 (h=12) + + ncnn::Mat out_single, out_multi; + if (run_gatherelements(data, idx, 1, out_single, 1) != 0 + || run_gatherelements(data, idx, 1, out_multi, 4) != 0) + { + fprintf(stderr, "gatherelements_multithread: forward failed\n"); + return -1; + } + return check_equal(out_single, out_multi, "gatherelements_multithread"); +} + +static int test_gatherelements_int64_indices() +{ + // Verify the int64 index path (elemsize==8) works identically to int32. + ncnn::Mat data = RandomMat(8, 5); // w=8 h=5 + + ncnn::Mat idx0_i64 = make_indices_i64(8, 3, 1, 5); + if (test_gatherelements(data, idx0_i64, 0, "gatherelements_i64_2d_axis0") != 0) return -1; + + ncnn::Mat idx1_i64 = make_indices_i64(4, 5, 1, 8); + if (test_gatherelements(data, idx1_i64, 1, "gatherelements_i64_2d_axis1") != 0) return -1; + + ncnn::Mat data3d = RandomMat(8, 6, 4); + ncnn::Mat idx3d_i64 = make_indices_i64(8, 3, 4, 6); + if (test_gatherelements(data3d, idx3d_i64, 1, "gatherelements_i64_3d_axis1") != 0) return -1; + + return 0; +} + +int main() +{ + SRAND(7767517); + + return 0 + || test_gatherelements_1d() + || test_gatherelements_2d() + || test_gatherelements_3d() + || test_gatherelements_negative_axis() + || test_gatherelements_clamp() + || test_gatherelements_int64_indices() + || test_gatherelements_multithread(); +} diff --git a/tests/test_mod.cpp b/tests/test_mod.cpp new file mode 100644 index 000000000000..5eb7c8efd9e8 --- /dev/null +++ b/tests/test_mod.cpp @@ -0,0 +1,232 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "testutil.h" + +#include + +static int run_mod(const ncnn::Mat& a, const ncnn::Mat& b, int fmode, ncnn::Mat& out) +{ + ncnn::ParamDict pd; + pd.set(0, fmode); + + ncnn::Option opt; + opt.num_threads = 1; + opt.use_vulkan_compute = false; + opt.use_packing_layout = false; + + ncnn::Layer* op = ncnn::create_layer_cpu("Mod"); + if (!op) + return -1; + + op->load_param(pd); + + std::vector weights(0); + ncnn::ModelBinFromMatArray mb(weights.data()); + op->load_model(mb); + op->create_pipeline(opt); + + std::vector bottom_blobs(2); + bottom_blobs[0] = a; + bottom_blobs[1] = b; + + std::vector top_blobs(1); + int ret = op->forward(bottom_blobs, top_blobs, opt); + + op->destroy_pipeline(opt); + delete op; + + if (ret != 0) + return ret; + + out = top_blobs[0]; + return 0; +} + +// Compare layer output against fmodf reference with exact equality. +// The impl uses ::fmodf (float-precision), so results must be bit-identical. +static int test_mod(int w, int h, int c, int fmode, const char* name) +{ + ncnn::Mat a = RandomMat(w, h, c); + ncnn::Mat b = RandomMat(w, h, c); + + // Ensure b is non-zero + for (int z = 0; z < c; z++) + for (int y = 0; y < h; y++) + for (int x = 0; x < w; x++) + { + float* bp = (float*)b + z * (int)b.cstep + y * w + x; + if (*bp == 0.0f) *bp = 1.0f; + } + + ncnn::Mat out; + if (run_mod(a, b, fmode, out) != 0) + { + fprintf(stderr, "%s: forward failed\n", name); + return -1; + } + + if (out.w != w || out.h != h || out.c != c) + { + fprintf(stderr, "%s: shape mismatch\n", name); + return -1; + } + + for (int z = 0; z < c; z++) + for (int y = 0; y < h; y++) + for (int x = 0; x < w; x++) + { + float val_a = ((const float*)a)[z * (int)a.cstep + y * w + x]; + float val_b = ((const float*)b)[z * (int)b.cstep + y * w + x]; + float val_out = ((const float*)out)[z * (int)out.cstep + y * w + x]; + + float expected; + if (fmode == 0) + { + expected = fmodf(val_a, val_b); + if (expected != 0.0f && (val_b < 0.0f) != (expected < 0.0f)) + expected += val_b; + } + else + { + expected = fmodf(val_a, val_b); + } + + if (val_out != expected) + { + fprintf(stderr, "%s: value mismatch at z=%d y=%d x=%d: got %f expected %f\n", + name, z, y, x, val_out, expected); + return -1; + } + } + return 0; +} + +// Zero divisor: b=0 must return 0, not crash. +static int test_mod_zero_divisor() +{ + ncnn::Mat a(5, (size_t)4u); + ncnn::Mat b(5, (size_t)4u); + float* ap = a; + float* bp = b; + ap[0] = 7.f; + ap[1] = -3.f; + ap[2] = 0.f; + ap[3] = 100.f; + ap[4] = -50.f; + for (int i = 0; i < 5; i++) bp[i] = 0.0f; + + ncnn::Mat out; + for (int fmode = 0; fmode <= 1; fmode++) + { + if (run_mod(a, b, fmode, out) != 0) + { + fprintf(stderr, "test_mod_zero_divisor fmode=%d: forward failed\n", fmode); + return -1; + } + const float* op = out; + for (int i = 0; i < 5; i++) + { + if (op[i] != 0.0f) + { + fprintf(stderr, "test_mod_zero_divisor fmode=%d: expected 0 at %d, got %f\n", + fmode, i, op[i]); + return -1; + } + } + } + return 0; +} + +// Python-style mod with known negative inputs/divisors. +static int test_mod_negative_values() +{ + ncnn::Mat a(6, (size_t)4u); + ncnn::Mat b(6, (size_t)4u); + float avals[6] = {-10, -8, -6, -4, -2, 0}; + float bvals[6] = {3, 3, 3, 3, 3, 3}; + float* ap = a; + float* bp = b; + for (int i = 0; i < 6; i++) + { + ap[i] = avals[i]; + bp[i] = bvals[i]; + } + + ncnn::Mat out; + if (run_mod(a, b, 0, out) != 0) + { + fprintf(stderr, "test_mod_negative_values: forward failed\n"); + return -1; + } + // Python mod: -10%3=2, -8%3=1, -6%3=0, -4%3=2, -2%3=1, 0%3=0 + float expected[6] = {2, 1, 0, 2, 1, 0}; + const float* op = out; + for (int i = 0; i < 6; i++) + { + if (op[i] != expected[i]) + { + fprintf(stderr, "test_mod_negative_values: mismatch at %d: got %f expected %f\n", + i, op[i], expected[i]); + return -1; + } + } + return 0; +} + +// C-style fmod with negative b — sign of result follows the dividend, not divisor. +static int test_mod_fmod1_negative_b() +{ + ncnn::Mat a(4, (size_t)4u); + ncnn::Mat b(4, (size_t)4u); + float* ap = a; + float* bp = b; + ap[0] = 7.f; + bp[0] = -3.f; // fmod(7, -3) = 1 (sign of dividend +7) + ap[1] = -7.f; + bp[1] = 3.f; // fmod(-7, 3) = -1 (sign of dividend -7) + ap[2] = -7.f; + bp[2] = -3.f; // fmod(-7, -3) = -1 + ap[3] = 6.f; + bp[3] = -2.f; // fmod(6, -2) = 0 + + ncnn::Mat out; + if (run_mod(a, b, 1, out) != 0) + { + fprintf(stderr, "test_mod_fmod1_negative_b: forward failed\n"); + return -1; + } + const float* op = out; + float expected[4] = { + fmodf(7.f, -3.f), + fmodf(-7.f, 3.f), + fmodf(-7.f, -3.f), + fmodf(6.f, -2.f) + }; + for (int i = 0; i < 4; i++) + { + if (op[i] != expected[i]) + { + fprintf(stderr, "test_mod_fmod1_negative_b: mismatch at %d: got %f expected %f\n", + i, op[i], expected[i]); + return -1; + } + } + return 0; +} + +int main() +{ + SRAND(7767517); + + return 0 + || test_mod(10, 1, 1, 0, "mod_1d_python") + || test_mod(10, 1, 1, 1, "mod_1d_c") + || test_mod(8, 6, 1, 0, "mod_2d_python") + || test_mod(8, 6, 1, 1, "mod_2d_c") + || test_mod(4, 6, 8, 0, "mod_3d_python") + || test_mod(4, 6, 8, 1, "mod_3d_c") + || test_mod_zero_divisor() + || test_mod_negative_values() + || test_mod_fmod1_negative_b(); +} diff --git a/tests/test_topk.cpp b/tests/test_topk.cpp new file mode 100644 index 000000000000..04b9a723bd2b --- /dev/null +++ b/tests/test_topk.cpp @@ -0,0 +1,592 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "testutil.h" + +#if NCNN_SIMPLESTL +static const float TEST_INF = 1.f / 0.f; +static const float TEST_NAN = 0.f / 0.f; +#define INFINITY TEST_INF +#define NAN TEST_NAN +#else +#include +#include +#endif + +// Unified runner: want_indices=false → top_blobs(1), else top_blobs(2). +static int run_topk(const ncnn::Mat& a, int axis, int k, int largest, int sorted, + bool want_indices, ncnn::Mat& values, ncnn::Mat& indices) +{ + ncnn::ParamDict pd; + pd.set(0, axis); + pd.set(1, largest); + pd.set(2, sorted); + pd.set(3, k); + + ncnn::Option opt; + opt.num_threads = 1; + opt.use_vulkan_compute = false; + opt.use_packing_layout = false; + + ncnn::Layer* op = ncnn::create_layer_cpu("TopK"); + if (!op) + return -1; + + op->load_param(pd); + + std::vector weights(0); + ncnn::ModelBinFromMatArray mb(weights.data()); + op->load_model(mb); + op->create_pipeline(opt); + + std::vector bottom_blobs(1); + bottom_blobs[0] = a; + std::vector top_blobs(want_indices ? 2 : 1); + int ret = op->forward(bottom_blobs, top_blobs, opt); + + op->destroy_pipeline(opt); + delete op; + + if (ret != 0) + return ret; + + values = top_blobs[0]; + if (want_indices) + indices = top_blobs[1]; + return 0; +} + +static int test_topk(const ncnn::Mat& a, int axis, int k, int largest, int sorted) +{ + ncnn::ParamDict pd; + pd.set(0, axis); + pd.set(1, largest); + pd.set(2, sorted); + pd.set(3, k); + + std::vector weights(0); + std::vector a0(1); + a0[0] = a; + + int ret = test_layer("TopK", pd, weights, a0, 2, 0.01f, TEST_LAYER_DISABLE_AUTO_INPUT_CASTING); + if (ret != 0) + { + fprintf(stderr, "test_topk failed a.dims=%d a=(%d %d %d %d) axis=%d k=%d largest=%d sorted=%d\n", + a.dims, a.w, a.h, a.d, a.c, axis, k, largest, sorted); + } + return ret; +} + +static int test_topk_0() +{ + ncnn::Mat a = RandomMat(13); + + return 0 + || test_topk(a, 0, 1, 1, 1) + || test_topk(a, 0, 5, 1, 1) + || test_topk(a, 0, 1, 0, 0) + || test_topk(a, -1, 7, 0, 1) + || test_topk(a, 0, 4, 1, 0) + || test_topk(a, 0, 9, 1, 1); +} + +static int test_topk_1() +{ + ncnn::Mat a = RandomMat(12, 17); + + return 0 + || test_topk(a, 0, 1, 1, 1) + || test_topk(a, 0, 5, 1, 1) + || test_topk(a, 1, 3, 1, 1) + || test_topk(a, -1, 8, 0, 1) + || test_topk(a, 1, 6, 0, 0) + || test_topk(a, -2, 7, 1, 1); +} + +static int test_topk_2() +{ + ncnn::Mat a = RandomMat(8, 9, 11); + + return 0 + || test_topk(a, 0, 3, 1, 1) + || test_topk(a, 1, 4, 1, 1) + || test_topk(a, 2, 2, 0, 1) + || test_topk(a, 2, 5, 1, 0) + || test_topk(a, -1, 6, 1, 1) + || test_topk(a, -2, 5, 0, 1) + || test_topk(a, -3, 7, 1, 1); +} + +static int test_topk_3() +{ + ncnn::Mat a = RandomMat(5, 7, 9, 10); + + return 0 + || test_topk(a, 0, 2, 1, 1) + || test_topk(a, 1, 3, 0, 1) + || test_topk(a, 2, 4, 1, 1) + || test_topk(a, 3, 4, 0, 0) + || test_topk(a, 3, 5, 1, 1) + || test_topk(a, -1, 6, 0, 1) + || test_topk(a, -2, 3, 1, 1) + || test_topk(a, -3, 4, 0, 1) + || test_topk(a, -4, 2, 1, 1); +} + +static int test_topk_inf_order() +{ + ncnn::Mat a(6); + float* ptr = a; + ptr[0] = 1.f; + ptr[1] = INFINITY; + ptr[2] = -2.f; + ptr[3] = -INFINITY; + ptr[4] = 0.5f; + ptr[5] = 3.f; + + ncnn::Mat values, indices; + + if (run_topk(a, 0, 2, 1, 1, true, values, indices) != 0) + { + fprintf(stderr, "test_topk_inf_order largest failed\n"); + return -1; + } + const float* vptr = values; + const int* iptr = (const int*)(const void*)indices; + if (values.w != 2 || indices.w != 2 || vptr[0] != INFINITY || vptr[1] != 3.f || iptr[0] != 1 || iptr[1] != 5) + { + fprintf(stderr, "test_topk_inf_order largest mismatch\n"); + return -1; + } + + if (run_topk(a, 0, 2, 0, 1, true, values, indices) != 0) + { + fprintf(stderr, "test_topk_inf_order smallest failed\n"); + return -1; + } + vptr = values; + iptr = (const int*)(const void*)indices; + if (values.w != 2 || indices.w != 2 || vptr[0] != -INFINITY || vptr[1] != -2.f || iptr[0] != 3 || iptr[1] != 2) + { + fprintf(stderr, "test_topk_inf_order smallest mismatch\n"); + return -1; + } + + return 0; +} + +static int test_topk_nan_robust() +{ + // NaN mid-array: [1, NaN, 2, -1], k=2, largest → {2@2, 1@0} + ncnn::Mat a(4); + float* ptr = a; + ptr[0] = 1.f; + ptr[1] = NAN; + ptr[2] = 2.f; + ptr[3] = -1.f; + + ncnn::Mat values, indices; + + if (run_topk(a, 0, 2, 1, 1, true, values, indices) != 0) + { + fprintf(stderr, "test_topk_nan_robust sorted failed\n"); + return -1; + } + const float* vptr = values; + const int* iptr = (const int*)(const void*)indices; + if (values.w != 2 || vptr[0] != 2.f || vptr[1] != 1.f || iptr[0] != 2 || iptr[1] != 0) + { + fprintf(stderr, "test_topk_nan_robust sorted largest mismatch\n"); + return -1; + } + + if (run_topk(a, 0, 2, 0, 1, true, values, indices) != 0) + { + fprintf(stderr, "test_topk_nan_robust sorted smallest failed\n"); + return -1; + } + vptr = values; + iptr = (const int*)(const void*)indices; + if (values.w != 2 || vptr[0] != -1.f || vptr[1] != 1.f || iptr[0] != 3 || iptr[1] != 0) + { + fprintf(stderr, "test_topk_nan_robust sorted smallest mismatch\n"); + return -1; + } + + if (run_topk(a, 0, 2, 1, 0, true, values, indices) != 0) + { + fprintf(stderr, "test_topk_nan_robust unsorted failed\n"); + return -1; + } + iptr = (const int*)(const void*)indices; + if (iptr[0] < 0 || iptr[0] >= 4 || iptr[1] < 0 || iptr[1] >= 4) + { + fprintf(stderr, "test_topk_nan_robust unsorted invalid indices\n"); + return -1; + } + + return 0; +} + +// NaN at index 0 — exercises `has_nan = topk_isnan(best_value)` at the top of +// the k=1 scalar fast path; without this, the fast loop is entered with a NaN +// as the running best and comparisons are silently wrong. +static int test_topk_nan_first_element() +{ + ncnn::Mat a(5); + float* ptr = a; + ptr[0] = NAN; + ptr[1] = 3.f; + ptr[2] = 1.f; + ptr[3] = 5.f; + ptr[4] = 2.f; + + ncnn::Mat values, indices; + + // k=1 largest: best is 5@3 + if (run_topk(a, 0, 1, 1, 1, true, values, indices) != 0) + { + fprintf(stderr, "test_topk_nan_first_element k1 failed\n"); + return -1; + } + const float* vp = values; + const int* ip = (const int*)(const void*)indices; + if (values.w != 1 || vp[0] != 5.f || ip[0] != 3) + { + fprintf(stderr, "test_topk_nan_first_element k1 mismatch v=%f i=%d\n", vp[0], ip[0]); + return -1; + } + + // k=2 smallest sorted: {1@2, 2@4} + if (run_topk(a, 0, 2, 0, 1, true, values, indices) != 0) + { + fprintf(stderr, "test_topk_nan_first_element k2 failed\n"); + return -1; + } + vp = values; + ip = (const int*)(const void*)indices; + if (values.w != 2 || vp[0] != 1.f || vp[1] != 2.f || ip[0] != 2 || ip[1] != 4) + { + fprintf(stderr, "test_topk_nan_first_element k2 mismatch\n"); + return -1; + } + + return 0; +} + +// Multiple NaN values — exercises NaN eviction from the k-buffer in the k≤4 path. +static int test_topk_multiple_nans() +{ + ncnn::Mat a(7); + float* ptr = a; + ptr[0] = NAN; + ptr[1] = 2.f; + ptr[2] = NAN; + ptr[3] = 5.f; + ptr[4] = NAN; + ptr[5] = 1.f; + ptr[6] = NAN; + + ncnn::Mat values, indices; + + // k=2, largest, sorted: {5@3, 2@1} + if (run_topk(a, 0, 2, 1, 1, true, values, indices) != 0) + { + fprintf(stderr, "test_topk_multiple_nans failed\n"); + return -1; + } + const float* vp = values; + const int* ip = (const int*)(const void*)indices; + if (values.w != 2 || vp[0] != 5.f || vp[1] != 2.f || ip[0] != 3 || ip[1] != 1) + { + fprintf(stderr, "test_topk_multiple_nans mismatch v=[%f,%f] i=[%d,%d]\n", + vp[0], vp[1], ip[0], ip[1]); + return -1; + } + + // k=3, smallest, sorted: {1@5, 2@1, 5@3} + if (run_topk(a, 0, 3, 0, 1, true, values, indices) != 0) + { + fprintf(stderr, "test_topk_multiple_nans k3 failed\n"); + return -1; + } + vp = values; + ip = (const int*)(const void*)indices; + if (values.w != 3 || vp[0] != 1.f || vp[1] != 2.f || vp[2] != 5.f + || ip[0] != 5 || ip[1] != 1 || ip[2] != 3) + { + fprintf(stderr, "test_topk_multiple_nans k3 mismatch\n"); + return -1; + } + + return 0; +} + +// sorted=0 must return the same SET of top-k values as sorted=1. +static int test_topk_sorted0_vs_sorted1() +{ + ncnn::Mat a(8); + float* ptr = a; + ptr[0] = 3.f; + ptr[1] = 1.f; + ptr[2] = 4.f; + ptr[3] = 1.f; + ptr[4] = 5.f; + ptr[5] = 9.f; + ptr[6] = 2.f; + ptr[7] = 6.f; + + ncnn::Mat sv, uv, dummy; + + // k=3, largest + if (run_topk(a, 0, 3, 1, 1, false, sv, dummy) != 0 + || run_topk(a, 0, 3, 1, 0, false, uv, dummy) != 0) + { + fprintf(stderr, "test_topk_sorted0_vs_sorted1: forward failed\n"); + return -1; + } + { + float s[3], u[3]; + const float* sp = sv; + const float* up = uv; + for (int i = 0; i < 3; i++) + { + s[i] = sp[i]; + u[i] = up[i]; + } + std::sort(s, s + 3); + std::sort(u, u + 3); + for (int i = 0; i < 3; i++) + { + if (s[i] != u[i]) + { + fprintf(stderr, "test_topk_sorted0_vs_sorted1 largest: value set mismatch at %d: sorted=%f unsorted=%f\n", + i, s[i], u[i]); + return -1; + } + } + } + + // k=4, smallest + if (run_topk(a, 0, 4, 0, 1, false, sv, dummy) != 0 + || run_topk(a, 0, 4, 0, 0, false, uv, dummy) != 0) + { + fprintf(stderr, "test_topk_sorted0_vs_sorted1: smallest forward failed\n"); + return -1; + } + { + float s[4], u[4]; + const float* sp = sv; + const float* up = uv; + for (int i = 0; i < 4; i++) + { + s[i] = sp[i]; + u[i] = up[i]; + } + std::sort(s, s + 4); + std::sort(u, u + 4); + for (int i = 0; i < 4; i++) + { + if (s[i] != u[i]) + { + fprintf(stderr, "test_topk_sorted0_vs_sorted1 smallest: value set mismatch at %d\n", i); + return -1; + } + } + } + + return 0; +} + +// Equal values → lower original index wins as tiebreak. +static int test_topk_tie_breaking() +{ + ncnn::Mat a(5); + float* ptr = a; + ptr[0] = 5.f; + ptr[1] = 5.f; + ptr[2] = 3.f; + ptr[3] = 5.f; + ptr[4] = 1.f; + + ncnn::Mat values, indices; + + // Top-2 largest: 5@0, 5@1 (lower indices win) + if (run_topk(a, 0, 2, 1, 1, true, values, indices) != 0) + { + fprintf(stderr, "test_topk_tie_breaking: forward failed\n"); + return -1; + } + const float* vp = values; + const int* ip = (const int*)(const void*)indices; + if (values.w != 2 || vp[0] != 5.f || vp[1] != 5.f || ip[0] != 0 || ip[1] != 1) + { + fprintf(stderr, "test_topk_tie_breaking largest: got v=[%f,%f] i=[%d,%d]\n", + vp[0], vp[1], ip[0], ip[1]); + return -1; + } + + // Top-2 smallest: 1@4, 3@2 + if (run_topk(a, 0, 2, 0, 1, true, values, indices) != 0) + { + fprintf(stderr, "test_topk_tie_breaking: smallest forward failed\n"); + return -1; + } + vp = values; + ip = (const int*)(const void*)indices; + if (values.w != 2 || vp[0] != 1.f || vp[1] != 3.f || ip[0] != 4 || ip[1] != 2) + { + fprintf(stderr, "test_topk_tie_breaking smallest: got v=[%f,%f] i=[%d,%d]\n", + vp[0], vp[1], ip[0], ip[1]); + return -1; + } + + return 0; +} + +// k=0 must produce empty output without crashing. +static int test_topk_k_zero() +{ + ncnn::Mat a(6); + float* ptr = a; + for (int i = 0; i < 6; i++) ptr[i] = (float)i; + + ncnn::Mat values, indices; + if (run_topk(a, 0, 0, 1, 1, true, values, indices) != 0) + { + fprintf(stderr, "test_topk_k_zero: forward failed\n"); + return -1; + } + if (values.total() != 0 || indices.total() != 0) + { + fprintf(stderr, "test_topk_k_zero: expected empty output, got values=%d indices=%d\n", + (int)values.total(), (int)indices.total()); + return -1; + } + return 0; +} + +// k > axis_size must be clamped to axis_size. +static int test_topk_k_clamp() +{ + ncnn::Mat a(4); + float* ptr = a; + ptr[0] = 1.f; + ptr[1] = 4.f; + ptr[2] = 3.f; + ptr[3] = 2.f; + + ncnn::Mat values, indices; + if (run_topk(a, 0, 10, 1, 1, true, values, indices) != 0) + { + fprintf(stderr, "test_topk_k_clamp: forward failed\n"); + return -1; + } + const float* vp = values; + const int* ip = (const int*)(const void*)indices; + // clamped to k=4, sorted largest: 4@1, 3@2, 2@3, 1@0 + if ((int)values.total() != 4 || vp[0] != 4.f || vp[1] != 3.f || vp[2] != 2.f || vp[3] != 1.f + || ip[0] != 1 || ip[1] != 2 || ip[2] != 3 || ip[3] != 0) + { + fprintf(stderr, "test_topk_k_clamp: mismatch\n"); + return -1; + } + return 0; +} + +static int test_topk_values_only_fastpaths() +{ + ncnn::Mat a(5); + float* ptr = a; + ptr[0] = 1.f; + ptr[1] = -2.f; + ptr[2] = 4.f; + ptr[3] = 3.f; + ptr[4] = 0.f; + + ncnn::Mat values, dummy; + + // k=1, values-only (triggers NEON path on ARM when axis_size >= 4) + if (run_topk(a, 0, 1, 1, 0, false, values, dummy) != 0) + { + fprintf(stderr, "test_topk_values_only_fastpaths k1 failed\n"); + return -1; + } + if (values.w != 1 || ((const float*)values)[0] != 4.f) + { + fprintf(stderr, "test_topk_values_only_fastpaths k1 mismatch\n"); + return -1; + } + + // k=full, values-only (copy-all fast path) + if (run_topk(a, 0, 5, 1, 0, false, values, dummy) != 0) + { + fprintf(stderr, "test_topk_values_only_fastpaths fullk failed\n"); + return -1; + } + if (values.w != 5) + { + fprintf(stderr, "test_topk_values_only_fastpaths fullk shape mismatch\n"); + return -1; + } + const float* vptr = values; + for (int i = 0; i < 5; i++) + { + if (vptr[i] != ptr[i]) + { + fprintf(stderr, "test_topk_values_only_fastpaths fullk value mismatch at %d\n", i); + return -1; + } + } + + // k=1, values-only, smallest — exercises NEON min path + if (run_topk(a, 0, 1, 0, 0, false, values, dummy) != 0) + { + fprintf(stderr, "test_topk_values_only_fastpaths k1_min failed\n"); + return -1; + } + if (values.w != 1 || ((const float*)values)[0] != -2.f) + { + fprintf(stderr, "test_topk_values_only_fastpaths k1_min mismatch: got %f\n", + ((const float*)values)[0]); + return -1; + } + + return 0; +} + +static int test_topk_full_k() +{ + ncnn::Mat a2d = RandomMat(8, 5); + if (test_topk(a2d, 0, 5, 1, 1) != 0) return -1; + if (test_topk(a2d, 0, 5, 0, 1) != 0) return -1; + if (test_topk(a2d, 1, 8, 1, 1) != 0) return -1; + + ncnn::Mat a3d = RandomMat(6, 4, 3); + if (test_topk(a3d, 0, 3, 1, 1) != 0) return -1; + if (test_topk(a3d, 1, 4, 1, 1) != 0) return -1; + if (test_topk(a3d, 2, 6, 1, 1) != 0) return -1; + + return 0; +} + +int main() +{ + SRAND(7767517); + + return 0 + || test_topk_0() + || test_topk_1() + || test_topk_2() + || test_topk_3() + || test_topk_inf_order() + || test_topk_nan_robust() + || test_topk_nan_first_element() + || test_topk_multiple_nans() + || test_topk_sorted0_vs_sorted1() + || test_topk_tie_breaking() + || test_topk_k_zero() + || test_topk_k_clamp() + || test_topk_values_only_fastpaths() + || test_topk_full_k(); +} diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 284d8dac16fb..7bb7098d97b5 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -16,6 +16,7 @@ endif() add_subdirectory(caffe) add_subdirectory(mxnet) add_subdirectory(darknet) +add_subdirectory(onnx) if(NCNN_INT8) add_subdirectory(quantize) else() diff --git a/tools/pnnx/CMakeLists.txt b/tools/pnnx/CMakeLists.txt index e50ab4788c3d..73d5fdb9733c 100644 --- a/tools/pnnx/CMakeLists.txt +++ b/tools/pnnx/CMakeLists.txt @@ -83,7 +83,8 @@ else() message(WARNING "Building without TorchVision") endif() -include_directories(SYSTEM ${TORCH_INCLUDE_DIRS}) +# Torch includes are added per-target in src/CMakeLists.txt to avoid +# conflicts with system protobuf headers if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") # test if libtorch and protobuf has the same cxxabi version @@ -95,7 +96,10 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") endif() if((PNNX_TORCH_USE_CXX11_ABI AND PNNX_COMPILER_USE_CXX11_ABI) OR (NOT PNNX_TORCH_USE_CXX11_ABI AND NOT PNNX_COMPILER_USE_CXX11_ABI)) - find_package(protobuf CONFIG) + # Torch may have already registered protobuf targets — skip find_package if so + if(NOT TARGET protobuf::libprotobuf) + find_package(protobuf CONFIG) + endif() if(protobuf_FOUND) set(PROTOBUF_FOUND ${protobuf_FOUND}) @@ -109,20 +113,46 @@ if((PNNX_TORCH_USE_CXX11_ABI AND PNNX_COMPILER_USE_CXX11_ABI) OR (NOT PNNX_TORCH set_target_properties(protobuf::protoc PROPERTIES IMPORTED_LOCATION_RELEASE "${PROTOBUF_PROTOC_EXECUTABLE}") endif() endif() + + # Homebrew protobuf 34.x depends on Abseil — we need to link it explicitly + # because macOS doesn't resolve transitive dylib deps with @rpath properly + find_package(PkgConfig QUIET) + if(PKG_CONFIG_FOUND) + pkg_check_modules(ABSL QUIET absl_log_internal_check_op absl_die_if_null absl_log_internal_conditions absl_log_internal_message absl_examine_stack absl_statusor absl_synchronization absl_time) + if(ABSL_FOUND) + set(ABSL_LIBRARIES ${ABSL_LINK_LIBRARIES}) + endif() + endif() endif() # https://github.com/supertone-inc/onnxruntime-build set(onnxruntime_INSTALL_DIR "/home/nihui/osd/pnnx/install" CACHE STRING "") -find_library(onnxruntime_LIB NAMES onnxruntime PATHS ${onnxruntime_INSTALL_DIR}/lib64 ${onnxruntime_INSTALL_DIR}/lib) +find_library(onnxruntime_LIB NAMES onnxruntime + PATHS ${onnxruntime_INSTALL_DIR}/lib64 ${onnxruntime_INSTALL_DIR}/lib + /opt/homebrew/lib /usr/local/lib) if(onnxruntime_LIB) set(onnxruntime_FOUND TRUE) - add_library(onnxruntime::onnxruntime STATIC IMPORTED) + add_library(onnxruntime::onnxruntime SHARED IMPORTED) set_target_properties(onnxruntime::onnxruntime PROPERTIES IMPORTED_LOCATION ${onnxruntime_LIB}) - set_target_properties(onnxruntime::onnxruntime PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${onnxruntime_INSTALL_DIR}/include) + # prefer install-dir include, fall back to homebrew + if(EXISTS ${onnxruntime_INSTALL_DIR}/include) + set_target_properties(onnxruntime::onnxruntime PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${onnxruntime_INSTALL_DIR}/include) + else() + find_path(onnxruntime_INCLUDE_DIR onnxruntime_c_api.h + PATHS /opt/homebrew/include/onnxruntime /usr/local/include/onnxruntime) + if(onnxruntime_INCLUDE_DIR) + set_target_properties(onnxruntime::onnxruntime PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${onnxruntime_INCLUDE_DIR}) + endif() + endif() else() set(onnxruntime_FOUND FALSE) endif() +option(PNNX_DISABLE_ONNXRUNTIME "disable onnxruntime support and skip building onnx2pnnx" OFF) +if(PNNX_DISABLE_ONNXRUNTIME) + set(onnxruntime_FOUND FALSE) +endif() + option(PNNX_TNN2PNNX "build tnn2pnnx" ON) add_subdirectory(src) diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 82be5182ae16..de58e4d263de 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -593,6 +593,11 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/Tensor_reshape_as.cpp pass_ncnn/Tensor_repeat.cpp pass_ncnn/Tensor_unflatten.cpp + pass_ncnn/TopK.cpp + pass_ncnn/gatherelements.cpp + pass_ncnn/expand.cpp + pass_ncnn/tile.cpp + pass_ncnn/mod.cpp pass_ncnn/torch_addmm.cpp pass_ncnn/torch_amax.cpp pass_ncnn/torch_amin.cpp @@ -603,6 +608,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/torch_diag.cpp pass_ncnn/torch_flatten.cpp pass_ncnn/torch_flip.cpp + pass_ncnn/torch_gather.cpp pass_ncnn/torch_istft.cpp pass_ncnn/torch_logsumexp.cpp pass_ncnn/torch_matmul.cpp @@ -615,6 +621,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/torch_roll.cpp pass_ncnn/torch_slice_scatter.cpp pass_ncnn/torch_squeeze.cpp + pass_ncnn/tensor_to.cpp pass_ncnn/torch_sum.cpp pass_ncnn/torch_stft.cpp pass_ncnn/torch_t.cpp @@ -635,6 +642,12 @@ if(PROTOBUF_FOUND) add_library(onnxproto STATIC ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS}) target_include_directories(onnxproto PUBLIC ${PROTOBUF_INCLUDE_DIR} ${CMAKE_CURRENT_BINARY_DIR}) target_link_libraries(onnxproto PUBLIC ${PROTOBUF_LIBRARIES}) + if(ABSL_LIBRARIES) + target_link_libraries(onnxproto PUBLIC ${ABSL_LIBRARIES}) + endif() + # Force system protobuf headers BEFORE any Torch-bundled old headers + # (Torch bundles an ancient protobuf that conflicts with system protobuf >= 22) + include_directories(BEFORE ${PROTOBUF_INCLUDE_DIR} ${CMAKE_CURRENT_BINARY_DIR}) else() add_library(onnxproto STATIC onnx-data.proto onnx-ml.proto onnx-operators-ml.proto) target_include_directories(onnxproto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) @@ -674,6 +687,7 @@ set(torch2pnnx_SRCS add_library(torch2pnnx OBJECT ${torch2pnnx_SRCS}) target_compile_definitions(torch2pnnx PRIVATE BUILD_TORCH2PNNX) target_compile_options(torch2pnnx PUBLIC "${TORCH_CXX_FLAGS}") +target_include_directories(torch2pnnx SYSTEM PRIVATE ${TORCH_INCLUDE_DIRS}) if(WIN32) target_compile_definitions(torch2pnnx PUBLIC NOMINMAX) @@ -687,6 +701,10 @@ if(PROTOBUF_FOUND) add_library(pnnx2onnx STATIC save_onnx.cpp ) + # Ensure Homebrew protobuf headers are found BEFORE Torch's bundled old ones + if(Protobuf_FOUND OR protobuf_MODULE_COMPATIBLE) + target_include_directories(pnnx2onnx BEFORE PRIVATE ${PROTOBUF_INCLUDE_DIR}) + endif() if(onnxruntime_FOUND) target_link_libraries(pnnx2onnx PRIVATE onnxruntime::onnxruntime) else() @@ -780,12 +798,18 @@ set(pnnx_SRCS add_executable(pnnx ${pnnx_SRCS}) set_property(SOURCE main.cpp APPEND PROPERTY COMPILE_DEFINITIONS BUILD_TORCH2PNNX) +target_include_directories(pnnx SYSTEM PRIVATE ${TORCH_INCLUDE_DIRS}) target_link_libraries(pnnx PRIVATE torch2pnnx) if(TorchVision_FOUND) target_link_libraries(pnnx PRIVATE ${TORCHVISION_LIBRARY}) endif() +# Link Abseil (needed for protobuf 34.x on macOS/Homebrew) +if(ABSL_LIBRARIES) + target_link_libraries(pnnx PRIVATE ${ABSL_LIBRARIES}) +endif() + if(WIN32) target_link_libraries(pnnx PRIVATE ${TORCH_LIBRARIES}) else() @@ -800,6 +824,9 @@ endif() if(onnxruntime_FOUND) set_property(SOURCE main.cpp APPEND PROPERTY COMPILE_DEFINITIONS BUILD_ONNX2PNNX) target_link_libraries(pnnx PRIVATE onnx2pnnx) + if(PROTOBUF_FOUND) + target_link_libraries(pnnx PRIVATE onnxproto) + endif() endif() if(PNNX_TNN2PNNX) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index de861c4ca6bd..c3b587678c49 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1494,6 +1494,34 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath, con fprintf(pyfp, "\n"); + // output custom layer classes for pnnx operators + { + bool has_topk = false; + for (const Operator* op : ops) + { + if (op->type == "TopK") + { + has_topk = true; + break; + } + } + + if (has_topk) + { + fprintf(pyfp, "class TopK(nn.Module):\n"); + fprintf(pyfp, " def __init__(self, k=1, axis=1, largest=1, sorted=1):\n"); + fprintf(pyfp, " super(TopK, self).__init__()\n"); + fprintf(pyfp, " self.k = k\n"); + fprintf(pyfp, " self.axis = axis\n"); + fprintf(pyfp, " self.largest = largest\n"); + fprintf(pyfp, " self.sorted = sorted\n"); + fprintf(pyfp, " def forward(self, x):\n"); + fprintf(pyfp, " # Torch topk returns (values, indices)\n"); + fprintf(pyfp, " return torch.topk(x, self.k, dim=self.axis, largest=bool(self.largest), sorted=bool(self.sorted))\n"); + fprintf(pyfp, "\n"); + } + } + fprintf(pyfp, "class Model(nn.Module):\n"); fprintf(pyfp, " def __init__(self):\n"); fprintf(pyfp, " super(Model, self).__init__()\n"); @@ -1620,6 +1648,28 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath, con } } + // TopK modules + { + for (const Operator* op : ops) + { + if (op->type != "TopK") + continue; + + // TopK param ids: "0"=axis "1"=largest "2"=sorted "3"=k + int k_val = 1; + int axis_val = -1; + int largest_val = 1; + int sorted_val = 1; + if (op->params.count("3")) k_val = op->params.at("3").i; + if (op->params.count("0")) axis_val = op->params.at("0").i; + if (op->params.count("1")) largest_val = op->params.at("1").i; + if (op->params.count("2")) sorted_val = op->params.at("2").i; + + fprintf(pyfp, " self.%s = TopK(k=%d, axis=%d, largest=%d, sorted=%d)\n", + sanitize_identifier(op->name).c_str(), k_val, axis_val, largest_val, sorted_val); + } + } + fprintf(pyfp, "\n"); // load weights @@ -2201,6 +2251,24 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath, con } fprintf(pyfp, ")\n"); } + else if (op->type == "TopK") + { + // self.topk_name() + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, ")\n"); + } else { if (op->type.find("::") == std::string::npos && op->type.find(".") == std::string::npos) diff --git a/tools/pnnx/src/load_onnx.cpp b/tools/pnnx/src/load_onnx.cpp index 3c788a0c4849..63559fee1827 100644 --- a/tools/pnnx/src/load_onnx.cpp +++ b/tools/pnnx/src/load_onnx.cpp @@ -13,7 +13,13 @@ #include #include +#if __has_include() #include +#elif __has_include() +#include +#elif __has_include() +#include +#endif #include "ir.h" diff --git a/tools/pnnx/src/pass_level2/torch_topk.cpp b/tools/pnnx/src/pass_level2/torch_topk.cpp index f3d7fae98ba4..bfc8ef51c7c5 100644 --- a/tools/pnnx/src/pass_level2/torch_topk.cpp +++ b/tools/pnnx/src/pass_level2/torch_topk.cpp @@ -13,11 +13,11 @@ class torch_topk : public GraphRewriterPass return R"PNNXIR(7767517 7 7 pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 k -pnnx.Input input_2 0 1 dim -pnnx.Input input_3 0 1 largest -pnnx.Input input_4 0 1 sorted -aten::topk op_0 5 2 input k dim largest sorted values indices +prim::Constant op_0 0 1 k value=%k +prim::Constant op_1 0 1 dim value=%dim +prim::Constant op_2 0 1 largest value=%largest +prim::Constant op_3 0 1 sorted value=%sorted +aten::topk op_4 5 2 input k dim largest sorted values indices pnnx.Output output 2 0 values indices )PNNXIR"; } diff --git a/tools/pnnx/src/pass_ncnn/TopK.cpp b/tools/pnnx/src/pass_ncnn/TopK.cpp new file mode 100644 index 000000000000..7a0a2370bebd --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/TopK.cpp @@ -0,0 +1,140 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +static int parameter_to_bool(const Parameter& p, int default_value) +{ + if (p.type == 1) + return p.b ? 1 : 0; + if (p.type == 2) + return p.i ? 1 : 0; + + return default_value; +} + +static void write_topk_params(Operator* op, const std::map& captured_params) +{ + int axis = -1; + if (captured_params.find("dim") != captured_params.end()) + { + const Parameter& dim_p = captured_params.at("dim"); + if (dim_p.type == 2) + axis = dim_p.i; + else if (dim_p.type == 5 && !dim_p.ai.empty()) + axis = dim_p.ai[0]; + } + + int largest = 1; + if (captured_params.find("largest") != captured_params.end()) + largest = parameter_to_bool(captured_params.at("largest"), 1); + + int sorted = 1; + if (captured_params.find("sorted") != captured_params.end()) + sorted = parameter_to_bool(captured_params.at("sorted"), 1); + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + if (axis == batch_index) + { + fprintf(stderr, "TopK along batch axis is not supported\n"); + return; + } + + int new_axis = axis; + if (axis >= 0) + new_axis = axis > batch_index ? axis - 1 : axis; + + // ncnn TopK uses ncnn-internal axis ordering (shape[0]=w=innermost), + // but pnnx axis is PyTorch-style (outermost=0). Convert. + const int pytorch_ndim = (int)op->inputs[0]->shape.size(); + const bool has_batch = (batch_index >= 0 && batch_index < pytorch_ndim); + const int ncnn_ndim = has_batch ? pytorch_ndim - 1 : pytorch_ndim; + if (new_axis >= 0 && ncnn_ndim > 0) + new_axis = (ncnn_ndim - 1) - new_axis; + + int k_val = 1; + if (captured_params.find("k") != captured_params.end()) + { + const Parameter& k_p = captured_params.at("k"); + if (k_p.type == 2) + k_val = k_p.i; + else if (k_p.type == 5 && !k_p.ai.empty()) + k_val = k_p.ai[0]; + } + + op->params["0"] = new_axis; + op->params["1"] = largest; + op->params["2"] = sorted; + op->params["3"] = k_val; +} + +class torch_topk : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input_0 0 1 input +torch.topk op_0 1 2 input values indices k=%k dim=%dim largest=%largest sorted=%sorted +pnnx.Output output 2 0 values indices +)PNNXIR"; + } + + const char* type_str() const + { + return "TopK"; + } + + const char* name_str() const + { + return "topk"; + } + + void write(Operator* op, const std::map& captured_params) const + { + write_topk_params(op, captured_params); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_topk, 20) + +class torch_topk_0 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 1 +pnnx.Input input_0 0 1 input +torch.topk op_0 1 1 input values k=%k dim=%dim largest=%largest sorted=%sorted +pnnx.Output output 1 0 values +)PNNXIR"; + } + + const char* type_str() const + { + return "TopK"; + } + + const char* name_str() const + { + return "topk"; + } + + void write(Operator* op, const std::map& captured_params) const + { + write_topk_params(op, captured_params); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_topk_0, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/expand.cpp b/tools/pnnx/src/pass_ncnn/expand.cpp new file mode 100644 index 000000000000..2a6f2cc74c42 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/expand.cpp @@ -0,0 +1,44 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class onnx_Expand : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 shape +Expand op_0 2 1 input shape output +pnnx.Output output 1 0 output +)PNNXIR"; + } + + const char* type_str() const + { + return "Expand"; + } + + const char* name_str() const + { + return "expand"; + } + + void write(Operator* op, const std::map& captured_params) const + { + // No parameters needed - shape comes as second input blob + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(onnx_Expand, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/gatherelements.cpp b/tools/pnnx/src/pass_ncnn/gatherelements.cpp new file mode 100644 index 000000000000..1eaa1f8d5508 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/gatherelements.cpp @@ -0,0 +1,54 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class onnx_GatherElements : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 data +pnnx.Input input_1 0 1 indices +GatherElements op_0 2 1 data indices out axis=%axis +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "GatherElements"; + } + + const char* name_str() const + { + return "gatherelements"; + } + + void write(Operator* op, const std::map& captured_params) const + { + int axis = 0; + if (captured_params.find("axis") != captured_params.end()) + { + const Parameter& axis_p = captured_params.at("axis"); + if (axis_p.type == 2) + axis = axis_p.i; + else if (axis_p.type == 5 && !axis_p.ai.empty()) + axis = axis_p.ai[0]; + } + + op->params["0"] = axis; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(onnx_GatherElements, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/mod.cpp b/tools/pnnx/src/pass_ncnn/mod.cpp new file mode 100644 index 000000000000..0c92742d4bfe --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/mod.cpp @@ -0,0 +1,54 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class onnx_Mod : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 A +pnnx.Input input_1 0 1 B +Mod op_0 2 1 A B C fmod=%fmod +pnnx.Output output 1 0 C +)PNNXIR"; + } + + const char* type_str() const + { + return "Mod"; + } + + const char* name_str() const + { + return "mod"; + } + + void write(Operator* op, const std::map& captured_params) const + { + int fmod = 0; + if (captured_params.find("fmod") != captured_params.end()) + { + const Parameter& fmod_p = captured_params.at("fmod"); + if (fmod_p.type == 1) + fmod = fmod_p.b ? 1 : 0; + else if (fmod_p.type == 2) + fmod = fmod_p.i; + } + + op->params["0"] = fmod; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(onnx_Mod, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/tensor_to.cpp b/tools/pnnx/src/pass_ncnn/tensor_to.cpp new file mode 100644 index 000000000000..597079da7969 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/tensor_to.cpp @@ -0,0 +1,88 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class Tensor_to : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 2 +pnnx.Input input_0 0 1 input +Tensor.to op_0 1 1 input out copy=%copy dtype=%dtype +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Cast"; + } + + const char* name_str() const + { + return "to"; + } + + void write(Operator* op, const std::map& captured_params) const + { + // Map pnnx operand type (0=null 1=f32 2=f64 3=f16 4=i32 5=i64 7=i8 13=bf16) + // to ncnn cast type (1=float32 2=float16 3=int8 4=bfloat16 5=int64 6=int32) + static const int pnnx_to_ncnn_cast_type[] = { + 0, // 0=null + 1, // 1=f32 → ncnn float32 + 1, // 2=f64 → ncnn float32 (no f64 in ncnn) + 2, // 3=f16 → ncnn float16 + 6, // 4=i32 → ncnn int32 + 5, // 5=i64 → ncnn int64 + 0, // 6=i16 → unsupported + 3, // 7=i8 → ncnn int8 + 0, // 8=u8 → unsupported + 0, // 9=bool → unsupported + 0, // 10=c64 + 0, // 11=c128 + 0, // 12=c32 + 4, // 13=bf16 → ncnn bfloat16 + }; + + const int in_pnnx_type = op->inputs[0]->type; + int type_from = 0; + if (in_pnnx_type >= 0 && in_pnnx_type <= 13) + type_from = pnnx_to_ncnn_cast_type[in_pnnx_type]; + + std::string dtype = "torch.float"; + if (captured_params.find("dtype") != captured_params.end()) + { + dtype = captured_params.at("dtype").s; + } + + int type_to = 0; + if (dtype == "torch.float" || dtype == "torch.float32") + type_to = 1; + else if (dtype == "torch.float16" || dtype == "torch.half") + type_to = 2; + else if (dtype == "torch.int8") + type_to = 3; + else if (dtype == "torch.bfloat16") + type_to = 4; + else if (dtype == "torch.int64" || dtype == "torch.long") + type_to = 5; + else if (dtype == "torch.int32" || dtype == "torch.int") + type_to = 6; + + op->params["0"] = type_from; + op->params["1"] = type_to; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(Tensor_to, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/tile.cpp b/tools/pnnx/src/pass_ncnn/tile.cpp new file mode 100644 index 000000000000..fcab9a18e2ff --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/tile.cpp @@ -0,0 +1,44 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class onnx_Tile : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 repeats +Tile op_0 2 1 input repeats output +pnnx.Output output 1 0 output +)PNNXIR"; + } + + const char* type_str() const + { + return "Tile"; + } + + const char* name_str() const + { + return "tile"; + } + + void write(Operator* op, const std::map& captured_params) const + { + // No parameters needed - repeats comes as second input blob + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(onnx_Tile, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_gather.cpp b/tools/pnnx/src/pass_ncnn/torch_gather.cpp new file mode 100644 index 000000000000..2df4571bce75 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_gather.cpp @@ -0,0 +1,66 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_gather : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 index +torch.gather op_0 2 1 input index out dim=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Gather"; + } + + const char* name_str() const + { + return "gather"; + } + + void write(Operator* op, const std::map& captured_params) const + { + int axis = 0; + if (captured_params.find("dim") != captured_params.end()) + { + const Parameter& dim_p = captured_params.at("dim"); + if (dim_p.type == 2) + axis = dim_p.i; + else if (dim_p.type == 5 && !dim_p.ai.empty()) + axis = dim_p.ai[0]; + } + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + if (axis == batch_index) + { + fprintf(stderr, "Gather along batch axis is not supported\n"); + return; + } + + int new_axis = axis; + if (axis >= 0) + new_axis = axis > batch_index ? axis - 1 : axis; + + op->params["0"] = new_axis; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_gather, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/fold_constants.cpp b/tools/pnnx/src/pass_onnx/fold_constants.cpp index 1ef0092a72ec..c79cb29f34a1 100644 --- a/tools/pnnx/src/pass_onnx/fold_constants.cpp +++ b/tools/pnnx/src/pass_onnx/fold_constants.cpp @@ -9,7 +9,15 @@ #include #include +#if __has_include() #include +#elif __has_include() +#include +#elif __has_include() +#include +#else +#error "onnxruntime_c_api.h not found" +#endif #include "dead_code_elimination.h" diff --git a/tools/pnnx/src/pass_onnx/shape_inference.cpp b/tools/pnnx/src/pass_onnx/shape_inference.cpp index 99dc652389d8..23986a7a7d2d 100644 --- a/tools/pnnx/src/pass_onnx/shape_inference.cpp +++ b/tools/pnnx/src/pass_onnx/shape_inference.cpp @@ -8,7 +8,15 @@ #include #include +#if __has_include() #include +#elif __has_include() +#include +#elif __has_include() +#include +#else +#error "onnxruntime_c_api.h not found" +#endif namespace pnnx { diff --git a/tools/pnnx/tests/onnx/CMakeLists.txt b/tools/pnnx/tests/onnx/CMakeLists.txt index 7398017dc73e..6681954e716c 100644 --- a/tools/pnnx/tests/onnx/CMakeLists.txt +++ b/tools/pnnx/tests/onnx/CMakeLists.txt @@ -193,6 +193,8 @@ pnnx_onnx_add_test(torch_split) pnnx_onnx_add_test(torch_squeeze) pnnx_onnx_add_test(torch_stack) pnnx_onnx_add_test(torch_sum) +pnnx_onnx_add_test(torch_gather) +pnnx_onnx_add_test(torch_topk) pnnx_onnx_add_test(torch_transpose) pnnx_onnx_add_test(torch_unbind) pnnx_onnx_add_test(torch_unsqueeze) diff --git a/tools/pnnx/tests/onnx/test_torch_gather.py b/tools/pnnx/tests/onnx/test_torch_gather.py new file mode 100644 index 000000000000..f97f74a8b098 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_torch_gather.py @@ -0,0 +1,72 @@ +# Copyright 2025 Tencent +# SPDX-License-Identifier: BSD-3-Clause + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + # 1D gather along axis 0 + idx_1d = torch.tensor([2, 0, 1], dtype=torch.int64) + a = torch.gather(x, 0, idx_1d) + + # 2D gather along axis 0 + idx_2d_axis0 = torch.tensor([[0, 1], [1, 0], [0, 0]], dtype=torch.int64) + b = torch.gather(y, 0, idx_2d_axis0) + + # 2D gather along axis 1 + idx_2d_axis1 = torch.tensor([[1, 0, 2], [0, 2, 1]], dtype=torch.int64) + c = torch.gather(y, 1, idx_2d_axis1) + + # 3D gather along axis 1 + idx_3d = torch.zeros(2, 2, 4, dtype=torch.int64) + d = torch.gather(z, 1, idx_3d) + + # 3D gather along last axis (negative index) + idx_3d_last = torch.zeros(2, 3, 2, dtype=torch.int64) + e = torch.gather(z, -1, idx_3d_last) + + return a, b, c, d, e + + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(5) + y = torch.rand(3, 4) + z = torch.rand(2, 3, 4) + + a = net(x, y, z) + + # export onnx + torch.onnx.export(net, (x, y, z), "test_torch_gather.onnx", + opset_version=13) + + # onnx to pnnx + import os + os.system( + "../../src/pnnx test_torch_gather.onnx " + "inputshape=[5],[3,4],[2,3,4]" + ) + + # pnnx inference + import test_torch_gather_pnnx + b = test_torch_gather_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_torch_topk.py b/tools/pnnx/tests/onnx/test_torch_topk.py new file mode 100644 index 000000000000..dfd99ee2ac26 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_torch_topk.py @@ -0,0 +1,105 @@ +# Copyright 2026 Tencent +# SPDX-License-Identifier: BSD-3-Clause + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, u, v): + x_values, x_indices = torch.topk( + x, 2, dim=1, largest=True, sorted=True + ) + x_k1_values, x_k1_indices = torch.topk( + x, 1, dim=1, largest=True, sorted=True + ) + x_k0_values, x_k0_indices = torch.topk( + x, 0, dim=1, largest=True, sorted=True + ) + x_unsorted_values, x_unsorted_indices = torch.topk( + x, 2, dim=1, largest=True, sorted=False + ) + x_values_only = torch.topk( + x, 3, dim=1, largest=True, sorted=True + )[0] + y_values, y_indices = torch.topk( + y, 4, dim=3, largest=False, sorted=True + ) + z_values, z_indices = torch.topk( + z, 3, dim=0, largest=True, sorted=True + ) + z_unsorted_values, z_unsorted_indices = torch.topk( + z, 3, dim=0, largest=True, sorted=False + ) + u_values, u_indices = torch.topk( + u, 2, dim=-1, largest=True, sorted=True + ) + v_values, v_indices = torch.topk( + v, 2, dim=1, largest=True, sorted=True + ) + + return ( + x_values, + x_indices, + x_k1_values, + x_k1_indices, + x_k0_values, + x_k0_indices, + x_unsorted_values, + x_unsorted_indices, + x_values_only, + y_values, + y_indices, + z_values, + z_indices, + z_unsorted_values, + z_unsorted_indices, + u_values, + u_indices, + v_values, + v_indices, + ) + + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + u = torch.rand(2, 8, 4) + v = torch.rand(2, 4, 3) + + a = net(x, y, z, u, v) + + # export onnx + torch.onnx.export(net, (x, y, z, u, v), "test_torch_topk.onnx") + + # onnx to pnnx + import os + + os.system( + "../../src/pnnx test_torch_topk.onnx " + "inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10],[2,8,4],[2,4,3]" + ) + + # pnnx inference + import test_torch_topk_pnnx + b = test_torch_topk_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)