Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 127 additions & 83 deletions src/layer/vulkan/rmsnorm_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ RMSNorm_vulkan::RMSNorm_vulkan()
pipeline_rmsnorm_reduce_mean_pack4 = 0;
pipeline_rmsnorm_coeffs_pack4 = 0;
pipeline_rmsnorm_norm_pack4 = 0;

// subgroup
pipeline_rmsnorm_reduce_subgroup = 0;
pipeline_rmsnorm_reduce_subgroup_pack4 = 0;
}

int RMSNorm_vulkan::create_pipeline(const Option& opt)
Expand Down Expand Up @@ -102,6 +106,18 @@ int RMSNorm_vulkan::create_pipeline(const Option& opt)
pipeline_rmsnorm_norm_pack4->set_optimal_local_size_xyz(8, 8, 1);
pipeline_rmsnorm_norm_pack4->create(LayerShaderType::rmsnorm_norm_pack4, opt, spec);
}

if (vkdev->info.support_subgroup_ops() & VK_SUBGROUP_FEATURE_ARITHMETIC_BIT)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Honor disabled subgroup ops before creating the subgroup pipeline

When callers set opt.use_subgroup_ops = false on a device that still advertises subgroup arithmetic, this branch still creates and later uses the new subgroup RMSNorm pipeline. The shader is guarded by the device macro ncnn_subgroup_arithmetic, so it will compile/execute subgroup instructions despite the option being disabled (and the pipeline compiler targets the non-subgroup SPIR-V environment when the option is false). Please gate this optimization on both the device capability and opt.use_subgroup_ops, falling back to the existing reduce path when subgroup ops are disabled.

Useful? React with 👍 / 👎.

{
pipeline_rmsnorm_reduce_subgroup = new Pipeline(vkdev);
pipeline_rmsnorm_reduce_subgroup->set_local_size_xyz(256, 1, 1);
pipeline_rmsnorm_reduce_subgroup->create(LayerShaderType::rmsnorm_reduce_subgroup, opt, std::vector<vk_specialization_type>());

pipeline_rmsnorm_reduce_subgroup_pack4 = new Pipeline(vkdev);
pipeline_rmsnorm_reduce_subgroup_pack4->set_local_size_xyz(256, 1, 1);
pipeline_rmsnorm_reduce_subgroup_pack4->create(LayerShaderType::rmsnorm_reduce_subgroup_pack4, opt, std::vector<vk_specialization_type>());
}

return 0;
}

Expand Down Expand Up @@ -139,6 +155,11 @@ int RMSNorm_vulkan::destroy_pipeline(const Option&)
pipeline_rmsnorm_coeffs_pack4 = 0;
pipeline_rmsnorm_norm_pack4 = 0;

delete pipeline_rmsnorm_reduce_subgroup;
pipeline_rmsnorm_reduce_subgroup = 0;
delete pipeline_rmsnorm_reduce_subgroup_pack4;
pipeline_rmsnorm_reduce_subgroup_pack4 = 0;

return 0;
}

Expand Down Expand Up @@ -196,106 +217,129 @@ int RMSNorm_vulkan::forward_inplace(VkMat& bottom_top_blob, VkCompute& cmd, cons
}
int num_groups_total = num_groups_per_channel * channels;

// 1) x -> x^2
VkMat square_workspace(w, h, channels, elemsize, elempack, opt.workspace_vkallocator);
VkMat rms_workspace(num_groups_total, 4u * elempack, elempack, opt.workspace_vkallocator);

const Pipeline* pipeline_reduce_subgroup = elempack == 4 ? pipeline_rmsnorm_reduce_subgroup_pack4 : pipeline_rmsnorm_reduce_subgroup;
if (pipeline_reduce_subgroup)
{
std::vector<VkMat> bindings(2);
bindings[0] = bottom_top_blob;
bindings[1] = square_workspace;
bindings[1] = rms_workspace;

std::vector<vk_constant_type> constants(4);
constants[0].i = w;
constants[1].i = h;
constants[2].i = channels;
constants[3].i = cstep;
std::vector<vk_constant_type> constants(3);
constants[0].i = group_size;
constants[1].i = num_groups_per_channel;
constants[2].i = (int)cstep;

const Pipeline* pipe_sq = elempack == 4 ? pipeline_rmsnorm_square_pack4 : pipeline_rmsnorm_square;
VkMat dispatcher;
dispatcher.w = 1;
dispatcher.h = num_groups_total;
dispatcher.c = 1;

cmd.record_pipeline(pipe_sq, bindings, constants, square_workspace);
cmd.record_pipeline(pipeline_reduce_subgroup, bindings, constants, dispatcher);
}

