Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/core/ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2007,6 +2007,10 @@ struct GGMLRunner {
}

bool copy_cache_tensors_to_cache_buffer(const std::unordered_set<std::string>* cache_keep_names = nullptr) {
if (cache_tensor_map.empty() && cache_keep_names == nullptr) {
return true;
}

ggml_context* old_cache_ctx = cache_ctx;
ggml_backend_buffer_t old_cache_buffer = cache_buffer;
cache_ctx = nullptr;
Expand Down
79 changes: 28 additions & 51 deletions src/model/diffusion/control.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,16 +312,17 @@ struct ControlNet : public GGMLRunner {
ControlNetBlock control_net;
std::string weight_prefix;

ggml_backend_buffer_t control_buffer = nullptr;
ggml_context* control_ctx = nullptr;
std::vector<ggml_tensor*> control_outputs_ggml;
ggml_tensor* guided_hint_output_ggml = nullptr;
std::vector<sd::Tensor<float>> controls;
sd::Tensor<float> guided_hint;
bool guided_hint_cached = false;
std::shared_ptr<ModelManager> owned_model_manager;
ggml_backend_t params_backend = nullptr;

static const char* guided_hint_cache_name() {
return "controlnet.guided_hint";
}

ControlNet(ggml_backend_t backend,
ggml_backend_t params_backend_,
const String2TensorStorage& tensor_storage_map = {},
Expand All @@ -336,44 +337,12 @@ struct ControlNet : public GGMLRunner {
free_control_ctx();
}

void alloc_control_ctx(std::vector<ggml_tensor*> outs) {
ggml_init_params params;
params.mem_size = static_cast<size_t>(outs.size() * ggml_tensor_overhead()) + 1024 * 1024;
params.mem_buffer = nullptr;
params.no_alloc = true;
control_ctx = ggml_init(params);

control_outputs_ggml.resize(outs.size() - 1);

size_t control_buffer_size = 0;

guided_hint_output_ggml = ggml_dup_tensor(control_ctx, outs[0]);
control_buffer_size += ggml_nbytes(guided_hint_output_ggml);

for (int i = 0; i < outs.size() - 1; i++) {
control_outputs_ggml[i] = ggml_dup_tensor(control_ctx, outs[i + 1]);
control_buffer_size += ggml_nbytes(control_outputs_ggml[i]);
}

control_buffer = ggml_backend_alloc_ctx_tensors(control_ctx, runtime_backend);

LOG_DEBUG("control buffer size %.2fMB", control_buffer_size * 1.f / 1024.f / 1024.f);
}

void free_control_ctx() {
if (control_buffer != nullptr) {
ggml_backend_buffer_free(control_buffer);
control_buffer = nullptr;
}
if (control_ctx != nullptr) {
ggml_free(control_ctx);
control_ctx = nullptr;
}
guided_hint_output_ggml = nullptr;
guided_hint_cached = false;
guided_hint = {};
control_outputs_ggml.clear();
controls.clear();
free_cache_ctx_and_buffer();
}

std::string get_desc() override {
Expand All @@ -397,11 +366,17 @@ struct ControlNet : public GGMLRunner {
ggml_tensor* context = make_optional_input(context_tensor);
ggml_tensor* y = make_optional_input(y_tensor);

guided_hint_output_ggml = nullptr;
control_outputs_ggml.clear();

ggml_tensor* guided_hint_input = nullptr;
if (guided_hint_cached && !guided_hint.empty()) {
guided_hint_input = make_input(guided_hint);
hint = nullptr;
} else {
if (guided_hint_cached) {
guided_hint_input = get_cache_tensor_by_name(guided_hint_cache_name());
if (guided_hint_input == nullptr) {
guided_hint_cached = false;
}
}
if (guided_hint_input == nullptr) {
hint = make_input(hint_tensor);
}

Expand All @@ -415,13 +390,19 @@ struct ControlNet : public GGMLRunner {
context,
y);

if (control_ctx == nullptr) {
alloc_control_ctx(outs);
if (guided_hint_input == nullptr && !outs.empty()) {
guided_hint_output_ggml = outs[0];
ggml_set_output(guided_hint_output_ggml);
cache(guided_hint_cache_name(), guided_hint_output_ggml);
ggml_build_forward_expand(gf, guided_hint_output_ggml);
}

ggml_build_forward_expand(gf, ggml_cpy(compute_ctx, outs[0], guided_hint_output_ggml));
for (int i = 0; i < outs.size() - 1; i++) {
ggml_build_forward_expand(gf, ggml_cpy(compute_ctx, outs[i + 1], control_outputs_ggml[i]));
control_outputs_ggml.reserve(outs.size() > 0 ? outs.size() - 1 : 0);
for (size_t i = 1; i < outs.size(); i++) {
ggml_tensor* control_output = outs[i];
ggml_set_output(control_output);
ggml_build_forward_expand(gf, control_output);
control_outputs_ggml.push_back(control_output);
}

return gf;
Expand All @@ -441,23 +422,19 @@ struct ControlNet : public GGMLRunner {
return build_graph(x, hint, timesteps, context, y);
};

auto compute_result = GGMLRunner::compute<float>(get_graph, n_threads, false, false, false);
auto compute_result = GGMLRunner::compute<float>(get_graph, n_threads, false, false, false, true);
if (!compute_result.has_value()) {
return std::nullopt;
}

if (guided_hint_output_ggml != nullptr) {
guided_hint = restore_trailing_singleton_dims(sd::make_sd_tensor_from_ggml<float>(guided_hint_output_ggml),
4);
}
guided_hint_cached = get_cache_tensor_by_name(guided_hint_cache_name()) != nullptr;
controls.clear();
controls.reserve(control_outputs_ggml.size());
for (ggml_tensor* control : control_outputs_ggml) {
auto control_host = restore_trailing_singleton_dims(sd::make_sd_tensor_from_ggml<float>(control), 4);
GGML_ASSERT(!control_host.empty());
controls.push_back(std::move(control_host));
}
guided_hint_cached = true;
return controls;
}

Expand Down
Loading