Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/layer/arm/convolution_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,25 @@ int Convolution_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_blob, con
{
int elembits = bottom_blob.elembits();

// flattened blob, implement as InnerProduct
if (bottom_blob.dims == 1 && kernel_w == 1 && kernel_h == 1)
{
Mat bottom_blob_3d = bottom_blob.reshape(1, 1, bottom_blob.w, opt.workspace_allocator);
if (bottom_blob_3d.empty())
return -100;

Mat top_blob_3d;
int ret = forward_int8_arm(bottom_blob_3d, top_blob_3d, opt);
if (ret != 0)
return ret;

top_blob = top_blob_3d.reshape(top_blob_3d.w * top_blob_3d.h * top_blob_3d.c, opt.blob_allocator);
if (top_blob.empty())
return -100;

return 0;
}

Mat bottom_blob_int8 = bottom_blob;
if (elembits != 8)
{
Expand Down
46 changes: 46 additions & 0 deletions src/layer/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,52 @@ int Convolution::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Optio

// NCNN_LOGE("Convolution input %d x %d ksize=%d %d stride=%d %d", w, h, kernel_w, kernel_h, stride_w, stride_h);

// flattened blob, implement as InnerProduct
if (bottom_blob.dims == 1 && kernel_w == 1 && kernel_h == 1)
{
int num_input = weight_data_size / num_output;
if (bottom_blob.w * bottom_blob.elempack == num_input)
{
// call InnerProduct
ncnn::Layer* op = ncnn::create_layer_cpu(ncnn::LayerType::InnerProduct);

// set param
ncnn::ParamDict pd;
pd.set(0, num_output);
pd.set(1, bias_term);
pd.set(2, weight_data_size);
pd.set(8, int8_scale_term);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve requantized int8 outputs

When int8_scale_term > 100, the normal convolution int8 path creates a 1-byte top blob and applies top_blob_int8_scales before returning. This new flattened fallback delegates to InnerProduct, whose int8 implementation only dequantizes to 4-byte floats and never loads the convolution top scale, so portable builds now return fp32 for flattened requantized convolutions that downstream int8 layers expect to stay int8. Gate this fallback off for >100 or preserve the convolution requantization path.

Useful? React with 👍 / 👎.

pd.set(9, activation_type);
pd.set(10, activation_params);

op->load_param(pd);

// set weights
ncnn::Mat weights[4];
weights[0] = weight_data;
weights[1] = bias_data;

if (int8_scale_term)
{
weights[2] = weight_data_int8_scales;
weights[3] = bottom_blob_int8_scales;
}

op->load_model(ModelBinFromMatArray(weights));

op->create_pipeline(opt);

// forward
int ret = op->forward(bottom_blob, top_blob, opt);

op->destroy_pipeline(opt);

delete op;

return ret;
}
}

const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1;
const int kernel_extent_h = dilation_h * (kernel_h - 1) + 1;

Expand Down
19 changes: 19 additions & 0 deletions src/layer/loongarch/convolution_loongarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,25 @@ int Convolution_loongarch::forward_int8_loongarch(const Mat& bottom_blob, Mat& t
{
int elembits = bottom_blob.elembits();

// flattened blob, implement as InnerProduct
if (bottom_blob.dims == 1 && kernel_w == 1 && kernel_h == 1)
{
Mat bottom_blob_3d = bottom_blob.reshape(1, 1, bottom_blob.w, opt.workspace_allocator);
if (bottom_blob_3d.empty())
return -100;

Mat top_blob_3d;
int ret = forward_int8_loongarch(bottom_blob_3d, top_blob_3d, opt);
if (ret != 0)
return ret;

top_blob = top_blob_3d.reshape(top_blob_3d.w * top_blob_3d.h * top_blob_3d.c, opt.blob_allocator);
if (top_blob.empty())
return -100;

return 0;
}

Mat bottom_blob_int8 = bottom_blob;
if (elembits != 8)
{
Expand Down
19 changes: 19 additions & 0 deletions src/layer/mips/convolution_mips.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,25 @@ int Convolution_mips::forward_int8_mips(const Mat& bottom_blob, Mat& top_blob, c
{
int elembits = bottom_blob.elembits();

// flattened blob, implement as InnerProduct
if (bottom_blob.dims == 1 && kernel_w == 1 && kernel_h == 1)
{
Mat bottom_blob_3d = bottom_blob.reshape(1, 1, bottom_blob.w, opt.workspace_allocator);
if (bottom_blob_3d.empty())
return -100;

Mat top_blob_3d;
int ret = forward_int8_mips(bottom_blob_3d, top_blob_3d, opt);
if (ret != 0)
return ret;

top_blob = top_blob_3d.reshape(top_blob_3d.w * top_blob_3d.h * top_blob_3d.c, opt.blob_allocator);
if (top_blob.empty())
return -100;

return 0;
}

Mat bottom_blob_int8 = bottom_blob;
if (elembits != 8)
{
Expand Down
19 changes: 19 additions & 0 deletions src/layer/x86/convolution_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,25 @@ int Convolution_x86::forward_int8_x86(const Mat& bottom_blob, Mat& top_blob, con
{
int elembits = bottom_blob.elembits();

// flattened blob, implement as InnerProduct
if (bottom_blob.dims == 1 && kernel_w == 1 && kernel_h == 1)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Guard the flattened reshape by weight size

This unconditional branch also catches true 1-D convolutions where the model has one input channel (weight_data_size / num_output == 1) but the input width is >1. In that case reshaping to 1x1xw makes the optimized int8 path think there are w input channels and it indexes the packed weights for channels that were never transformed; the generic path avoids this with the bottom_blob.w * bottom_blob.elempack == num_input check, so add the same guard here and in the parallel arm/mips/loongarch copies.

Useful? React with 👍 / 👎.

{
Mat bottom_blob_3d = bottom_blob.reshape(1, 1, bottom_blob.w, opt.workspace_allocator);
if (bottom_blob_3d.empty())
return -100;

Mat top_blob_3d;
int ret = forward_int8_x86(bottom_blob_3d, top_blob_3d, opt);
if (ret != 0)
return ret;

top_blob = top_blob_3d.reshape(top_blob_3d.w * top_blob_3d.h * top_blob_3d.c, opt.blob_allocator);
if (top_blob.empty())
return -100;

return 0;
}

Mat bottom_blob_int8 = bottom_blob;
if (elembits != 8)
{
Expand Down
130 changes: 130 additions & 0 deletions tests/test_convolution_5.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Copyright 2025 Tencent
// SPDX-License-Identifier: BSD-3-Clause

#include "testutil.h"
#include "layer_type.h"

#include <cmath>
#include <cstdio>
#include <vector>

static int test_convolution_int8_1d(int num_input, int num_output)
{
ncnn::Mat a = RandomMat(num_input);
// scale up so int8 quantization rounding error is negligible
for (int i = 0; i < num_input; i++)
a[i] = roundf(a[i] * 10.f);

ncnn::ParamDict pd;
pd.set(0, num_output);
pd.set(1, 1);
pd.set(11, 1);
pd.set(2, 1);
pd.set(12, 1);
pd.set(3, 1);
pd.set(13, 1);
pd.set(4, 0);
pd.set(5, 1);
pd.set(6, num_input * num_output);
pd.set(8, 1); // int8_scale_term

// int8 weights: weight, bias, per-output weight scales, input scale
std::vector<ncnn::Mat> weights_int8(4);
weights_int8[0] = RandomS8Mat(num_input * num_output);
weights_int8[1] = RandomMat(num_output);
weights_int8[2] = RandomMat(num_output);
for (int i = 0; i < num_output; i++)
weights_int8[2][i] = 1.f;
weights_int8[3] = RandomMat(1);
weights_int8[3][0] = 1.f;

// fp32 reference weights, converted from the same int8 values
std::vector<ncnn::Mat> weights_fp32(2);
weights_fp32[0] = ncnn::Mat(num_input * num_output);
for (int i = 0; i < num_input * num_output; i++)
weights_fp32[0][i] = (float)((signed char*)weights_int8[0])[i];
weights_fp32[1] = weights_int8[1];

// fp32 reference path: Convolution::forward will redirect to InnerProduct
ncnn::Mat ref;
{
ncnn::ParamDict pd_fp32 = pd;
pd_fp32.set(8, 0);

ncnn::Layer* op = ncnn::create_layer_cpu(ncnn::LayerType::Convolution);
op->load_param(pd_fp32);
op->load_model(ncnn::ModelBinFromMatArray(weights_fp32.data()));

ncnn::Option opt;
opt.num_threads = 1;
opt.use_int8_inference = false;
opt.use_packing_layout = false;

int ret = op->create_pipeline(opt);
if (ret != 0)
return ret;
ret = op->forward(a, ref, opt);
op->destroy_pipeline(opt);
delete op;
if (ret != 0)
return ret;
}

// int8 path: was missing the flattened blob handling before the fix
ncnn::Mat out;
{
ncnn::Layer* op = ncnn::create_layer_cpu(ncnn::LayerType::Convolution);
op->load_param(pd);
op->load_model(ncnn::ModelBinFromMatArray(weights_int8.data()));

ncnn::Option opt;
opt.num_threads = 1;
opt.use_int8_inference = true;
opt.use_packing_layout = false;

int ret = op->create_pipeline(opt);
if (ret != 0)
return ret;
ret = op->forward(a, out, opt);
op->destroy_pipeline(opt);
delete op;
if (ret != 0)
return ret;
}

// compare shape and values against fp32 reference
if (ref.dims != out.dims || ref.w != out.w || ref.h != out.h || ref.c != out.c)
{
fprintf(stderr, "test_convolution_int8_1d shape mismatch num_input=%d num_output=%d ref(dims=%d,w=%d,h=%d,c=%d) out(dims=%d,w=%d,h=%d,c=%d)\n",
num_input, num_output,
ref.dims, ref.w, ref.h, ref.c,
out.dims, out.w, out.h, out.c);
return -1;
}

float maxerr = 0.f;
for (int i = 0; i < ref.w; i++)
{
float err = fabsf(ref[i] - out[i]);
if (err > maxerr)
maxerr = err;
}

if (maxerr > 0.01f)
{
fprintf(stderr, "test_convolution_int8_1d failed num_input=%d num_output=%d maxerr=%f\n", num_input, num_output, maxerr);
return -1;
}

return 0;
}

int main()
{
SRAND(7767517);

return 0
|| test_convolution_int8_1d(8, 8)
|| test_convolution_int8_1d(16, 8)
|| test_convolution_int8_1d(17, 5);
}