// 2) reduce sum4 (square) -> ... -> mean
VkMat rms_workspace(num_groups_total, 4u * elempack, elempack, opt.workspace_vkallocator);
else
{
int reduced_w = (group_size + 3) / 4;
VkMat sqsum_workspace;
sqsum_workspace.create(reduced_w, num_groups_per_channel, channels, 4u * elempack, elempack, opt.workspace_vkallocator);

// 1) x -> x^2
VkMat square_workspace(w, h, channels, elemsize, elempack, opt.workspace_vkallocator);
{
std::vector<VkMat> bindings(2);
bindings[0] = square_workspace;
bindings[1] = sqsum_workspace;
bindings[0] = bottom_top_blob;
bindings[1] = square_workspace;

std::vector<vk_constant_type> constants(8);
constants[0].i = group_size;
constants[1].i = num_groups_per_channel;
std::vector<vk_constant_type> constants(4);
constants[0].i = w;
constants[1].i = h;
constants[2].i = channels;
constants[3].i = square_workspace.cstep;
constants[4].i = reduced_w;
constants[5].i = num_groups_per_channel;
constants[6].i = channels;
constants[7].i = sqsum_workspace.cstep;

VkMat dispatcher;
dispatcher.w = reduced_w;
dispatcher.h = num_groups_per_channel;
dispatcher.c = channels;
const Pipeline* p_reduce = elempack == 4 ? pipeline_rmsnorm_reduce_sum4_fp16_to_fp32_pack4
: pipeline_rmsnorm_reduce_sum4_fp16_to_fp32;
cmd.record_pipeline(p_reduce, bindings, constants, dispatcher);
}
int pb = 1;
while (sqsum_workspace.w > 1)
{
int current_w = sqsum_workspace.w;
reduced_w = (current_w + 3) / 4;

VkMat sqsum_reduced;
sqsum_reduced.create(reduced_w, num_groups_per_channel, channels, 4u * elempack, elempack, opt.workspace_vkallocator);
constants[3].i = cstep;

std::vector<VkMat> bindings(2);
bindings[0] = sqsum_workspace;
bindings[1] = sqsum_reduced;
const Pipeline* pipe_sq = elempack == 4 ? pipeline_rmsnorm_square_pack4 : pipeline_rmsnorm_square;

std::vector<vk_constant_type> constants(8);
constants[0].i = current_w;
constants[1].i = num_groups_per_channel;
constants[2].i = channels;
constants[3].i = sqsum_workspace.cstep;
constants[4].i = reduced_w;
constants[5].i = num_groups_per_channel;
constants[6].i = channels;
constants[7].i = sqsum_reduced.cstep;

VkMat dispatcher;
dispatcher.w = reduced_w;
dispatcher.h = num_groups_per_channel;
dispatcher.c = channels;
const Pipeline* p_iter = elempack == 4 ? pipeline_rmsnorm_reduce_sum4_fp32_pack4[pb % 2]
: pipeline_rmsnorm_reduce_sum4_fp32[pb % 2];
cmd.record_pipeline(p_iter, bindings, constants, dispatcher);
pb++;
sqsum_workspace = sqsum_reduced;
cmd.record_pipeline(pipe_sq, bindings, constants, square_workspace);
}

// 2) reduce sum4 (square) -> ... -> mean
{
std::vector<VkMat> bindings(2);
bindings[0] = sqsum_workspace;
bindings[1] = rms_workspace;

std::vector<vk_constant_type> constants(5);
constants[0].i = sqsum_workspace.w;
constants[1].i = num_groups_per_channel;
constants[2].i = channels;
constants[3].i = sqsum_workspace.cstep;
constants[4].f = (float)group_size;

VkMat dispatcher;
dispatcher.w = 1;
dispatcher.h = num_groups_per_channel;
dispatcher.c = channels;
const Pipeline* p_mean = elempack == 4 ? pipeline_rmsnorm_reduce_mean_pack4 : pipeline_rmsnorm_reduce_mean;
cmd.record_pipeline(p_mean, bindings, constants, dispatcher);
int reduced_w = (group_size + 3) / 4;
VkMat sqsum_workspace;
sqsum_workspace.create(reduced_w, num_groups_per_channel, channels, 4u * elempack, elempack, opt.workspace_vkallocator);

{
std::vector<VkMat> bindings(2);
bindings[0] = square_workspace;
bindings[1] = sqsum_workspace;

std::vector<vk_constant_type> constants(8);
constants[0].i = group_size;
constants[1].i = num_groups_per_channel;
constants[2].i = channels;
constants[3].i = square_workspace.cstep;
constants[4].i = reduced_w;
constants[5].i = num_groups_per_channel;
constants[6].i = channels;
constants[7].i = sqsum_workspace.cstep;

VkMat dispatcher;
dispatcher.w = reduced_w;
dispatcher.h = num_groups_per_channel;
dispatcher.c = channels;
const Pipeline* p_reduce = elempack == 4 ? pipeline_rmsnorm_reduce_sum4_fp16_to_fp32_pack4
: pipeline_rmsnorm_reduce_sum4_fp16_to_fp32;
cmd.record_pipeline(p_reduce, bindings, constants, dispatcher);
}
int pb = 1;
while (sqsum_workspace.w > 1)
{
int current_w = sqsum_workspace.w;
reduced_w = (current_w + 3) / 4;

VkMat sqsum_reduced;
sqsum_reduced.create(reduced_w, num_groups_per_channel, channels, 4u * elempack, elempack, opt.workspace_vkallocator);

std::vector<VkMat> bindings(2);
bindings[0] = sqsum_workspace;
bindings[1] = sqsum_reduced;

std::vector<vk_constant_type> constants(8);
constants[0].i = current_w;
constants[1].i = num_groups_per_channel;
constants[2].i = channels;
constants[3].i = sqsum_workspace.cstep;
constants[4].i = reduced_w;
constants[5].i = num_groups_per_channel;
constants[6].i = channels;
constants[7].i = sqsum_reduced.cstep;

VkMat dispatcher;
dispatcher.w = reduced_w;
dispatcher.h = num_groups_per_channel;
dispatcher.c = channels;
const Pipeline* p_iter = elempack == 4 ? pipeline_rmsnorm_reduce_sum4_fp32_pack4[pb % 2]
: pipeline_rmsnorm_reduce_sum4_fp32[pb % 2];
cmd.record_pipeline(p_iter, bindings, constants, dispatcher);
pb++;
sqsum_workspace = sqsum_reduced;
}

{
std::vector<VkMat> bindings(2);
bindings[0] = sqsum_workspace;
bindings[1] = rms_workspace;

std::vector<vk_constant_type> constants(5);
constants[0].i = sqsum_workspace.w;
constants[1].i = num_groups_per_channel;
constants[2].i = channels;
constants[3].i = sqsum_workspace.cstep;
constants[4].f = (float)group_size;

VkMat dispatcher;
dispatcher.w = 1;
dispatcher.h = num_groups_per_channel;
dispatcher.c = channels;
const Pipeline* p_mean = elempack == 4 ? pipeline_rmsnorm_reduce_mean_pack4 : pipeline_rmsnorm_reduce_mean;
cmd.record_pipeline(p_mean, bindings, constants, dispatcher);
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/layer/vulkan/rmsnorm_vulkan.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class RMSNorm_vulkan : public RMSNorm
Pipeline* pipeline_rmsnorm_reduce_mean_pack4;
Pipeline* pipeline_rmsnorm_coeffs_pack4;
Pipeline* pipeline_rmsnorm_norm_pack4;

// subgroup
Pipeline* pipeline_rmsnorm_reduce_subgroup;
Pipeline* pipeline_rmsnorm_reduce_subgroup_pack4;
};

} // namespace ncnn
Expand Down
Loading