Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,10 @@ ArgOptions SDContextParams::get_options() {
"--stream-layers",
"enable residency+prefetch streaming on top of --max-vram (no effect without --max-vram; defaults to false)",
true, &stream_layers},
{"",
"--eager-load-params",
"load all model params into the params backend up front instead of lazily on first use (faster steady-state; higher load-time cost)",
true, &eager_load_params},
{"",
"--force-sdxl-vae-conv-scale",
"force use of conv scale on sdxl vae",
Expand Down Expand Up @@ -760,6 +764,7 @@ std::string SDContextParams::to_string() const {
<< " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n"
<< " max_vram: " << max_vram << ",\n"
<< " stream_layers: " << (stream_layers ? "true" : "false") << ",\n"
<< " eager_load_params: " << (eager_load_params ? "true" : "false") << ",\n"
<< " backend: \"" << backend << "\",\n"
<< " params_backend: \"" << params_backend << "\",\n"
<< " enable_mmap: " << (enable_mmap ? "true" : "false") << ",\n"
Expand Down Expand Up @@ -838,6 +843,7 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool taesd_preview) {
sd_ctx_params.vae_format = str_to_vae_format(vae_format);
sd_ctx_params.max_vram = max_vram;
sd_ctx_params.stream_layers = stream_layers;
sd_ctx_params.eager_load_params = eager_load_params;
sd_ctx_params.backend = effective_backend.c_str();
sd_ctx_params.params_backend = effective_params_backend.c_str();
sd_ctx_params.rpc_servers = rpc_servers.c_str();
Expand Down
1 change: 1 addition & 0 deletions examples/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ struct SDContextParams {
bool offload_params_to_cpu = false;
float max_vram = 0.f;
bool stream_layers = false;
bool eager_load_params = false;
std::string backend;
std::string params_backend;
std::string rpc_servers;
Expand Down
1 change: 1 addition & 0 deletions include/stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ typedef struct {
enum sd_vae_format_t vae_format;
float max_vram; // GiB budget for graph-cut segmented param offload (0 = disabled, -1 = auto free VRAM minus 1 GiB)
bool stream_layers; // Enable residency+prefetch streaming on top of --max-vram (no effect without --max-vram)
bool eager_load_params; // Load all model params into the params backend at model-load time instead of lazily on first use
const char* backend;
const char* params_backend;
const char* rpc_servers;
Expand Down
11 changes: 11 additions & 0 deletions src/model_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,17 @@ bool ModelManager::register_param_tensors(const std::string& desc,
return true;
}

bool ModelManager::load_all_params_eagerly() {
std::vector<TensorState*> all_states;
all_states.reserve(tensor_states_.size());
for (const auto& s : tensor_states_) {
if (s != nullptr) {
all_states.push_back(s.get());
}
}
return load_tensors_to_params_backend(all_states);
}

bool ModelManager::validate_registered_tensors() {
bool ok = true;
for (const auto& state : tensor_states_) {
Expand Down
1 change: 1 addition & 0 deletions src/model_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class ModelManager : public RunnerWeightManager {
}

bool validate_registered_tensors();
bool load_all_params_eagerly();

bool prepare_params(const std::vector<ggml_tensor*>& tensors) override;
void release_compute_backend_params(const std::vector<ggml_tensor*>& tensors) override;
Expand Down
16 changes: 14 additions & 2 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class StableDiffusionGGML {
bool enable_mmap = false;
float max_vram = 0.f;
bool stream_layers = false;
bool eager_load_params = false;
std::string backend_spec;
std::string params_backend_spec;

Expand Down Expand Up @@ -316,6 +317,7 @@ class StableDiffusionGGML {
enable_mmap = sd_ctx_params->enable_mmap;
max_vram = sd_ctx_params->max_vram;
stream_layers = sd_ctx_params->stream_layers;
eager_load_params = sd_ctx_params->eager_load_params;
backend_spec = SAFE_STR(sd_ctx_params->backend);
params_backend_spec = SAFE_STR(sd_ctx_params->params_backend);

Expand Down Expand Up @@ -1093,8 +1095,15 @@ class StableDiffusionGGML {
LOG_ERROR("model metadata validation failed");
return false;
}

LOG_DEBUG("model metadata validated; weights will be prepared lazily");
if (eager_load_params) {
if (!model_manager->load_all_params_eagerly()) {
LOG_ERROR("eager params load failed");
return false;
}
LOG_DEBUG("model metadata validated; weights pre-loaded to params backend");
} else {
LOG_DEBUG("model metadata validated; weights will be prepared lazily");
}

{
size_t total_params_ram_size = 0;
Expand Down Expand Up @@ -2620,6 +2629,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO;
sd_ctx_params->max_vram = 0.f;
sd_ctx_params->stream_layers = false;
sd_ctx_params->eager_load_params = false;
sd_ctx_params->enable_mmap = false;
sd_ctx_params->diffusion_flash_attn = false;
sd_ctx_params->circular_x = false;
Expand Down Expand Up @@ -2663,6 +2673,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"prediction: %s\n"
"max_vram: %.3f\n"
"stream_layers: %s\n"
"eager_load_params: %s\n"
"backend: %s\n"
"params_backend: %s\n"
"flash_attn: %s\n"
Expand Down Expand Up @@ -2697,6 +2708,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
sd_prediction_name(sd_ctx_params->prediction),
sd_ctx_params->max_vram,
BOOL_STR(sd_ctx_params->stream_layers),
BOOL_STR(sd_ctx_params->eager_load_params),
SAFE_STR(sd_ctx_params->backend),
SAFE_STR(sd_ctx_params->params_backend),
BOOL_STR(sd_ctx_params->flash_attn),
Expand Down
Loading