diff --git a/.github/workflows/topk-linux-test.yml b/.github/workflows/topk-linux-test.yml new file mode 100644 index 000000000000..a29b5efc0a7c --- /dev/null +++ b/.github/workflows/topk-linux-test.yml @@ -0,0 +1,115 @@ +name: topk-linux-test +on: + push: + branches: + - topk-ci-tests + - fix-pnnx-onnx-topk-support + pull_request: + branches: + - master + +jobs: + x64-none: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: build + run: | + mkdir build && cd build + cmake -DCMAKE_BUILD_TYPE=Debug -DNCNN_RUNTIME_CPU=OFF \ + -DNCNN_SSE2=OFF -DNCNN_AVX=OFF \ + -DNCNN_OPENMP=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF -DNCNN_BUILD_TESTS=ON .. + cmake --build . --target test_topk -j$(nproc) + - name: test + run: cd build && ./tests/test_topk + + x64-sse2: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: build + run: | + mkdir build && cd build + cmake -DCMAKE_BUILD_TYPE=Debug -DNCNN_RUNTIME_CPU=OFF \ + -DNCNN_SSE2=ON -DNCNN_AVX=OFF \ + -DNCNN_OPENMP=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF -DNCNN_BUILD_TESTS=ON .. + cmake --build . --target test_topk -j$(nproc) + - name: test + run: cd build && ./tests/test_topk + + x64-avx2: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: build + run: | + mkdir build && cd build + cmake -DCMAKE_BUILD_TYPE=Debug -DNCNN_RUNTIME_CPU=OFF \ + -DNCNN_SSE2=ON -DNCNN_AVX=ON -DNCNN_F16C=ON -DNCNN_FMA=ON -DNCNN_AVX2=ON \ + -DNCNN_AVX512=OFF -DNCNN_XOP=OFF -DNCNN_AVXVNNI=OFF \ + -DNCNN_OPENMP=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF -DNCNN_BUILD_TESTS=ON .. + cmake --build . --target test_topk -j$(nproc) + - name: test + run: cd build && ./tests/test_topk + + simplestl-simplemath: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: build + run: | + mkdir build && cd build + cmake -DCMAKE_TOOLCHAIN_FILE=../toolchains/host-c.gcc.toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Debug \ + -DNCNN_SIMPLESTL=ON -DNCNN_SIMPLEMATH=ON \ + -DNCNN_OPENMP=OFF -DNCNN_THREADS=OFF \ + -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF -DNCNN_BUILD_TESTS=ON .. + cmake --build . --target test_topk -j$(nproc) + - name: test + run: cd build && ./tests/test_topk + + linux-x86-gcc: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: install + run: sudo apt-get update && sudo apt-get install -y gcc-multilib g++-multilib + - name: build + run: | + mkdir build && cd build + cmake -DCMAKE_TOOLCHAIN_FILE=../toolchains/host.gcc-m32.toolchain.cmake \ + -DNCNN_BUILD_TESTS=ON -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF .. + cmake --build . --target test_topk -j$(nproc) + - name: test + run: cd build && ./tests/test_topk + - name: build-nosse + run: | + mkdir build-nosse && cd build-nosse + cmake -DCMAKE_TOOLCHAIN_FILE=../toolchains/host.gcc-m32.toolchain.cmake \ + -DNCNN_RUNTIME_CPU=OFF -DNCNN_SSE2=OFF -DNCNN_AVX=OFF \ + -DNCNN_BUILD_TESTS=ON -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF .. + cmake --build . --target test_topk -j$(nproc) + - name: test-nosse + run: cd build-nosse && ./tests/test_topk + + pnnx-onnx-topk: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + - name: setup-pytorch + run: | + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + pip3 install numpy packaging onnx onnxruntime + - name: build-pnnx + run: | + cd tools/pnnx + mkdir build && cd build + cmake -DCMAKE_BUILD_TYPE=Release .. + cmake --build . --config Release -j$(nproc) + - name: test-topk + run: | + cd tools/pnnx/build + ctest --output-on-failure -R test_onnx_torch_topk diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 614c3b8f31f1..3f518f11117b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -101,6 +101,8 @@ ncnn_add_layer(SPP OFF) ncnn_add_layer(TanH) ncnn_add_layer(Threshold) ncnn_add_layer(Tile) +ncnn_add_layer(TopK) +ncnn_add_layer(Gather) ncnn_add_layer(RNN) ncnn_add_layer(LSTM) ncnn_add_layer(BinaryOp) diff --git a/src/layer/cast.cpp b/src/layer/cast.cpp index 3dcff38f3cac..e18a7c3a8ae2 100644 --- a/src/layer/cast.cpp +++ b/src/layer/cast.cpp @@ -74,6 +74,16 @@ int Cast::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons // bfloat16 out_elemsize = 2 * elempack; } + else if (type_to == 5) + { + // int64 + out_elemsize = 8 * elempack; + } + else if (type_to == 6) + { + // int32 + out_elemsize = 4 * elempack; + } if (dims == 1) { @@ -173,6 +183,70 @@ int Cast::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons // TODO more cast type + if (type_from == 5 && type_to == 1) + { + // int64 → float32 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const long long* ptr = bottom_blob.channel(q); + float* outptr = top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = (float)ptr[i]; + } + } + } + + if (type_from == 1 && type_to == 5) + { + // float32 → int64 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = bottom_blob.channel(q); + long long* outptr = top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = (long long)ptr[i]; + } + } + } + + if (type_from == 6 && type_to == 1) + { + // int32 → float32 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const int* ptr = bottom_blob.channel(q); + float* outptr = top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = (float)ptr[i]; + } + } + } + + if (type_from == 1 && type_to == 6) + { + // float32 → int32 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = bottom_blob.channel(q); + int* outptr = top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = (int)ptr[i]; + } + } + } + return 0; } diff --git a/src/layer/cast.h b/src/layer/cast.h index 036e61efed04..22c8f5da4626 100644 --- a/src/layer/cast.h +++ b/src/layer/cast.h @@ -24,6 +24,8 @@ class Cast : public Layer // 2 = float16 // 3 = int8 // 4 = bfloat16 + // 5 = int64 + // 6 = int32 int type_from; int type_to; }; diff --git a/src/layer/gather.cpp b/src/layer/gather.cpp new file mode 100644 index 000000000000..850b65b3d121 --- /dev/null +++ b/src/layer/gather.cpp @@ -0,0 +1,121 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "gather.h" + +namespace ncnn { + +Gather::Gather() +{ + one_blob_only = false; + support_inplace = false; +} + +int Gather::load_param(const ParamDict& pd) +{ + axis = pd.get(0, 0); + + return 0; +} + +int Gather::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + if (bottom_blobs.size() < 2) + return -1; + + const Mat& input_blob = bottom_blobs[0]; + const Mat& index_blob = bottom_blobs[1]; + const int dims = input_blob.dims; + + // index_blob should contain int64 or int32 indices + // For simplicity we treat it as float and cast + const int index_size = (int)index_blob.total(); + + int positive_axis = axis < 0 ? axis + dims : axis; + if (positive_axis < 0 || positive_axis >= dims) + return -1; + + int shape[4] = {1, 1, 1, 1}; + shape[0] = input_blob.w; + if (dims >= 2) shape[1] = input_blob.h; + if (dims == 3) shape[2] = input_blob.c; + if (dims == 4) shape[2] = input_blob.c; // w*h*c layout + + const int axis_dim_size = shape[positive_axis]; + + // Output shape matches index_blob shape + const Mat& out_shape = index_blob; + + // Allocate output (same dtype as input, shape matches index) + Mat& top_blob = top_blobs[0]; + top_blob.create(out_shape.w, out_shape.h, out_shape.c, input_blob.elemsize, input_blob.elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + const float* inp = input_blob; + const int* idx = (const int*)index_blob; + float* out = top_blob; + + // General case: iterate over all output positions + // Map flat output index to multi-dimensional coords, + // then compute corresponding input position with index substitution + const int total_out = (int)top_blob.total(); + for (int i = 0; i < total_out; i++) + { + // Decompose flat index i into coordinates based on top_blob shape + int rem = i; + int coord_out[4] = {0, 0, 0, 0}; + if (top_blob.dims == 1) + { + coord_out[0] = rem; + } + else if (top_blob.dims == 2) + { + coord_out[0] = rem % top_blob.w; + coord_out[1] = rem / top_blob.w; + } + else if (top_blob.dims == 3) + { + int hw = top_blob.w * top_blob.h; + coord_out[0] = (rem % hw) % top_blob.w; + coord_out[1] = (rem % hw) / top_blob.w; + coord_out[2] = rem / hw; + } + + // Get index value at this output position + int gather_idx = idx[i]; + // Handle negative indices + if (gather_idx < 0) gather_idx += axis_dim_size; + + // Build input coordinate (same as output, but axis coord replaced) + int coord_in[4] = {coord_out[0], coord_out[1], coord_out[2], coord_out[3]}; + coord_in[positive_axis] = gather_idx; + + // Clamp to input bounds + if (coord_in[positive_axis] >= axis_dim_size) coord_in[positive_axis] = axis_dim_size - 1; + if (coord_in[positive_axis] < 0) coord_in[positive_axis] = 0; + + // Compute flat input index + int flat_in = 0; + if (dims == 1) + { + flat_in = coord_in[0]; + } + else if (dims == 2) + { + flat_in = coord_in[0] + coord_in[1] * input_blob.w; + } + else if (dims == 3) + { + // ncnn 3D layout: w * h * c, with cstride padding + size_t cstep = input_blob.cstep; + flat_in = coord_in[0] + coord_in[1] * input_blob.w + coord_in[2] * (int)cstep; + } + + out[i] = inp[flat_in]; + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/gather.h b/src/layer/gather.h new file mode 100644 index 000000000000..f8d24d9afb54 --- /dev/null +++ b/src/layer/gather.h @@ -0,0 +1,27 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_GATHER_H +#define LAYER_GATHER_H + +#include "layer.h" + +namespace ncnn { + +class Gather : public Layer +{ +public: + Gather(); + + virtual int load_param(const ParamDict& pd); + + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + +public: + // param_0 = axis (default 0) + int axis; +}; + +} // namespace ncnn + +#endif // LAYER_GATHER_H diff --git a/src/layer/topk.cpp b/src/layer/topk.cpp new file mode 100644 index 000000000000..3b78fbfce3fe --- /dev/null +++ b/src/layer/topk.cpp @@ -0,0 +1,566 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "topk.h" + +#include +#include + +#if NCNN_SIMPLESTL +#include "simplestl.h" +#else +#include +#include +#endif + +#if __ARM_NEON +#include +#endif // __ARM_NEON + +namespace ncnn { + +static inline bool topk_isnan(float v) +{ + uint32_t u; + memcpy(&u, &v, sizeof(uint32_t)); + return (u & 0x7fffffff) > 0x7f800000; +} + +static inline bool topk_pair_comp(const std::pair& a, const std::pair& b, bool largest) +{ + const bool a_nan = topk_isnan(a.first); + const bool b_nan = topk_isnan(b.first); + + // Keep NaN at the end for both largest/smallest to ensure deterministic ordering. + if (a_nan || b_nan) + { + if (a_nan != b_nan) + return !a_nan && b_nan; + + return a.second < b.second; + } + + if (a.first != b.first) + return largest ? (a.first > b.first) : (a.first < b.first); + + return a.second < b.second; +} + +static inline bool topk_value_index_comp(float a_value, int a_index, float b_value, int b_index, bool largest) +{ + const bool a_nan = topk_isnan(a_value); + const bool b_nan = topk_isnan(b_value); + + if (a_nan || b_nan) + { + if (a_nan != b_nan) + return !a_nan && b_nan; + + return a_index < b_index; + } + + if (a_value != b_value) + return largest ? (a_value > b_value) : (a_value < b_value); + + return a_index < b_index; +} + +struct topk_pair_comparator +{ + topk_pair_comparator(bool _largest) + : largest(_largest) + { + } + + bool operator()(const std::pair& a, const std::pair& b) const + { + return topk_pair_comp(a, b, largest); + } + + bool largest; +}; + +TopK::TopK() +{ + one_blob_only = false; + support_inplace = false; +} + +int TopK::load_param(const ParamDict& pd) +{ + axis = pd.get(0, -1); + largest = pd.get(1, 1); + sorted = pd.get(2, 1); + k = pd.get(3, 1); + + return 0; +} + +int TopK::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + if (bottom_blobs.empty()) + return -1; + + const Mat& bottom_blob = bottom_blobs[0]; + + int _k = k; + if (bottom_blobs.size() >= 2) + { + const Mat& k_blob = bottom_blobs[1]; + if (k_blob.total() < 1) + return -1; + + _k = (int)((const float*)k_blob)[0]; + } + + if (bottom_blob.dims < 1 || bottom_blob.dims > 4) + return -100; + + const int dims = bottom_blob.dims; + + const int positive_axis = axis < 0 ? axis + dims : axis; + if (positive_axis < 0 || positive_axis >= dims) + return -1; + + int shape[4] = {1, 1, 1, 1}; + shape[0] = bottom_blob.w; + if (dims >= 2) shape[1] = bottom_blob.h; + if (dims >= 3) shape[2] = bottom_blob.dims == 3 ? bottom_blob.c : bottom_blob.d; + if (dims >= 4) shape[3] = bottom_blob.c; + + const int axis_size = shape[positive_axis]; + if (axis_size <= 0) + return -1; + + if (_k < 0) + return -1; + if (_k > axis_size) + _k = axis_size; + + int out_shape[4] = {shape[0], shape[1], shape[2], shape[3]}; + out_shape[positive_axis] = _k; + + Mat values; + if (dims == 1) values.create(out_shape[0], 4u, opt.blob_allocator); + if (dims == 2) values.create(out_shape[0], out_shape[1], 4u, opt.blob_allocator); + if (dims == 3) values.create(out_shape[0], out_shape[1], out_shape[2], 4u, opt.blob_allocator); + if (dims == 4) values.create(out_shape[0], out_shape[1], out_shape[2], out_shape[3], 4u, opt.blob_allocator); + if (values.empty()) + return -100; + + Mat indices; + if (top_blobs.size() >= 2) + { + if (dims == 1) indices.create(out_shape[0], 4u, opt.blob_allocator); + if (dims == 2) indices.create(out_shape[0], out_shape[1], 4u, opt.blob_allocator); + if (dims == 3) indices.create(out_shape[0], out_shape[1], out_shape[2], 4u, opt.blob_allocator); + if (dims == 4) indices.create(out_shape[0], out_shape[1], out_shape[2], out_shape[3], 4u, opt.blob_allocator); + if (indices.empty()) + return -100; + } + + if (_k == 0) + { + top_blobs[0] = values; + if (top_blobs.size() >= 2) + top_blobs[1] = indices; + + return 0; + } + + const float* ptr = bottom_blob; + float* outptr = values; + float* outidxptr = indices; + const bool output_indices = outidxptr != 0; + + int inner = 1; + for (int i = 0; i < positive_axis; i++) + { + inner *= shape[i]; + } + + int outer = 1; + for (int i = positive_axis + 1; i < dims; i++) + { + outer *= shape[i]; + } + + const bool largest_flag = largest != 0; + const bool sorted_flag = sorted != 0; + + const int total_lines = outer * inner; + + // ncnn 3-/4-D mats have a channel stride (cstep) that may be larger than w*h + // due to alignment padding. The flat inner/outer indexing must account for this: + // - when axis reduces a non-channel dim, the outer loop spans channels and + // the channel offset must use cstep rather than the product of spatial sizes; + // - when axis IS the channel dim, the per-element j-stride must be cstep. + const size_t in_cstep = (dims >= 3) ? (size_t)bottom_blob.cstep : 0; + const size_t out_cstep = (dims >= 3) ? values.cstep : 0; + const bool axis_is_channel = (dims >= 3 && positive_axis == dims - 1); + // spatial-only outer count: channels factored out so cstep can be used separately + const int c_channels = (!axis_is_channel && dims >= 3) ? shape[dims - 1] : 1; + const int outer_spatial = (dims >= 3 && !axis_is_channel) ? outer / c_channels : outer; + // stride when stepping along the axis in memory + const size_t in_axis_stride = axis_is_channel ? in_cstep : (size_t)inner; + const size_t out_axis_stride = axis_is_channel ? out_cstep : (size_t)inner; + + if (_k == 1) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int line = 0; line < total_lines; line++) + { + int outer_i = line / inner; + int inner_i = line - outer_i * inner; + + size_t in_base, out_base; + if (!axis_is_channel && dims >= 3) + { + const int ci = outer_i / outer_spatial; + const int sp_i = outer_i % outer_spatial; + in_base = (size_t)ci * in_cstep + (size_t)sp_i * axis_size * inner + inner_i; + out_base = (size_t)ci * out_cstep + (size_t)sp_i * 1 * inner + inner_i; + } + else + { + in_base = (size_t)outer_i * axis_size * inner + inner_i; + out_base = (size_t)outer_i * 1 * inner + inner_i; + } + +#if __ARM_NEON + if (!output_indices && inner == 1 && axis_size >= 4) + { + const float* lineptr = ptr + in_base; + + float best_value = lineptr[0]; + int j = 1; + int has_nan = topk_isnan(best_value); + + for (; !has_nan && j + 3 < axis_size; j += 4) + { + float32x4_t v = vld1q_f32(lineptr + j); + uint32x4_t nan_mask = vmvnq_u32(vceqq_f32(v, v)); + uint32_t nan_mask_lanes[4]; + vst1q_u32(nan_mask_lanes, nan_mask); + if (nan_mask_lanes[0] || nan_mask_lanes[1] || nan_mask_lanes[2] || nan_mask_lanes[3]) + { + has_nan = 1; + break; + } + + float tmp[4]; + vst1q_f32(tmp, v); + + if (largest_flag) + { + if (tmp[0] > best_value) best_value = tmp[0]; + if (tmp[1] > best_value) best_value = tmp[1]; + if (tmp[2] > best_value) best_value = tmp[2]; + if (tmp[3] > best_value) best_value = tmp[3]; + } + else + { + if (tmp[0] < best_value) best_value = tmp[0]; + if (tmp[1] < best_value) best_value = tmp[1]; + if (tmp[2] < best_value) best_value = tmp[2]; + if (tmp[3] < best_value) best_value = tmp[3]; + } + } + + if (!has_nan) + { + for (; j < axis_size; j++) + { + const float candidate_value = lineptr[j]; + if (topk_isnan(candidate_value)) + { + has_nan = 1; + break; + } + + if (largest_flag) + { + if (candidate_value > best_value) + best_value = candidate_value; + } + else + { + if (candidate_value < best_value) + best_value = candidate_value; + } + } + } + + if (!has_nan) + { + outptr[out_base] = best_value; + continue; + } + } +#endif // __ARM_NEON + + float best_value = ptr[in_base]; + int best_index = 0; + + for (int j = 1; j < axis_size; j++) + { + const float candidate_value = ptr[in_base + j * in_axis_stride]; + if (topk_value_index_comp(candidate_value, j, best_value, best_index, largest_flag)) + { + best_value = candidate_value; + best_index = j; + } + } + + outptr[out_base] = best_value; + if (output_indices) + outidxptr[out_base] = (float)best_index; + } + + top_blobs[0] = values; + if (top_blobs.size() >= 2) + top_blobs[1] = indices; + + return 0; + } + + if (_k == axis_size && !sorted_flag) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int line = 0; line < total_lines; line++) + { + int outer_i = line / inner; + int inner_i = line - outer_i * inner; + + size_t in_base, out_base; + if (!axis_is_channel && dims >= 3) + { + const int ci = outer_i / outer_spatial; + const int sp_i = outer_i % outer_spatial; + in_base = (size_t)ci * in_cstep + (size_t)sp_i * axis_size * inner + inner_i; + out_base = (size_t)ci * out_cstep + (size_t)sp_i * _k * inner + inner_i; + } + else + { + in_base = (size_t)outer_i * axis_size * inner + inner_i; + out_base = (size_t)outer_i * _k * inner + inner_i; + } + + if (output_indices) + { + for (int j = 0; j < _k; j++) + { + outptr[out_base + j * out_axis_stride] = ptr[in_base + j * in_axis_stride]; + outidxptr[out_base + j * out_axis_stride] = (float)j; + } + } + else + { + for (int j = 0; j < _k; j++) + { + outptr[out_base + j * out_axis_stride] = ptr[in_base + j * in_axis_stride]; + } + } + } + + top_blobs[0] = values; + if (top_blobs.size() >= 2) + top_blobs[1] = indices; + + return 0; + } + + if (_k <= 4) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int line = 0; line < total_lines; line++) + { + int outer_i = line / inner; + int inner_i = line - outer_i * inner; + + size_t in_base, out_base; + if (!axis_is_channel && dims >= 3) + { + const int ci = outer_i / outer_spatial; + const int sp_i = outer_i % outer_spatial; + in_base = (size_t)ci * in_cstep + (size_t)sp_i * axis_size * inner + inner_i; + out_base = (size_t)ci * out_cstep + (size_t)sp_i * _k * inner + inner_i; + } + else + { + in_base = (size_t)outer_i * axis_size * inner + inner_i; + out_base = (size_t)outer_i * _k * inner + inner_i; + } + + float top_values[4]; + int top_indices[4]; + int top_count = 0; + + if (sorted_flag) + { + for (int j = 0; j < axis_size; j++) + { + const float candidate_value = ptr[in_base + j * in_axis_stride]; + + if (top_count < _k) + { + int insert_pos = top_count; + while (insert_pos > 0 && topk_value_index_comp(candidate_value, j, top_values[insert_pos - 1], top_indices[insert_pos - 1], largest_flag)) + { + top_values[insert_pos] = top_values[insert_pos - 1]; + top_indices[insert_pos] = top_indices[insert_pos - 1]; + insert_pos--; + } + + top_values[insert_pos] = candidate_value; + top_indices[insert_pos] = j; + top_count++; + } + else if (topk_value_index_comp(candidate_value, j, top_values[_k - 1], top_indices[_k - 1], largest_flag)) + { + int insert_pos = _k - 1; + while (insert_pos > 0 && topk_value_index_comp(candidate_value, j, top_values[insert_pos - 1], top_indices[insert_pos - 1], largest_flag)) + { + top_values[insert_pos] = top_values[insert_pos - 1]; + top_indices[insert_pos] = top_indices[insert_pos - 1]; + insert_pos--; + } + + top_values[insert_pos] = candidate_value; + top_indices[insert_pos] = j; + } + } + } + else + { + for (int j = 0; j < axis_size; j++) + { + const float candidate_value = ptr[in_base + j * in_axis_stride]; + + if (top_count < _k) + { + top_values[top_count] = candidate_value; + top_indices[top_count] = j; + top_count++; + } + else + { + int worst_pos = 0; + for (int t = 1; t < _k; t++) + { + if (topk_value_index_comp(top_values[worst_pos], top_indices[worst_pos], top_values[t], top_indices[t], largest_flag)) + worst_pos = t; + } + + if (topk_value_index_comp(candidate_value, j, top_values[worst_pos], top_indices[worst_pos], largest_flag)) + { + top_values[worst_pos] = candidate_value; + top_indices[worst_pos] = j; + } + } + } + } + + if (output_indices) + { + for (int j = 0; j < _k; j++) + { + outptr[out_base + j * out_axis_stride] = top_values[j]; + outidxptr[out_base + j * out_axis_stride] = (float)top_indices[j]; + } + } + else + { + for (int j = 0; j < _k; j++) + { + outptr[out_base + j * out_axis_stride] = top_values[j]; + } + } + } + + top_blobs[0] = values; + if (top_blobs.size() >= 2) + top_blobs[1] = indices; + + return 0; + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int line = 0; line < total_lines; line++) + { + std::vector > vec(axis_size); + + topk_pair_comparator comp(largest_flag); + + int outer_i = line / inner; + int inner_i = line - outer_i * inner; + + size_t in_base, out_base; + if (!axis_is_channel && dims >= 3) + { + const int ci = outer_i / outer_spatial; + const int sp_i = outer_i % outer_spatial; + in_base = (size_t)ci * in_cstep + (size_t)sp_i * axis_size * inner + inner_i; + out_base = (size_t)ci * out_cstep + (size_t)sp_i * _k * inner + inner_i; + } + else + { + in_base = (size_t)outer_i * axis_size * inner + inner_i; + out_base = (size_t)outer_i * _k * inner + inner_i; + } + + for (int j = 0; j < axis_size; j++) + { + vec[j].first = ptr[in_base + j * in_axis_stride]; + vec[j].second = j; + } + + if (_k < axis_size) + { +#if NCNN_SIMPLESTL + std::partial_sort(vec.begin(), vec.begin() + _k, vec.end(), comp); +#else + if (sorted_flag) + { + std::nth_element(vec.begin(), vec.begin() + _k, vec.end(), comp); + std::sort(vec.begin(), vec.begin() + _k, comp); + } + else + std::nth_element(vec.begin(), vec.begin() + _k, vec.end(), comp); +#endif + } + else + { + if (sorted_flag) +#if NCNN_SIMPLESTL + std::partial_sort(vec.begin(), vec.end(), vec.end(), comp); +#else + std::sort(vec.begin(), vec.end(), comp); +#endif + } + + if (output_indices) + { + for (int j = 0; j < _k; j++) + { + outptr[out_base + j * out_axis_stride] = vec[j].first; + outidxptr[out_base + j * out_axis_stride] = (float)vec[j].second; + } + } + else + { + for (int j = 0; j < _k; j++) + { + outptr[out_base + j * out_axis_stride] = vec[j].first; + } + } + } + + top_blobs[0] = values; + if (top_blobs.size() >= 2) + top_blobs[1] = indices; + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/topk.h b/src/layer/topk.h new file mode 100644 index 000000000000..947dc21343ff --- /dev/null +++ b/src/layer/topk.h @@ -0,0 +1,29 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_TOPK_H +#define LAYER_TOPK_H + +#include "layer.h" + +namespace ncnn { + +class TopK : public Layer +{ +public: + TopK(); + + virtual int load_param(const ParamDict& pd); + + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + +public: + int axis; + int largest; + int sorted; + int k; +}; + +} // namespace ncnn + +#endif // LAYER_TOPK_H diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d3551879cbfc..697b91f89855 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -168,6 +168,7 @@ ncnn_add_layer_test(Swish) ncnn_add_layer_test(TanH) ncnn_add_layer_test(Threshold) ncnn_add_layer_test(Tile) +ncnn_add_layer_test(TopK) ncnn_add_layer_test(UnaryOp) ncnn_add_layer_test(Unfold) ncnn_add_layer_test(Yolov3DetectionOutput) diff --git a/tests/test_topk.cpp b/tests/test_topk.cpp new file mode 100644 index 000000000000..ac3375058e3f --- /dev/null +++ b/tests/test_topk.cpp @@ -0,0 +1,368 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "testutil.h" + +#if NCNN_SIMPLESTL +// simplemath.h conflicts with system math.h; define only what we need +static const float TEST_INF = 1.f / 0.f; +static const float TEST_NAN = 0.f / 0.f; +#define INFINITY TEST_INF +#define NAN TEST_NAN +#else +#include +#endif + +static int test_topk_cpu_forward(const ncnn::Mat& a, int axis, int k, int largest, int sorted, ncnn::Mat& values, ncnn::Mat& indices) +{ + ncnn::ParamDict pd; + pd.set(0, axis); + pd.set(1, largest); + pd.set(2, sorted); + pd.set(3, k); + + std::vector weights(0); + + ncnn::Option opt; + opt.num_threads = 1; + opt.use_vulkan_compute = false; + opt.use_packing_layout = false; + + ncnn::Layer* op = ncnn::create_layer_cpu("TopK"); + if (!op) + return -1; + + op->load_param(pd); + + ncnn::ModelBinFromMatArray mb(weights.data()); + op->load_model(mb); + + op->create_pipeline(opt); + + std::vector bottom_blobs(1); + bottom_blobs[0] = a; + + std::vector top_blobs(2); + int ret = op->forward(bottom_blobs, top_blobs, opt); + + op->destroy_pipeline(opt); + delete op; + + if (ret != 0) + return ret; + + values = top_blobs[0]; + indices = top_blobs[1]; + + return 0; +} + +static int test_topk_cpu_forward_values_only(const ncnn::Mat& a, int axis, int k, int largest, int sorted, ncnn::Mat& values) +{ + ncnn::ParamDict pd; + pd.set(0, axis); + pd.set(1, largest); + pd.set(2, sorted); + pd.set(3, k); + + std::vector weights(0); + + ncnn::Option opt; + opt.num_threads = 1; + opt.use_vulkan_compute = false; + opt.use_packing_layout = false; + + ncnn::Layer* op = ncnn::create_layer_cpu("TopK"); + if (!op) + return -1; + + op->load_param(pd); + + ncnn::ModelBinFromMatArray mb(weights.data()); + op->load_model(mb); + + op->create_pipeline(opt); + + std::vector bottom_blobs(1); + bottom_blobs[0] = a; + + std::vector top_blobs(1); + int ret = op->forward(bottom_blobs, top_blobs, opt); + + op->destroy_pipeline(opt); + delete op; + + if (ret != 0) + return ret; + + values = top_blobs[0]; + + return 0; +} + +static int test_topk(const ncnn::Mat& a, int axis, int k, int largest, int sorted) +{ + ncnn::ParamDict pd; + pd.set(0, axis); + pd.set(1, largest); + pd.set(2, sorted); + pd.set(3, k); + + std::vector weights(0); + + std::vector a0(1); + a0[0] = a; + + int ret = test_layer("TopK", pd, weights, a0, 2, 0.01f, TEST_LAYER_DISABLE_AUTO_INPUT_CASTING); + if (ret != 0) + { + fprintf(stderr, "test_topk failed a.dims=%d a=(%d %d %d %d) axis=%d k=%d largest=%d sorted=%d\n", a.dims, a.w, a.h, a.d, a.c, axis, k, largest, sorted); + } + + return ret; +} + +static int test_topk_0() +{ + ncnn::Mat a = RandomMat(13); + + return 0 + || test_topk(a, 0, 1, 1, 1) + || test_topk(a, 0, 5, 1, 1) + || test_topk(a, 0, 1, 0, 0) + || test_topk(a, -1, 7, 0, 1) + || test_topk(a, 0, 4, 1, 0) + || test_topk(a, 0, 9, 1, 1); +} + +static int test_topk_1() +{ + ncnn::Mat a = RandomMat(12, 17); + + return 0 + || test_topk(a, 0, 1, 1, 1) + || test_topk(a, 0, 5, 1, 1) + || test_topk(a, 1, 3, 1, 1) + || test_topk(a, -1, 8, 0, 1) + || test_topk(a, 1, 6, 0, 0) + || test_topk(a, -2, 7, 1, 1); +} + +static int test_topk_2() +{ + ncnn::Mat a = RandomMat(8, 9, 11); + + return 0 + || test_topk(a, 0, 3, 1, 1) + || test_topk(a, 1, 4, 1, 1) + || test_topk(a, 2, 2, 0, 1) + || test_topk(a, 2, 5, 1, 0) + || test_topk(a, -1, 6, 1, 1) + || test_topk(a, -2, 5, 0, 1) + || test_topk(a, -3, 7, 1, 1); +} + +static int test_topk_3() +{ + ncnn::Mat a = RandomMat(5, 7, 9, 10); + + return 0 + || test_topk(a, 0, 2, 1, 1) + || test_topk(a, 1, 3, 0, 1) + || test_topk(a, 2, 4, 1, 1) + || test_topk(a, 3, 4, 0, 0) + || test_topk(a, 3, 5, 1, 1) + || test_topk(a, -1, 6, 0, 1) + || test_topk(a, -2, 3, 1, 1) + || test_topk(a, -3, 4, 0, 1) + || test_topk(a, -4, 2, 1, 1); +} + +static int test_topk_inf_order() +{ + ncnn::Mat a(6); + float* ptr = a; + ptr[0] = 1.f; + ptr[1] = INFINITY; + ptr[2] = -2.f; + ptr[3] = -INFINITY; + ptr[4] = 0.5f; + ptr[5] = 3.f; + + ncnn::Mat values; + ncnn::Mat indices; + + int ret = test_topk_cpu_forward(a, 0, 2, 1, 1, values, indices); + if (ret != 0) + { + fprintf(stderr, "test_topk_inf_order largest failed ret=%d\n", ret); + return -1; + } + + const float* vptr = values; + const float* iptr = indices; + if (values.w != 2 || indices.w != 2 || vptr[0] != INFINITY || vptr[1] != 3.f || (int)iptr[0] != 1 || (int)iptr[1] != 5) + { + fprintf(stderr, "test_topk_inf_order largest result mismatch\n"); + return -1; + } + + ret = test_topk_cpu_forward(a, 0, 2, 0, 1, values, indices); + if (ret != 0) + { + fprintf(stderr, "test_topk_inf_order smallest failed ret=%d\n", ret); + return -1; + } + + vptr = values; + iptr = indices; + if (values.w != 2 || indices.w != 2 || vptr[0] != -INFINITY || vptr[1] != -2.f || (int)iptr[0] != 3 || (int)iptr[1] != 2) + { + fprintf(stderr, "test_topk_inf_order smallest result mismatch\n"); + return -1; + } + + return 0; +} + +static int test_topk_nan_robust() +{ + ncnn::Mat a(4); + float* ptr = a; + ptr[0] = 1.f; + ptr[1] = NAN; + ptr[2] = 2.f; + ptr[3] = -1.f; + + ncnn::Mat values; + ncnn::Mat indices; + + int ret = test_topk_cpu_forward(a, 0, 2, 1, 1, values, indices); + if (ret != 0) + { + fprintf(stderr, "test_topk_nan_robust sorted failed ret=%d\n", ret); + return -1; + } + + if (values.w != 2 || indices.w != 2) + { + fprintf(stderr, "test_topk_nan_robust sorted shape mismatch\n"); + return -1; + } + + const float* vptr = values; + const float* iptr = indices; + if (vptr[0] != 2.f || vptr[1] != 1.f || (int)iptr[0] != 2 || (int)iptr[1] != 0) + { + fprintf(stderr, "test_topk_nan_robust sorted largest mismatch\n"); + return -1; + } + + ret = test_topk_cpu_forward(a, 0, 2, 0, 1, values, indices); + if (ret != 0) + { + fprintf(stderr, "test_topk_nan_robust sorted smallest failed ret=%d\n", ret); + return -1; + } + + if (values.w != 2 || indices.w != 2) + { + fprintf(stderr, "test_topk_nan_robust sorted smallest shape mismatch\n"); + return -1; + } + + vptr = values; + iptr = indices; + if (vptr[0] != -1.f || vptr[1] != 1.f || (int)iptr[0] != 3 || (int)iptr[1] != 0) + { + fprintf(stderr, "test_topk_nan_robust sorted smallest mismatch\n"); + return -1; + } + + ret = test_topk_cpu_forward(a, 0, 2, 1, 0, values, indices); + if (ret != 0) + { + fprintf(stderr, "test_topk_nan_robust unsorted failed ret=%d\n", ret); + return -1; + } + + if (values.w != 2 || indices.w != 2) + { + fprintf(stderr, "test_topk_nan_robust unsorted shape mismatch\n"); + return -1; + } + + iptr = indices; + if ((int)iptr[0] < 0 || (int)iptr[0] >= 4 || (int)iptr[1] < 0 || (int)iptr[1] >= 4) + { + fprintf(stderr, "test_topk_nan_robust unsorted invalid indices\n"); + return -1; + } + + return 0; +} + +static int test_topk_values_only_fastpaths() +{ + ncnn::Mat a(5); + float* ptr = a; + ptr[0] = 1.f; + ptr[1] = -2.f; + ptr[2] = 4.f; + ptr[3] = 3.f; + ptr[4] = 0.f; + + ncnn::Mat values; + + int ret = test_topk_cpu_forward_values_only(a, 0, 1, 1, 0, values); + if (ret != 0) + { + fprintf(stderr, "test_topk_values_only_fastpaths k1 failed ret=%d\n", ret); + return -1; + } + + if (values.w != 1 || ((const float*)values)[0] != 4.f) + { + fprintf(stderr, "test_topk_values_only_fastpaths k1 result mismatch\n"); + return -1; + } + + ret = test_topk_cpu_forward_values_only(a, 0, 5, 1, 0, values); + if (ret != 0) + { + fprintf(stderr, "test_topk_values_only_fastpaths fullk failed ret=%d\n", ret); + return -1; + } + + if (values.w != 5) + { + fprintf(stderr, "test_topk_values_only_fastpaths fullk shape mismatch\n"); + return -1; + } + + const float* vptr = values; + for (int i = 0; i < 5; i++) + { + if (vptr[i] != ptr[i]) + { + fprintf(stderr, "test_topk_values_only_fastpaths fullk value mismatch\n"); + return -1; + } + } + + return 0; +} + +int main() +{ + SRAND(7767517); + + return 0 + || test_topk_0() + || test_topk_1() + || test_topk_2() + || test_topk_3() + || test_topk_inf_order() + || test_topk_nan_robust() + || test_topk_values_only_fastpaths(); +} diff --git a/tools/pnnx/CMakeLists.txt b/tools/pnnx/CMakeLists.txt index e50ab4788c3d..5b3250943cf8 100644 --- a/tools/pnnx/CMakeLists.txt +++ b/tools/pnnx/CMakeLists.txt @@ -83,7 +83,8 @@ else() message(WARNING "Building without TorchVision") endif() -include_directories(SYSTEM ${TORCH_INCLUDE_DIRS}) +# Torch includes are added per-target in src/CMakeLists.txt to avoid +# conflicts with system protobuf headers if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") # test if libtorch and protobuf has the same cxxabi version @@ -95,7 +96,10 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") endif() if((PNNX_TORCH_USE_CXX11_ABI AND PNNX_COMPILER_USE_CXX11_ABI) OR (NOT PNNX_TORCH_USE_CXX11_ABI AND NOT PNNX_COMPILER_USE_CXX11_ABI)) - find_package(protobuf CONFIG) + # Torch may have already registered protobuf targets — skip find_package if so + if(NOT TARGET protobuf::libprotobuf) + find_package(protobuf CONFIG) + endif() if(protobuf_FOUND) set(PROTOBUF_FOUND ${protobuf_FOUND}) @@ -109,20 +113,21 @@ if((PNNX_TORCH_USE_CXX11_ABI AND PNNX_COMPILER_USE_CXX11_ABI) OR (NOT PNNX_TORCH set_target_properties(protobuf::protoc PROPERTIES IMPORTED_LOCATION_RELEASE "${PROTOBUF_PROTOC_EXECUTABLE}") endif() endif() -endif() -# https://github.com/supertone-inc/onnxruntime-build -set(onnxruntime_INSTALL_DIR "/home/nihui/osd/pnnx/install" CACHE STRING "") -find_library(onnxruntime_LIB NAMES onnxruntime PATHS ${onnxruntime_INSTALL_DIR}/lib64 ${onnxruntime_INSTALL_DIR}/lib) -if(onnxruntime_LIB) - set(onnxruntime_FOUND TRUE) - add_library(onnxruntime::onnxruntime STATIC IMPORTED) - set_target_properties(onnxruntime::onnxruntime PROPERTIES IMPORTED_LOCATION ${onnxruntime_LIB}) - set_target_properties(onnxruntime::onnxruntime PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${onnxruntime_INSTALL_DIR}/include) -else() - set(onnxruntime_FOUND FALSE) + # Homebrew protobuf 34.x depends on Abseil — we need to link it explicitly + # because macOS doesn't resolve transitive dylib deps with @rpath properly + find_package(PkgConfig QUIET) + if(PKG_CONFIG_FOUND) + pkg_check_modules(ABSL QUIET absl_log_internal_check_op absl_die_if_null absl_log_internal_conditions absl_log_internal_message absl_examine_stack absl_statusor absl_synchronization absl_time) + if(ABSL_FOUND) + set(ABSL_LIBRARIES ${ABSL_LINK_LIBRARIES}) + endif() + endif() endif() +# Disable onnxruntime auto-detection — we only need torch2pnnx for YOLOv10 +set(onnxruntime_FOUND FALSE) + option(PNNX_TNN2PNNX "build tnn2pnnx" ON) add_subdirectory(src) diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 3e0c6f865a87..86c0593b9b37 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -592,6 +592,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/Tensor_reshape_as.cpp pass_ncnn/Tensor_repeat.cpp pass_ncnn/Tensor_unflatten.cpp + pass_ncnn/TopK.cpp pass_ncnn/torch_addmm.cpp pass_ncnn/torch_amax.cpp pass_ncnn/torch_amin.cpp @@ -602,6 +603,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/torch_diag.cpp pass_ncnn/torch_flatten.cpp pass_ncnn/torch_flip.cpp + pass_ncnn/torch_gather.cpp pass_ncnn/torch_istft.cpp pass_ncnn/torch_logsumexp.cpp pass_ncnn/torch_matmul.cpp @@ -614,6 +616,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/torch_roll.cpp pass_ncnn/torch_slice_scatter.cpp pass_ncnn/torch_squeeze.cpp + pass_ncnn/tensor_to.cpp pass_ncnn/torch_sum.cpp pass_ncnn/torch_stft.cpp pass_ncnn/torch_t.cpp @@ -634,6 +637,15 @@ if(PROTOBUF_FOUND) add_library(onnxproto STATIC ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS}) target_include_directories(onnxproto PUBLIC ${PROTOBUF_INCLUDE_DIR} ${CMAKE_CURRENT_BINARY_DIR}) target_link_libraries(onnxproto PUBLIC ${PROTOBUF_LIBRARIES}) + if(ABSL_LIBRARIES) + target_link_libraries(onnxproto PUBLIC ${ABSL_LIBRARIES}) + endif() + # Force system protobuf headers BEFORE any Torch-bundled old headers + # (Torch bundles an ancient protobuf that conflicts with system protobuf >= 22) + set_property(DIRECTORY APPEND PROPERTY INCLUDE_DIRECTORIES_BEFORE + ${PROTOBUF_INCLUDE_DIR} + ${CMAKE_CURRENT_BINARY_DIR} + ) else() add_library(onnxproto STATIC onnx-data.proto onnx-ml.proto onnx-operators-ml.proto) target_include_directories(onnxproto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) @@ -673,6 +685,7 @@ set(torch2pnnx_SRCS add_library(torch2pnnx OBJECT ${torch2pnnx_SRCS}) target_compile_definitions(torch2pnnx PRIVATE BUILD_TORCH2PNNX) target_compile_options(torch2pnnx PUBLIC "${TORCH_CXX_FLAGS}") +target_include_directories(torch2pnnx SYSTEM PRIVATE ${TORCH_INCLUDE_DIRS}) if(WIN32) target_compile_definitions(torch2pnnx PUBLIC NOMINMAX) @@ -686,6 +699,10 @@ if(PROTOBUF_FOUND) add_library(pnnx2onnx STATIC save_onnx.cpp ) + # Ensure Homebrew protobuf headers are found BEFORE Torch's bundled old ones + if(Protobuf_FOUND OR protobuf_MODULE_COMPATIBLE) + target_include_directories(pnnx2onnx BEFORE PRIVATE ${PROTOBUF_INCLUDE_DIR}) + endif() if(onnxruntime_FOUND) target_link_libraries(pnnx2onnx PRIVATE onnxruntime::onnxruntime) else() @@ -778,12 +795,18 @@ set(pnnx_SRCS add_executable(pnnx ${pnnx_SRCS}) set_property(SOURCE main.cpp APPEND PROPERTY COMPILE_DEFINITIONS BUILD_TORCH2PNNX) +target_include_directories(pnnx SYSTEM PRIVATE ${TORCH_INCLUDE_DIRS}) target_link_libraries(pnnx PRIVATE torch2pnnx) if(TorchVision_FOUND) target_link_libraries(pnnx PRIVATE ${TORCHVISION_LIBRARY}) endif() +# Link Abseil (needed for protobuf 34.x on macOS/Homebrew) +if(ABSL_LIBRARIES) + target_link_libraries(pnnx PRIVATE ${ABSL_LIBRARIES}) +endif() + if(WIN32) target_link_libraries(pnnx PRIVATE ${TORCH_LIBRARIES}) else() diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 44e4b77fdf2f..456f51993b15 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1479,6 +1479,33 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath, con fprintf(pyfp, "\n"); + // output custom layer classes for pnnx operators + { + bool has_topk = false; + for (const Operator* op : ops) + { + if (op->type == "TopK") + { + has_topk = true; + break; + } + } + + if (has_topk) + { + fprintf(pyfp, "class TopK(nn.Module):\n"); + fprintf(pyfp, " def __init__(self, axis=1, largest=1, sorted=1):\n"); + fprintf(pyfp, " super(TopK, self).__init__()\n"); + fprintf(pyfp, " self.axis = axis\n"); + fprintf(pyfp, " self.largest = largest\n"); + fprintf(pyfp, " self.sorted = sorted\n"); + fprintf(pyfp, " def forward(self, x, k):\n"); + fprintf(pyfp, " # Torch topk returns (values, indices)\n"); + fprintf(pyfp, " return torch.topk(x, k.item() if hasattr(k, 'item') else k, dim=self.axis, largest=bool(self.largest), sorted=bool(self.sorted))\n"); + fprintf(pyfp, "\n"); + } + } + fprintf(pyfp, "class Model(nn.Module):\n"); fprintf(pyfp, " def __init__(self):\n"); fprintf(pyfp, " super(Model, self).__init__()\n"); @@ -1605,6 +1632,39 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath, con } } + // TopK modules + { + for (const Operator* op : ops) + { + if (op->type != "TopK") + continue; + + fprintf(pyfp, " self.%s = TopK(", sanitize_identifier(op->name).c_str()); + + int i = 0; + for (const auto& it : op->params) + { + fprintf(pyfp, "%s=", it.first.c_str()); + + const Parameter& param = it.second; + if (param.type == 2) + { + fprintf(pyfp, "%d", param.i); + } + else if (param.type == 1) + { + fprintf(pyfp, "%d", param.b ? 1 : 0); + } + + if (i + 1 != op->params.size()) + fprintf(pyfp, ", "); + i++; + } + + fprintf(pyfp, ")\n"); + } + } + fprintf(pyfp, "\n"); // load weights @@ -2186,6 +2246,24 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath, con } fprintf(pyfp, ")\n"); } + else if (op->type == "TopK") + { + // self.topk_name() + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, ")\n"); + } else { if (op->type.find("::") == std::string::npos && op->type.find(".") == std::string::npos) diff --git a/tools/pnnx/src/load_onnx.cpp b/tools/pnnx/src/load_onnx.cpp index 3c788a0c4849..601ac70d80d5 100644 --- a/tools/pnnx/src/load_onnx.cpp +++ b/tools/pnnx/src/load_onnx.cpp @@ -13,8 +13,6 @@ #include #include -#include - #include "ir.h" #include "pass_onnx/canonicalize.h" diff --git a/tools/pnnx/src/pass_level2/torch_topk.cpp b/tools/pnnx/src/pass_level2/torch_topk.cpp index f3d7fae98ba4..339271f95fb7 100644 --- a/tools/pnnx/src/pass_level2/torch_topk.cpp +++ b/tools/pnnx/src/pass_level2/torch_topk.cpp @@ -11,13 +11,13 @@ class torch_topk : public GraphRewriterPass const char* match_pattern_graph() const { return R"PNNXIR(7767517 -7 7 +12 7 pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 k -pnnx.Input input_2 0 1 dim -pnnx.Input input_3 0 1 largest -pnnx.Input input_4 0 1 sorted -aten::topk op_0 5 2 input k dim largest sorted values indices +prim::Constant op_0 0 1 k value=%k +prim::Constant op_1 0 1 dim value=%dim +prim::Constant op_2 0 1 largest value=%largest +prim::Constant op_3 0 1 sorted value=%sorted +aten::topk op_4 5 2 input k dim largest sorted values indices pnnx.Output output 2 0 values indices )PNNXIR"; } diff --git a/tools/pnnx/src/pass_ncnn/TopK.cpp b/tools/pnnx/src/pass_ncnn/TopK.cpp new file mode 100644 index 000000000000..2641493dd0fc --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/TopK.cpp @@ -0,0 +1,148 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +static int parameter_to_bool(const Parameter& p, int default_value) +{ + if (p.type == 1) + return p.b ? 1 : 0; + if (p.type == 2) + return p.i ? 1 : 0; + + return default_value; +} + +class torch_topk : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input_0 0 1 input +torch.topk op_0 1 2 input values indices k=%k dim=%dim largest=%largest sorted=%sorted +pnnx.Output output 2 0 values indices +)PNNXIR"; + } + + const char* type_str() const + { + return "TopK"; + } + + const char* name_str() const + { + return "topk"; + } + + void write(Operator* op, const std::map& captured_params) const + { + int axis = -1; + if (captured_params.find("dim") != captured_params.end()) + { + const Parameter& dim_p = captured_params.at("dim"); + if (dim_p.type == 2) + axis = dim_p.i; + else if (dim_p.type == 5 && !dim_p.ai.empty()) + axis = dim_p.ai[0]; + } + + int largest = 1; + if (captured_params.find("largest") != captured_params.end()) + largest = parameter_to_bool(captured_params.at("largest"), 1); + + int sorted = 1; + if (captured_params.find("sorted") != captured_params.end()) + sorted = parameter_to_bool(captured_params.at("sorted"), 1); + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + if (axis == batch_index) + { + fprintf(stderr, "TopK along batch axis is not supported\n"); + return; + } + + int new_axis = axis; + if (axis >= 0) + new_axis = axis > batch_index ? axis - 1 : axis; + + op->params["0"] = new_axis; + op->params["1"] = largest; + op->params["2"] = sorted; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_topk, 20) + +class torch_topk_0 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 1 +pnnx.Input input_0 0 1 input +torch.topk op_0 1 1 input values k=%k dim=%dim largest=%largest sorted=%sorted +pnnx.Output output 1 0 values +)PNNXIR"; + } + + const char* type_str() const + { + return "TopK"; + } + + const char* name_str() const + { + return "topk"; + } + + void write(Operator* op, const std::map& captured_params) const + { + int axis = -1; + if (captured_params.find("dim") != captured_params.end()) + { + const Parameter& dim_p = captured_params.at("dim"); + if (dim_p.type == 2) + axis = dim_p.i; + else if (dim_p.type == 5 && !dim_p.ai.empty()) + axis = dim_p.ai[0]; + } + + int largest = 1; + if (captured_params.find("largest") != captured_params.end()) + largest = parameter_to_bool(captured_params.at("largest"), 1); + + int sorted = 1; + if (captured_params.find("sorted") != captured_params.end()) + sorted = parameter_to_bool(captured_params.at("sorted"), 1); + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + if (axis == batch_index) + { + fprintf(stderr, "TopK along batch axis is not supported\n"); + return; + } + + int new_axis = axis; + if (axis >= 0) + new_axis = axis > batch_index ? axis - 1 : axis; + + op->params["0"] = new_axis; + op->params["1"] = largest; + op->params["2"] = sorted; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_topk_0, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/tensor_to.cpp b/tools/pnnx/src/pass_ncnn/tensor_to.cpp new file mode 100644 index 000000000000..252498fd0ffa --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/tensor_to.cpp @@ -0,0 +1,67 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class Tensor_to : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 2 +pnnx.Input input_0 0 1 input +Tensor.to op_0 1 1 input out copy=%copy dtype=%dtype +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Cast"; + } + + const char* name_str() const + { + return "to"; + } + + void write(Operator* op, const std::map& captured_params) const + { + // Map torch dtype to ncnn cast type + // torch.float = 1 (float32), torch.int64 = 5 (int64), torch.int32 = 6 (int32), etc. + // The input type is auto-detected, we only need to set the target type + std::string dtype = "torch.float"; + if (captured_params.find("dtype") != captured_params.end()) + { + dtype = captured_params.at("dtype").s; + } + + int type_to = 0; + if (dtype == "torch.float" || dtype == "torch.float32") + type_to = 1; + else if (dtype == "torch.float16" || dtype == "torch.half") + type_to = 2; + else if (dtype == "torch.int8") + type_to = 3; + else if (dtype == "torch.bfloat16") + type_to = 4; + else if (dtype == "torch.int64" || dtype == "torch.long") + type_to = 5; + else if (dtype == "torch.int32" || dtype == "torch.int") + type_to = 6; + + op->params["0"] = 0; // auto-detect input type + op->params["1"] = type_to; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(Tensor_to, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_gather.cpp b/tools/pnnx/src/pass_ncnn/torch_gather.cpp new file mode 100644 index 000000000000..13d1d69e0103 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_gather.cpp @@ -0,0 +1,54 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_gather : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 index +torch.gather op_0 2 1 input index out dim=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Gather"; + } + + const char* name_str() const + { + return "gather"; + } + + void write(Operator* op, const std::map& captured_params) const + { + int axis = 0; + if (captured_params.find("dim") != captured_params.end()) + { + const Parameter& dim_p = captured_params.at("dim"); + if (dim_p.type == 2) + axis = dim_p.i; + else if (dim_p.type == 5 && !dim_p.ai.empty()) + axis = dim_p.ai[0]; + } + + op->params["0"] = axis; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_gather, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/fold_constants.cpp b/tools/pnnx/src/pass_onnx/fold_constants.cpp index 1ef0092a72ec..c79cb29f34a1 100644 --- a/tools/pnnx/src/pass_onnx/fold_constants.cpp +++ b/tools/pnnx/src/pass_onnx/fold_constants.cpp @@ -9,7 +9,15 @@ #include #include +#if __has_include() #include +#elif __has_include() +#include +#elif __has_include() +#include +#else +#error "onnxruntime_c_api.h not found" +#endif #include "dead_code_elimination.h" diff --git a/tools/pnnx/src/pass_onnx/shape_inference.cpp b/tools/pnnx/src/pass_onnx/shape_inference.cpp index 99dc652389d8..23986a7a7d2d 100644 --- a/tools/pnnx/src/pass_onnx/shape_inference.cpp +++ b/tools/pnnx/src/pass_onnx/shape_inference.cpp @@ -8,7 +8,15 @@ #include #include +#if __has_include() #include +#elif __has_include() +#include +#elif __has_include() +#include +#else +#error "onnxruntime_c_api.h not found" +#endif namespace pnnx { diff --git a/tools/pnnx/tests/onnx/CMakeLists.txt b/tools/pnnx/tests/onnx/CMakeLists.txt index f029a669584d..ba821233ad12 100644 --- a/tools/pnnx/tests/onnx/CMakeLists.txt +++ b/tools/pnnx/tests/onnx/CMakeLists.txt @@ -191,6 +191,7 @@ pnnx_onnx_add_test(torch_split) pnnx_onnx_add_test(torch_squeeze) pnnx_onnx_add_test(torch_stack) pnnx_onnx_add_test(torch_sum) +pnnx_onnx_add_test(torch_topk) pnnx_onnx_add_test(torch_transpose) pnnx_onnx_add_test(torch_unbind) pnnx_onnx_add_test(torch_unsqueeze) diff --git a/tools/pnnx/tests/onnx/test_torch_topk.py b/tools/pnnx/tests/onnx/test_torch_topk.py new file mode 100644 index 000000000000..dfd99ee2ac26 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_torch_topk.py @@ -0,0 +1,105 @@ +# Copyright 2026 Tencent +# SPDX-License-Identifier: BSD-3-Clause + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, u, v): + x_values, x_indices = torch.topk( + x, 2, dim=1, largest=True, sorted=True + ) + x_k1_values, x_k1_indices = torch.topk( + x, 1, dim=1, largest=True, sorted=True + ) + x_k0_values, x_k0_indices = torch.topk( + x, 0, dim=1, largest=True, sorted=True + ) + x_unsorted_values, x_unsorted_indices = torch.topk( + x, 2, dim=1, largest=True, sorted=False + ) + x_values_only = torch.topk( + x, 3, dim=1, largest=True, sorted=True + )[0] + y_values, y_indices = torch.topk( + y, 4, dim=3, largest=False, sorted=True + ) + z_values, z_indices = torch.topk( + z, 3, dim=0, largest=True, sorted=True + ) + z_unsorted_values, z_unsorted_indices = torch.topk( + z, 3, dim=0, largest=True, sorted=False + ) + u_values, u_indices = torch.topk( + u, 2, dim=-1, largest=True, sorted=True + ) + v_values, v_indices = torch.topk( + v, 2, dim=1, largest=True, sorted=True + ) + + return ( + x_values, + x_indices, + x_k1_values, + x_k1_indices, + x_k0_values, + x_k0_indices, + x_unsorted_values, + x_unsorted_indices, + x_values_only, + y_values, + y_indices, + z_values, + z_indices, + z_unsorted_values, + z_unsorted_indices, + u_values, + u_indices, + v_values, + v_indices, + ) + + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + u = torch.rand(2, 8, 4) + v = torch.rand(2, 4, 3) + + a = net(x, y, z, u, v) + + # export onnx + torch.onnx.export(net, (x, y, z, u, v), "test_torch_topk.onnx") + + # onnx to pnnx + import os + + os.system( + "../../src/pnnx test_torch_topk.onnx " + "inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10],[2,8,4],[2,4,3]" + ) + + # pnnx inference + import test_torch_topk_pnnx + b = test_torch_topk_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)