From 81fb93fa3ab2a5b2a81cfdf662d9622d0e1e3576 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 14 Jun 2026 16:41:32 +0800 Subject: [PATCH] refactor: simplify ControlNet output caching --- src/core/ggml_extend.hpp | 4 ++ src/model/diffusion/control.hpp | 79 ++++++++++++--------------------- 2 files changed, 32 insertions(+), 51 deletions(-) diff --git a/src/core/ggml_extend.hpp b/src/core/ggml_extend.hpp index 76109d04c..a3dda16b2 100644 --- a/src/core/ggml_extend.hpp +++ b/src/core/ggml_extend.hpp @@ -2007,6 +2007,10 @@ struct GGMLRunner { } bool copy_cache_tensors_to_cache_buffer(const std::unordered_set* cache_keep_names = nullptr) { + if (cache_tensor_map.empty() && cache_keep_names == nullptr) { + return true; + } + ggml_context* old_cache_ctx = cache_ctx; ggml_backend_buffer_t old_cache_buffer = cache_buffer; cache_ctx = nullptr; diff --git a/src/model/diffusion/control.hpp b/src/model/diffusion/control.hpp index d857fa095..57e3616f2 100644 --- a/src/model/diffusion/control.hpp +++ b/src/model/diffusion/control.hpp @@ -312,16 +312,17 @@ struct ControlNet : public GGMLRunner { ControlNetBlock control_net; std::string weight_prefix; - ggml_backend_buffer_t control_buffer = nullptr; - ggml_context* control_ctx = nullptr; std::vector control_outputs_ggml; ggml_tensor* guided_hint_output_ggml = nullptr; std::vector> controls; - sd::Tensor guided_hint; bool guided_hint_cached = false; std::shared_ptr owned_model_manager; ggml_backend_t params_backend = nullptr; + static const char* guided_hint_cache_name() { + return "controlnet.guided_hint"; + } + ControlNet(ggml_backend_t backend, ggml_backend_t params_backend_, const String2TensorStorage& tensor_storage_map = {}, @@ -336,44 +337,12 @@ struct ControlNet : public GGMLRunner { free_control_ctx(); } - void alloc_control_ctx(std::vector outs) { - ggml_init_params params; - params.mem_size = static_cast(outs.size() * ggml_tensor_overhead()) + 1024 * 1024; - params.mem_buffer = nullptr; - params.no_alloc = true; - control_ctx = ggml_init(params); - - control_outputs_ggml.resize(outs.size() - 1); - - size_t control_buffer_size = 0; - - guided_hint_output_ggml = ggml_dup_tensor(control_ctx, outs[0]); - control_buffer_size += ggml_nbytes(guided_hint_output_ggml); - - for (int i = 0; i < outs.size() - 1; i++) { - control_outputs_ggml[i] = ggml_dup_tensor(control_ctx, outs[i + 1]); - control_buffer_size += ggml_nbytes(control_outputs_ggml[i]); - } - - control_buffer = ggml_backend_alloc_ctx_tensors(control_ctx, runtime_backend); - - LOG_DEBUG("control buffer size %.2fMB", control_buffer_size * 1.f / 1024.f / 1024.f); - } - void free_control_ctx() { - if (control_buffer != nullptr) { - ggml_backend_buffer_free(control_buffer); - control_buffer = nullptr; - } - if (control_ctx != nullptr) { - ggml_free(control_ctx); - control_ctx = nullptr; - } guided_hint_output_ggml = nullptr; guided_hint_cached = false; - guided_hint = {}; control_outputs_ggml.clear(); controls.clear(); + free_cache_ctx_and_buffer(); } std::string get_desc() override { @@ -397,11 +366,17 @@ struct ControlNet : public GGMLRunner { ggml_tensor* context = make_optional_input(context_tensor); ggml_tensor* y = make_optional_input(y_tensor); + guided_hint_output_ggml = nullptr; + control_outputs_ggml.clear(); + ggml_tensor* guided_hint_input = nullptr; - if (guided_hint_cached && !guided_hint.empty()) { - guided_hint_input = make_input(guided_hint); - hint = nullptr; - } else { + if (guided_hint_cached) { + guided_hint_input = get_cache_tensor_by_name(guided_hint_cache_name()); + if (guided_hint_input == nullptr) { + guided_hint_cached = false; + } + } + if (guided_hint_input == nullptr) { hint = make_input(hint_tensor); } @@ -415,13 +390,19 @@ struct ControlNet : public GGMLRunner { context, y); - if (control_ctx == nullptr) { - alloc_control_ctx(outs); + if (guided_hint_input == nullptr && !outs.empty()) { + guided_hint_output_ggml = outs[0]; + ggml_set_output(guided_hint_output_ggml); + cache(guided_hint_cache_name(), guided_hint_output_ggml); + ggml_build_forward_expand(gf, guided_hint_output_ggml); } - ggml_build_forward_expand(gf, ggml_cpy(compute_ctx, outs[0], guided_hint_output_ggml)); - for (int i = 0; i < outs.size() - 1; i++) { - ggml_build_forward_expand(gf, ggml_cpy(compute_ctx, outs[i + 1], control_outputs_ggml[i])); + control_outputs_ggml.reserve(outs.size() > 0 ? outs.size() - 1 : 0); + for (size_t i = 1; i < outs.size(); i++) { + ggml_tensor* control_output = outs[i]; + ggml_set_output(control_output); + ggml_build_forward_expand(gf, control_output); + control_outputs_ggml.push_back(control_output); } return gf; @@ -441,15 +422,12 @@ struct ControlNet : public GGMLRunner { return build_graph(x, hint, timesteps, context, y); }; - auto compute_result = GGMLRunner::compute(get_graph, n_threads, false, false, false); + auto compute_result = GGMLRunner::compute(get_graph, n_threads, false, false, false, true); if (!compute_result.has_value()) { return std::nullopt; } - if (guided_hint_output_ggml != nullptr) { - guided_hint = restore_trailing_singleton_dims(sd::make_sd_tensor_from_ggml(guided_hint_output_ggml), - 4); - } + guided_hint_cached = get_cache_tensor_by_name(guided_hint_cache_name()) != nullptr; controls.clear(); controls.reserve(control_outputs_ggml.size()); for (ggml_tensor* control : control_outputs_ggml) { @@ -457,7 +435,6 @@ struct ControlNet : public GGMLRunner { GGML_ASSERT(!control_host.empty()); controls.push_back(std::move(control_host)); } - guided_hint_cached = true; return controls; }