diff --git a/src/layer/vulkan/rmsnorm_vulkan.cpp b/src/layer/vulkan/rmsnorm_vulkan.cpp index ac638f15cad8..5bd97256e5ba 100644 --- a/src/layer/vulkan/rmsnorm_vulkan.cpp +++ b/src/layer/vulkan/rmsnorm_vulkan.cpp @@ -28,6 +28,10 @@ RMSNorm_vulkan::RMSNorm_vulkan() pipeline_rmsnorm_reduce_mean_pack4 = 0; pipeline_rmsnorm_coeffs_pack4 = 0; pipeline_rmsnorm_norm_pack4 = 0; + + // subgroup + pipeline_rmsnorm_reduce_subgroup = 0; + pipeline_rmsnorm_reduce_subgroup_pack4 = 0; } int RMSNorm_vulkan::create_pipeline(const Option& opt) @@ -102,6 +106,18 @@ int RMSNorm_vulkan::create_pipeline(const Option& opt) pipeline_rmsnorm_norm_pack4->set_optimal_local_size_xyz(8, 8, 1); pipeline_rmsnorm_norm_pack4->create(LayerShaderType::rmsnorm_norm_pack4, opt, spec); } + + if (vkdev->info.support_subgroup_ops() & VK_SUBGROUP_FEATURE_ARITHMETIC_BIT) + { + pipeline_rmsnorm_reduce_subgroup = new Pipeline(vkdev); + pipeline_rmsnorm_reduce_subgroup->set_local_size_xyz(256, 1, 1); + pipeline_rmsnorm_reduce_subgroup->create(LayerShaderType::rmsnorm_reduce_subgroup, opt, std::vector()); + + pipeline_rmsnorm_reduce_subgroup_pack4 = new Pipeline(vkdev); + pipeline_rmsnorm_reduce_subgroup_pack4->set_local_size_xyz(256, 1, 1); + pipeline_rmsnorm_reduce_subgroup_pack4->create(LayerShaderType::rmsnorm_reduce_subgroup_pack4, opt, std::vector()); + } + return 0; } @@ -139,6 +155,11 @@ int RMSNorm_vulkan::destroy_pipeline(const Option&) pipeline_rmsnorm_coeffs_pack4 = 0; pipeline_rmsnorm_norm_pack4 = 0; + delete pipeline_rmsnorm_reduce_subgroup; + pipeline_rmsnorm_reduce_subgroup = 0; + delete pipeline_rmsnorm_reduce_subgroup_pack4; + pipeline_rmsnorm_reduce_subgroup_pack4 = 0; + return 0; } @@ -196,106 +217,129 @@ int RMSNorm_vulkan::forward_inplace(VkMat& bottom_top_blob, VkCompute& cmd, cons } int num_groups_total = num_groups_per_channel * channels; - // 1) x -> x^2 - VkMat square_workspace(w, h, channels, elemsize, elempack, opt.workspace_vkallocator); + VkMat rms_workspace(num_groups_total, 4u * elempack, elempack, opt.workspace_vkallocator); + + const Pipeline* pipeline_reduce_subgroup = elempack == 4 ? pipeline_rmsnorm_reduce_subgroup_pack4 : pipeline_rmsnorm_reduce_subgroup; + if (pipeline_reduce_subgroup) { std::vector bindings(2); bindings[0] = bottom_top_blob; - bindings[1] = square_workspace; + bindings[1] = rms_workspace; - std::vector constants(4); - constants[0].i = w; - constants[1].i = h; - constants[2].i = channels; - constants[3].i = cstep; + std::vector constants(3); + constants[0].i = group_size; + constants[1].i = num_groups_per_channel; + constants[2].i = (int)cstep; - const Pipeline* pipe_sq = elempack == 4 ? pipeline_rmsnorm_square_pack4 : pipeline_rmsnorm_square; + VkMat dispatcher; + dispatcher.w = 1; + dispatcher.h = num_groups_total; + dispatcher.c = 1; - cmd.record_pipeline(pipe_sq, bindings, constants, square_workspace); + cmd.record_pipeline(pipeline_reduce_subgroup, bindings, constants, dispatcher); } - - // 2) reduce sum4 (square) -> ... -> mean - VkMat rms_workspace(num_groups_total, 4u * elempack, elempack, opt.workspace_vkallocator); + else { - int reduced_w = (group_size + 3) / 4; - VkMat sqsum_workspace; - sqsum_workspace.create(reduced_w, num_groups_per_channel, channels, 4u * elempack, elempack, opt.workspace_vkallocator); - + // 1) x -> x^2 + VkMat square_workspace(w, h, channels, elemsize, elempack, opt.workspace_vkallocator); { std::vector bindings(2); - bindings[0] = square_workspace; - bindings[1] = sqsum_workspace; + bindings[0] = bottom_top_blob; + bindings[1] = square_workspace; - std::vector constants(8); - constants[0].i = group_size; - constants[1].i = num_groups_per_channel; + std::vector constants(4); + constants[0].i = w; + constants[1].i = h; constants[2].i = channels; - constants[3].i = square_workspace.cstep; - constants[4].i = reduced_w; - constants[5].i = num_groups_per_channel; - constants[6].i = channels; - constants[7].i = sqsum_workspace.cstep; - - VkMat dispatcher; - dispatcher.w = reduced_w; - dispatcher.h = num_groups_per_channel; - dispatcher.c = channels; - const Pipeline* p_reduce = elempack == 4 ? pipeline_rmsnorm_reduce_sum4_fp16_to_fp32_pack4 - : pipeline_rmsnorm_reduce_sum4_fp16_to_fp32; - cmd.record_pipeline(p_reduce, bindings, constants, dispatcher); - } - int pb = 1; - while (sqsum_workspace.w > 1) - { - int current_w = sqsum_workspace.w; - reduced_w = (current_w + 3) / 4; - - VkMat sqsum_reduced; - sqsum_reduced.create(reduced_w, num_groups_per_channel, channels, 4u * elempack, elempack, opt.workspace_vkallocator); + constants[3].i = cstep; - std::vector bindings(2); - bindings[0] = sqsum_workspace; - bindings[1] = sqsum_reduced; + const Pipeline* pipe_sq = elempack == 4 ? pipeline_rmsnorm_square_pack4 : pipeline_rmsnorm_square; - std::vector constants(8); - constants[0].i = current_w; - constants[1].i = num_groups_per_channel; - constants[2].i = channels; - constants[3].i = sqsum_workspace.cstep; - constants[4].i = reduced_w; - constants[5].i = num_groups_per_channel; - constants[6].i = channels; - constants[7].i = sqsum_reduced.cstep; - - VkMat dispatcher; - dispatcher.w = reduced_w; - dispatcher.h = num_groups_per_channel; - dispatcher.c = channels; - const Pipeline* p_iter = elempack == 4 ? pipeline_rmsnorm_reduce_sum4_fp32_pack4[pb % 2] - : pipeline_rmsnorm_reduce_sum4_fp32[pb % 2]; - cmd.record_pipeline(p_iter, bindings, constants, dispatcher); - pb++; - sqsum_workspace = sqsum_reduced; + cmd.record_pipeline(pipe_sq, bindings, constants, square_workspace); } + // 2) reduce sum4 (square) -> ... -> mean { - std::vector bindings(2); - bindings[0] = sqsum_workspace; - bindings[1] = rms_workspace; - - std::vector constants(5); - constants[0].i = sqsum_workspace.w; - constants[1].i = num_groups_per_channel; - constants[2].i = channels; - constants[3].i = sqsum_workspace.cstep; - constants[4].f = (float)group_size; - - VkMat dispatcher; - dispatcher.w = 1; - dispatcher.h = num_groups_per_channel; - dispatcher.c = channels; - const Pipeline* p_mean = elempack == 4 ? pipeline_rmsnorm_reduce_mean_pack4 : pipeline_rmsnorm_reduce_mean; - cmd.record_pipeline(p_mean, bindings, constants, dispatcher); + int reduced_w = (group_size + 3) / 4; + VkMat sqsum_workspace; + sqsum_workspace.create(reduced_w, num_groups_per_channel, channels, 4u * elempack, elempack, opt.workspace_vkallocator); + + { + std::vector bindings(2); + bindings[0] = square_workspace; + bindings[1] = sqsum_workspace; + + std::vector constants(8); + constants[0].i = group_size; + constants[1].i = num_groups_per_channel; + constants[2].i = channels; + constants[3].i = square_workspace.cstep; + constants[4].i = reduced_w; + constants[5].i = num_groups_per_channel; + constants[6].i = channels; + constants[7].i = sqsum_workspace.cstep; + + VkMat dispatcher; + dispatcher.w = reduced_w; + dispatcher.h = num_groups_per_channel; + dispatcher.c = channels; + const Pipeline* p_reduce = elempack == 4 ? pipeline_rmsnorm_reduce_sum4_fp16_to_fp32_pack4 + : pipeline_rmsnorm_reduce_sum4_fp16_to_fp32; + cmd.record_pipeline(p_reduce, bindings, constants, dispatcher); + } + int pb = 1; + while (sqsum_workspace.w > 1) + { + int current_w = sqsum_workspace.w; + reduced_w = (current_w + 3) / 4; + + VkMat sqsum_reduced; + sqsum_reduced.create(reduced_w, num_groups_per_channel, channels, 4u * elempack, elempack, opt.workspace_vkallocator); + + std::vector bindings(2); + bindings[0] = sqsum_workspace; + bindings[1] = sqsum_reduced; + + std::vector constants(8); + constants[0].i = current_w; + constants[1].i = num_groups_per_channel; + constants[2].i = channels; + constants[3].i = sqsum_workspace.cstep; + constants[4].i = reduced_w; + constants[5].i = num_groups_per_channel; + constants[6].i = channels; + constants[7].i = sqsum_reduced.cstep; + + VkMat dispatcher; + dispatcher.w = reduced_w; + dispatcher.h = num_groups_per_channel; + dispatcher.c = channels; + const Pipeline* p_iter = elempack == 4 ? pipeline_rmsnorm_reduce_sum4_fp32_pack4[pb % 2] + : pipeline_rmsnorm_reduce_sum4_fp32[pb % 2]; + cmd.record_pipeline(p_iter, bindings, constants, dispatcher); + pb++; + sqsum_workspace = sqsum_reduced; + } + + { + std::vector bindings(2); + bindings[0] = sqsum_workspace; + bindings[1] = rms_workspace; + + std::vector constants(5); + constants[0].i = sqsum_workspace.w; + constants[1].i = num_groups_per_channel; + constants[2].i = channels; + constants[3].i = sqsum_workspace.cstep; + constants[4].f = (float)group_size; + + VkMat dispatcher; + dispatcher.w = 1; + dispatcher.h = num_groups_per_channel; + dispatcher.c = channels; + const Pipeline* p_mean = elempack == 4 ? pipeline_rmsnorm_reduce_mean_pack4 : pipeline_rmsnorm_reduce_mean; + cmd.record_pipeline(p_mean, bindings, constants, dispatcher); + } } } diff --git a/src/layer/vulkan/rmsnorm_vulkan.h b/src/layer/vulkan/rmsnorm_vulkan.h index 78ce5f9ab752..7bfc66a7bbb1 100644 --- a/src/layer/vulkan/rmsnorm_vulkan.h +++ b/src/layer/vulkan/rmsnorm_vulkan.h @@ -39,6 +39,10 @@ class RMSNorm_vulkan : public RMSNorm Pipeline* pipeline_rmsnorm_reduce_mean_pack4; Pipeline* pipeline_rmsnorm_coeffs_pack4; Pipeline* pipeline_rmsnorm_norm_pack4; + + // subgroup + Pipeline* pipeline_rmsnorm_reduce_subgroup; + Pipeline* pipeline_rmsnorm_reduce_subgroup_pack4; }; } // namespace ncnn diff --git a/src/layer/vulkan/shader/rmsnorm_reduce_subgroup.comp b/src/layer/vulkan/shader/rmsnorm_reduce_subgroup.comp new file mode 100644 index 000000000000..2883fd46a9d6 --- /dev/null +++ b/src/layer/vulkan/shader/rmsnorm_reduce_subgroup.comp @@ -0,0 +1,157 @@ +// Copyright 2026 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#version 450 + +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#if NCNN_fp16_storage +#extension GL_EXT_shader_subgroup_extended_types_float16 : require +#endif + +layout(binding = 0) readonly buffer bottom_top_blob { sfp bottom_top_blob_data[]; }; +layout(binding = 1) writeonly buffer rms_blob { float rms_data[]; }; + +layout(push_constant) uniform parameter +{ + int group_size; + int num_groups_per_channel; + int cstep; +} p; + +shared float sdata_sqsum[256]; + +void main() +{ + const int tid = int(gl_LocalInvocationID.x); + const int group_id = int(gl_WorkGroupID.y); + + const int gz = group_id / p.num_groups_per_channel; + const int gy = group_id % p.num_groups_per_channel; + const int offset = gz * p.cstep + gy * p.group_size; + + afp sqsum = afp(0.f); + + for (int t = tid; t < p.group_size; t += 256) + { + afp v = buffer_ld1(bottom_top_blob_data, offset + t); + sqsum += v * v; + } + + afp sg_sqsum = subgroupAdd(sqsum); + + if (subgroupElect()) + { + sdata_sqsum[int(gl_SubgroupID)] = float(sg_sqsum); + } + + barrier(); + +#if ncnn_subgroupSize >= 16 + if (int(gl_SubgroupID) == 0) + { + const int lane = int(gl_SubgroupInvocationID); + const int num_sg = int(gl_NumSubgroups); + + afp v_sqsum = lane < num_sg ? afp(sdata_sqsum[lane]) : afp(0.f); + afp r_sqsum = subgroupAdd(v_sqsum); + + if (subgroupElect()) + { + rms_data[group_id] = float(r_sqsum) / float(p.group_size); + } + } +#elif ncnn_subgroupSize == 8 + if (int(gl_SubgroupID) < 4) + { + const int lane = int(gl_SubgroupInvocationID); + const int base = int(gl_SubgroupID) * 8; + + afp v_sqsum = afp(sdata_sqsum[base + lane]); + afp r_sqsum = subgroupAdd(v_sqsum); + + if (subgroupElect()) + { + sdata_sqsum[int(gl_SubgroupID)] = float(r_sqsum); + } + } + + barrier(); + + if (int(gl_SubgroupID) == 0) + { + const int lane = int(gl_SubgroupInvocationID); + + afp v_sqsum = lane < 4 ? afp(sdata_sqsum[lane]) : afp(0.f); + afp r_sqsum = subgroupAdd(v_sqsum); + + if (subgroupElect()) + { + rms_data[group_id] = float(r_sqsum) / float(p.group_size); + } + } +#elif ncnn_subgroupSize == 4 + if (int(gl_SubgroupID) < 16) + { + const int lane = int(gl_SubgroupInvocationID); + const int base = int(gl_SubgroupID) * 4; + + afp v_sqsum = afp(sdata_sqsum[base + lane]); + afp r_sqsum = subgroupAdd(v_sqsum); + + if (subgroupElect()) + { + sdata_sqsum[int(gl_SubgroupID)] = float(r_sqsum); + } + } + + barrier(); + + if (int(gl_SubgroupID) < 4) + { + const int lane = int(gl_SubgroupInvocationID); + const int base = int(gl_SubgroupID) * 4; + + afp v_sqsum = afp(sdata_sqsum[base + lane]); + afp r_sqsum = subgroupAdd(v_sqsum); + + if (subgroupElect()) + { + sdata_sqsum[int(gl_SubgroupID)] = float(r_sqsum); + } + } + + barrier(); + + if (int(gl_SubgroupID) == 0) + { + const int lane = int(gl_SubgroupInvocationID); + + afp v_sqsum = afp(sdata_sqsum[lane]); + afp r_sqsum = subgroupAdd(v_sqsum); + + if (subgroupElect()) + { + rms_data[group_id] = float(r_sqsum) / float(p.group_size); + } + } +#else + { + const int num_sg = 256 / ncnn_subgroupSize; + + for (int stride = num_sg / 2; stride > 0; stride >>= 1) + { + if (tid < stride) + { + sdata_sqsum[tid] = sdata_sqsum[tid] + sdata_sqsum[tid + stride]; + } + barrier(); + } + + if (tid == 0) + { + rms_data[group_id] = sdata_sqsum[0] / float(p.group_size); + } + } +#endif +} diff --git a/src/layer/vulkan/shader/rmsnorm_reduce_subgroup_pack4.comp b/src/layer/vulkan/shader/rmsnorm_reduce_subgroup_pack4.comp new file mode 100644 index 000000000000..1ec5e3df6017 --- /dev/null +++ b/src/layer/vulkan/shader/rmsnorm_reduce_subgroup_pack4.comp @@ -0,0 +1,157 @@ +// Copyright 2026 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#version 450 + +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#if NCNN_fp16_storage +#extension GL_EXT_shader_subgroup_extended_types_float16 : require +#endif + +layout(binding = 0) readonly buffer bottom_top_blob { sfpvec4 bottom_top_blob_data[]; }; +layout(binding = 1) writeonly buffer rms_blob { vec4 rms_data[]; }; + +layout(push_constant) uniform parameter +{ + int group_size; + int num_groups_per_channel; + int cstep; +} p; + +shared vec4 sdata_sqsum[256]; + +void main() +{ + const int tid = int(gl_LocalInvocationID.x); + const int group_id = int(gl_WorkGroupID.y); + + const int gz = group_id / p.num_groups_per_channel; + const int gy = group_id % p.num_groups_per_channel; + const int offset = gz * p.cstep + gy * p.group_size; + + vec4 sqsum = vec4(0.f); + + for (int t = tid; t < p.group_size; t += 256) + { + vec4 v = vec4(buffer_ld4(bottom_top_blob_data, offset + t)); + sqsum += v * v; + } + + vec4 sg_sqsum = subgroupAdd(sqsum); + + if (subgroupElect()) + { + sdata_sqsum[int(gl_SubgroupID)] = vec4(sg_sqsum); + } + + barrier(); + +#if ncnn_subgroupSize >= 16 + if (int(gl_SubgroupID) == 0) + { + const int lane = int(gl_SubgroupInvocationID); + const int num_sg = int(gl_NumSubgroups); + + vec4 v_sqsum = lane < num_sg ? sdata_sqsum[lane] : vec4(0.f); + vec4 r_sqsum = subgroupAdd(v_sqsum); + + if (subgroupElect()) + { + rms_data[group_id] = r_sqsum / float(p.group_size); + } + } +#elif ncnn_subgroupSize == 8 + if (int(gl_SubgroupID) < 4) + { + const int lane = int(gl_SubgroupInvocationID); + const int base = int(gl_SubgroupID) * 8; + + vec4 v_sqsum = sdata_sqsum[base + lane]; + vec4 r_sqsum = subgroupAdd(v_sqsum); + + if (subgroupElect()) + { + sdata_sqsum[int(gl_SubgroupID)] = vec4(r_sqsum); + } + } + + barrier(); + + if (int(gl_SubgroupID) == 0) + { + const int lane = int(gl_SubgroupInvocationID); + + vec4 v_sqsum = lane < 4 ? sdata_sqsum[lane] : vec4(0.f); + vec4 r_sqsum = subgroupAdd(v_sqsum); + + if (subgroupElect()) + { + rms_data[group_id] = r_sqsum / float(p.group_size); + } + } +#elif ncnn_subgroupSize == 4 + if (int(gl_SubgroupID) < 16) + { + const int lane = int(gl_SubgroupInvocationID); + const int base = int(gl_SubgroupID) * 4; + + vec4 v_sqsum = sdata_sqsum[base + lane]; + vec4 r_sqsum = subgroupAdd(v_sqsum); + + if (subgroupElect()) + { + sdata_sqsum[int(gl_SubgroupID)] = vec4(r_sqsum); + } + } + + barrier(); + + if (int(gl_SubgroupID) < 4) + { + const int lane = int(gl_SubgroupInvocationID); + const int base = int(gl_SubgroupID) * 4; + + vec4 v_sqsum = sdata_sqsum[base + lane]; + vec4 r_sqsum = subgroupAdd(v_sqsum); + + if (subgroupElect()) + { + sdata_sqsum[int(gl_SubgroupID)] = vec4(r_sqsum); + } + } + + barrier(); + + if (int(gl_SubgroupID) == 0) + { + const int lane = int(gl_SubgroupInvocationID); + + vec4 v_sqsum = sdata_sqsum[lane]; + vec4 r_sqsum = subgroupAdd(v_sqsum); + + if (subgroupElect()) + { + rms_data[group_id] = r_sqsum / float(p.group_size); + } + } +#else + { + const int num_sg = 256 / ncnn_subgroupSize; + + for (int stride = num_sg / 2; stride > 0; stride >>= 1) + { + if (tid < stride) + { + sdata_sqsum[tid] = sdata_sqsum[tid] + sdata_sqsum[tid + stride]; + } + barrier(); + } + + if (tid == 0) + { + rms_data[group_id] = sdata_sqsum[0] / float(p.group_size); + } + } +#endif +} diff --git a/tests/perf/CMakeLists.txt b/tests/perf/CMakeLists.txt index 10c0535d8087..86586b2a8e30 100644 --- a/tests/perf/CMakeLists.txt +++ b/tests/perf/CMakeLists.txt @@ -38,6 +38,7 @@ ncnn_add_layer_perf(BinaryOp) ncnn_add_layer_perf(Concat) ncnn_add_layer_perf(Sigmoid) ncnn_add_layer_perf(BatchNorm) +ncnn_add_layer_perf(RMSNorm) # SDPA perf tests (decode and prefill phases) if(WITH_LAYER_sdpa) diff --git a/tests/perf/perf_rmsnorm.cpp b/tests/perf/perf_rmsnorm.cpp new file mode 100644 index 000000000000..3ef277b7ead7 --- /dev/null +++ b/tests/perf/perf_rmsnorm.cpp @@ -0,0 +1,33 @@ +// Copyright 2026 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#include "perfutil.h" + +static void perf_rmsnorm(int w, int h, int c, int affine_size) +{ + ncnn::ParamDict pd; + pd.set(0, affine_size); + pd.set(1, 1e-5f); + pd.set(2, 1); + + std::vector weights(1); + weights[0] = PerfMat(affine_size, 1.0f); + + perf_layer("RMSNorm", pd, weights, PerfMat(w, h, c), "affine_size=%d", affine_size); +} + +int main() +{ + // typical LLM feature dimensions + perf_rmsnorm(4096, 1, 1, 4096); + perf_rmsnorm(4096, 1, 32, 4096); + perf_rmsnorm(16384, 1, 1, 16384); + perf_rmsnorm(5120, 1, 1, 5120); + perf_rmsnorm(4096, 512, 1, 4096); + + // smaller dims for comparison + perf_rmsnorm(1024, 1, 1, 1024); + perf_rmsnorm(768, 1, 1, 768); + + return 0; +}