diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1068 @@
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import copy
|
16
|
+
import time
|
17
|
+
from collections import OrderedDict
|
18
|
+
from itertools import combinations
|
19
|
+
from typing import Any, Dict, List, Optional, Union
|
20
|
+
|
21
|
+
import torch
|
22
|
+
|
23
|
+
from ..hooks import ModelHook
|
24
|
+
from ..utils import (
|
25
|
+
is_accelerate_available,
|
26
|
+
logging,
|
27
|
+
)
|
28
|
+
|
29
|
+
|
30
|
+
if is_accelerate_available():
|
31
|
+
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
32
|
+
from accelerate.state import PartialState
|
33
|
+
from accelerate.utils import send_to_device
|
34
|
+
from accelerate.utils.memory import clear_device_cache
|
35
|
+
from accelerate.utils.modeling import convert_file_size_to_int
|
36
|
+
|
37
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
38
|
+
|
39
|
+
|
40
|
+
class CustomOffloadHook(ModelHook):
|
41
|
+
"""
|
42
|
+
A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are
|
43
|
+
on the given device. Optionally offloads other models to the CPU before the forward pass is called.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
execution_device(`str`, `int` or `torch.device`, *optional*):
|
47
|
+
The device on which the model should be executed. Will default to the MPS device if it's available, then
|
48
|
+
GPU 0 if there is a GPU, and finally to the CPU.
|
49
|
+
"""
|
50
|
+
|
51
|
+
no_grad = False
|
52
|
+
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
execution_device: Optional[Union[str, int, torch.device]] = None,
|
56
|
+
other_hooks: Optional[List["UserCustomOffloadHook"]] = None,
|
57
|
+
offload_strategy: Optional["AutoOffloadStrategy"] = None,
|
58
|
+
):
|
59
|
+
self.execution_device = execution_device if execution_device is not None else PartialState().default_device
|
60
|
+
self.other_hooks = other_hooks
|
61
|
+
self.offload_strategy = offload_strategy
|
62
|
+
self.model_id = None
|
63
|
+
|
64
|
+
def set_strategy(self, offload_strategy: "AutoOffloadStrategy"):
|
65
|
+
self.offload_strategy = offload_strategy
|
66
|
+
|
67
|
+
def add_other_hook(self, hook: "UserCustomOffloadHook"):
|
68
|
+
"""
|
69
|
+
Add a hook to the list of hooks to consider for offloading.
|
70
|
+
"""
|
71
|
+
if self.other_hooks is None:
|
72
|
+
self.other_hooks = []
|
73
|
+
self.other_hooks.append(hook)
|
74
|
+
|
75
|
+
def init_hook(self, module):
|
76
|
+
return module.to("cpu")
|
77
|
+
|
78
|
+
def pre_forward(self, module, *args, **kwargs):
|
79
|
+
if module.device != self.execution_device:
|
80
|
+
if self.other_hooks is not None:
|
81
|
+
hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device]
|
82
|
+
# offload all other hooks
|
83
|
+
start_time = time.perf_counter()
|
84
|
+
if self.offload_strategy is not None:
|
85
|
+
hooks_to_offload = self.offload_strategy(
|
86
|
+
hooks=hooks_to_offload,
|
87
|
+
model_id=self.model_id,
|
88
|
+
model=module,
|
89
|
+
execution_device=self.execution_device,
|
90
|
+
)
|
91
|
+
end_time = time.perf_counter()
|
92
|
+
logger.info(
|
93
|
+
f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds"
|
94
|
+
)
|
95
|
+
|
96
|
+
for hook in hooks_to_offload:
|
97
|
+
logger.info(
|
98
|
+
f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu"
|
99
|
+
)
|
100
|
+
hook.offload()
|
101
|
+
|
102
|
+
if hooks_to_offload:
|
103
|
+
clear_device_cache()
|
104
|
+
module.to(self.execution_device)
|
105
|
+
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
|
106
|
+
|
107
|
+
|
108
|
+
class UserCustomOffloadHook:
|
109
|
+
"""
|
110
|
+
A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of
|
111
|
+
the hook or remove it entirely.
|
112
|
+
"""
|
113
|
+
|
114
|
+
def __init__(self, model_id, model, hook):
|
115
|
+
self.model_id = model_id
|
116
|
+
self.model = model
|
117
|
+
self.hook = hook
|
118
|
+
|
119
|
+
def offload(self):
|
120
|
+
self.hook.init_hook(self.model)
|
121
|
+
|
122
|
+
def attach(self):
|
123
|
+
add_hook_to_module(self.model, self.hook)
|
124
|
+
self.hook.model_id = self.model_id
|
125
|
+
|
126
|
+
def remove(self):
|
127
|
+
remove_hook_from_module(self.model)
|
128
|
+
self.hook.model_id = None
|
129
|
+
|
130
|
+
def add_other_hook(self, hook: "UserCustomOffloadHook"):
|
131
|
+
self.hook.add_other_hook(hook)
|
132
|
+
|
133
|
+
|
134
|
+
def custom_offload_with_hook(
|
135
|
+
model_id: str,
|
136
|
+
model: torch.nn.Module,
|
137
|
+
execution_device: Union[str, int, torch.device] = None,
|
138
|
+
offload_strategy: Optional["AutoOffloadStrategy"] = None,
|
139
|
+
):
|
140
|
+
hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy)
|
141
|
+
user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook)
|
142
|
+
user_hook.attach()
|
143
|
+
return user_hook
|
144
|
+
|
145
|
+
|
146
|
+
# this is the class that user can customize to implement their own offload strategy
|
147
|
+
class AutoOffloadStrategy:
|
148
|
+
"""
|
149
|
+
Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on
|
150
|
+
the available memory on the device.
|
151
|
+
"""
|
152
|
+
|
153
|
+
# YiYi TODO: instead of memory_reserve_margin, we should let user set the maximum_total_models_size to keep on device
|
154
|
+
# the actual memory usage would be higher. But it's simpler this way, and can be tested
|
155
|
+
def __init__(self, memory_reserve_margin="3GB"):
|
156
|
+
self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin)
|
157
|
+
|
158
|
+
def __call__(self, hooks, model_id, model, execution_device):
|
159
|
+
if len(hooks) == 0:
|
160
|
+
return []
|
161
|
+
|
162
|
+
current_module_size = model.get_memory_footprint()
|
163
|
+
|
164
|
+
mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0]
|
165
|
+
mem_on_device = mem_on_device - self.memory_reserve_margin
|
166
|
+
if current_module_size < mem_on_device:
|
167
|
+
return []
|
168
|
+
|
169
|
+
min_memory_offload = current_module_size - mem_on_device
|
170
|
+
logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory")
|
171
|
+
|
172
|
+
# exlucde models that's not currently loaded on the device
|
173
|
+
module_sizes = dict(
|
174
|
+
sorted(
|
175
|
+
{hook.model_id: hook.model.get_memory_footprint() for hook in hooks}.items(),
|
176
|
+
key=lambda x: x[1],
|
177
|
+
reverse=True,
|
178
|
+
)
|
179
|
+
)
|
180
|
+
|
181
|
+
# YiYi/Dhruv TODO: sort smallest to largest, and offload in that order we would tend to keep the larger models on GPU more often
|
182
|
+
def search_best_candidate(module_sizes, min_memory_offload):
|
183
|
+
"""
|
184
|
+
search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a
|
185
|
+
minimum memory offload size. the combination of models should add up to the smallest modulesize that is
|
186
|
+
larger than `min_memory_offload`
|
187
|
+
"""
|
188
|
+
model_ids = list(module_sizes.keys())
|
189
|
+
best_candidate = None
|
190
|
+
best_size = float("inf")
|
191
|
+
for r in range(1, len(model_ids) + 1):
|
192
|
+
for candidate_model_ids in combinations(model_ids, r):
|
193
|
+
candidate_size = sum(
|
194
|
+
module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids
|
195
|
+
)
|
196
|
+
if candidate_size < min_memory_offload:
|
197
|
+
continue
|
198
|
+
else:
|
199
|
+
if best_candidate is None or candidate_size < best_size:
|
200
|
+
best_candidate = candidate_model_ids
|
201
|
+
best_size = candidate_size
|
202
|
+
|
203
|
+
return best_candidate
|
204
|
+
|
205
|
+
best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload)
|
206
|
+
|
207
|
+
if best_offload_model_ids is None:
|
208
|
+
# if no combination is found, meaning that we cannot meet the memory requirement, offload all models
|
209
|
+
logger.warning("no combination of models to offload to cpu is found, offloading all models")
|
210
|
+
hooks_to_offload = hooks
|
211
|
+
else:
|
212
|
+
hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids]
|
213
|
+
|
214
|
+
return hooks_to_offload
|
215
|
+
|
216
|
+
|
217
|
+
# utils for display component info in a readable format
|
218
|
+
# TODO: move to a different file
|
219
|
+
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
|
220
|
+
"""Summarizes a dictionary by finding common prefixes that share the same value.
|
221
|
+
|
222
|
+
For a dictionary with dot-separated keys like: {
|
223
|
+
'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
|
224
|
+
'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
|
225
|
+
'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
|
226
|
+
}
|
227
|
+
|
228
|
+
Returns a dictionary where keys are the shortest common prefixes and values are their shared values: {
|
229
|
+
'down_blocks': [0.6], 'up_blocks': [0.3]
|
230
|
+
}
|
231
|
+
"""
|
232
|
+
# First group by values - convert lists to tuples to make them hashable
|
233
|
+
value_to_keys = {}
|
234
|
+
for key, value in d.items():
|
235
|
+
value_tuple = tuple(value) if isinstance(value, list) else value
|
236
|
+
if value_tuple not in value_to_keys:
|
237
|
+
value_to_keys[value_tuple] = []
|
238
|
+
value_to_keys[value_tuple].append(key)
|
239
|
+
|
240
|
+
def find_common_prefix(keys: List[str]) -> str:
|
241
|
+
"""Find the shortest common prefix among a list of dot-separated keys."""
|
242
|
+
if not keys:
|
243
|
+
return ""
|
244
|
+
if len(keys) == 1:
|
245
|
+
return keys[0]
|
246
|
+
|
247
|
+
# Split all keys into parts
|
248
|
+
key_parts = [k.split(".") for k in keys]
|
249
|
+
|
250
|
+
# Find how many initial parts are common
|
251
|
+
common_length = 0
|
252
|
+
for parts in zip(*key_parts):
|
253
|
+
if len(set(parts)) == 1: # All parts at this position are the same
|
254
|
+
common_length += 1
|
255
|
+
else:
|
256
|
+
break
|
257
|
+
|
258
|
+
if common_length == 0:
|
259
|
+
return ""
|
260
|
+
|
261
|
+
# Return the common prefix
|
262
|
+
return ".".join(key_parts[0][:common_length])
|
263
|
+
|
264
|
+
# Create summary by finding common prefixes for each value group
|
265
|
+
summary = {}
|
266
|
+
for value_tuple, keys in value_to_keys.items():
|
267
|
+
prefix = find_common_prefix(keys)
|
268
|
+
if prefix: # Only add if we found a common prefix
|
269
|
+
# Convert tuple back to list if it was originally a list
|
270
|
+
value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple
|
271
|
+
summary[prefix] = value
|
272
|
+
else:
|
273
|
+
summary[""] = value # Use empty string if no common prefix
|
274
|
+
|
275
|
+
return summary
|
276
|
+
|
277
|
+
|
278
|
+
class ComponentsManager:
|
279
|
+
"""
|
280
|
+
A central registry and management system for model components across multiple pipelines.
|
281
|
+
|
282
|
+
[`ComponentsManager`] provides a unified way to register, track, and reuse model components (like UNet, VAE, text
|
283
|
+
encoders, etc.) across different modular pipelines. It includes features for duplicate detection, memory
|
284
|
+
management, and component organization.
|
285
|
+
|
286
|
+
<Tip warning={true}>
|
287
|
+
|
288
|
+
This is an experimental feature and is likely to change in the future.
|
289
|
+
|
290
|
+
</Tip>
|
291
|
+
|
292
|
+
Example:
|
293
|
+
```python
|
294
|
+
from diffusers import ComponentsManager
|
295
|
+
|
296
|
+
# Create a components manager
|
297
|
+
cm = ComponentsManager()
|
298
|
+
|
299
|
+
# Add components
|
300
|
+
cm.add("unet", unet_model, collection="sdxl")
|
301
|
+
cm.add("vae", vae_model, collection="sdxl")
|
302
|
+
|
303
|
+
# Enable auto offloading
|
304
|
+
cm.enable_auto_cpu_offload(device="cuda")
|
305
|
+
|
306
|
+
# Retrieve components
|
307
|
+
unet = cm.get_one(name="unet", collection="sdxl")
|
308
|
+
```
|
309
|
+
"""
|
310
|
+
|
311
|
+
_available_info_fields = [
|
312
|
+
"model_id",
|
313
|
+
"added_time",
|
314
|
+
"collection",
|
315
|
+
"class_name",
|
316
|
+
"size_gb",
|
317
|
+
"adapters",
|
318
|
+
"has_hook",
|
319
|
+
"execution_device",
|
320
|
+
"ip_adapter",
|
321
|
+
]
|
322
|
+
|
323
|
+
def __init__(self):
|
324
|
+
self.components = OrderedDict()
|
325
|
+
# YiYi TODO: can remove once confirm we don't need this in mellon
|
326
|
+
self.added_time = OrderedDict() # Store when components were added
|
327
|
+
self.collections = OrderedDict() # collection_name -> set of component_names
|
328
|
+
self.model_hooks = None
|
329
|
+
self._auto_offload_enabled = False
|
330
|
+
|
331
|
+
def _lookup_ids(
|
332
|
+
self,
|
333
|
+
name: Optional[str] = None,
|
334
|
+
collection: Optional[str] = None,
|
335
|
+
load_id: Optional[str] = None,
|
336
|
+
components: Optional[OrderedDict] = None,
|
337
|
+
):
|
338
|
+
"""
|
339
|
+
Lookup component_ids by name, collection, or load_id. Does not support pattern matching. Returns a set of
|
340
|
+
component_ids
|
341
|
+
"""
|
342
|
+
if components is None:
|
343
|
+
components = self.components
|
344
|
+
|
345
|
+
if name:
|
346
|
+
ids_by_name = set()
|
347
|
+
for component_id, component in components.items():
|
348
|
+
comp_name = self._id_to_name(component_id)
|
349
|
+
if comp_name == name:
|
350
|
+
ids_by_name.add(component_id)
|
351
|
+
else:
|
352
|
+
ids_by_name = set(components.keys())
|
353
|
+
if collection:
|
354
|
+
ids_by_collection = set()
|
355
|
+
for component_id, component in components.items():
|
356
|
+
if component_id in self.collections[collection]:
|
357
|
+
ids_by_collection.add(component_id)
|
358
|
+
else:
|
359
|
+
ids_by_collection = set(components.keys())
|
360
|
+
if load_id:
|
361
|
+
ids_by_load_id = set()
|
362
|
+
for name, component in components.items():
|
363
|
+
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id:
|
364
|
+
ids_by_load_id.add(name)
|
365
|
+
else:
|
366
|
+
ids_by_load_id = set(components.keys())
|
367
|
+
|
368
|
+
ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id)
|
369
|
+
return ids
|
370
|
+
|
371
|
+
@staticmethod
|
372
|
+
def _id_to_name(component_id: str):
|
373
|
+
return "_".join(component_id.split("_")[:-1])
|
374
|
+
|
375
|
+
def add(self, name: str, component: Any, collection: Optional[str] = None):
|
376
|
+
"""
|
377
|
+
Add a component to the ComponentsManager.
|
378
|
+
|
379
|
+
Args:
|
380
|
+
name (str): The name of the component
|
381
|
+
component (Any): The component to add
|
382
|
+
collection (Optional[str]): The collection to add the component to
|
383
|
+
|
384
|
+
Returns:
|
385
|
+
str: The unique component ID, which is generated as "{name}_{id(component)}" where
|
386
|
+
id(component) is Python's built-in unique identifier for the object
|
387
|
+
"""
|
388
|
+
component_id = f"{name}_{id(component)}"
|
389
|
+
is_new_component = True
|
390
|
+
|
391
|
+
# check for duplicated components
|
392
|
+
for comp_id, comp in self.components.items():
|
393
|
+
if comp == component:
|
394
|
+
comp_name = self._id_to_name(comp_id)
|
395
|
+
if comp_name == name:
|
396
|
+
logger.warning(f"ComponentsManager: component '{name}' already exists as '{comp_id}'")
|
397
|
+
component_id = comp_id
|
398
|
+
is_new_component = False
|
399
|
+
break
|
400
|
+
else:
|
401
|
+
logger.warning(
|
402
|
+
f"ComponentsManager: adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'"
|
403
|
+
f"To remove a duplicate, call `components_manager.remove('<component_id>')`."
|
404
|
+
)
|
405
|
+
|
406
|
+
# check for duplicated load_id and warn (we do not delete for you)
|
407
|
+
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
|
408
|
+
components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id)
|
409
|
+
components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id]
|
410
|
+
|
411
|
+
if components_with_same_load_id:
|
412
|
+
existing = ", ".join(components_with_same_load_id)
|
413
|
+
logger.warning(
|
414
|
+
f"ComponentsManager: adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
|
415
|
+
f"To remove a duplicate, call `components_manager.remove('<component_id>')`."
|
416
|
+
)
|
417
|
+
|
418
|
+
# add component to components manager
|
419
|
+
self.components[component_id] = component
|
420
|
+
self.added_time[component_id] = time.time()
|
421
|
+
|
422
|
+
if collection:
|
423
|
+
if collection not in self.collections:
|
424
|
+
self.collections[collection] = set()
|
425
|
+
if component_id not in self.collections[collection]:
|
426
|
+
comp_ids_in_collection = self._lookup_ids(name=name, collection=collection)
|
427
|
+
for comp_id in comp_ids_in_collection:
|
428
|
+
logger.warning(
|
429
|
+
f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}"
|
430
|
+
)
|
431
|
+
# remove existing component from this collection (if it is not in any other collection, will be removed from ComponentsManager)
|
432
|
+
self.remove_from_collection(comp_id, collection)
|
433
|
+
|
434
|
+
self.collections[collection].add(component_id)
|
435
|
+
logger.info(
|
436
|
+
f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}"
|
437
|
+
)
|
438
|
+
else:
|
439
|
+
logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'")
|
440
|
+
|
441
|
+
if self._auto_offload_enabled and is_new_component:
|
442
|
+
self.enable_auto_cpu_offload(self._auto_offload_device)
|
443
|
+
|
444
|
+
return component_id
|
445
|
+
|
446
|
+
def remove_from_collection(self, component_id: str, collection: str):
|
447
|
+
"""
|
448
|
+
Remove a component from a collection.
|
449
|
+
"""
|
450
|
+
if collection not in self.collections:
|
451
|
+
logger.warning(f"Collection '{collection}' not found in ComponentsManager")
|
452
|
+
return
|
453
|
+
if component_id not in self.collections[collection]:
|
454
|
+
logger.warning(f"Component '{component_id}' not found in collection '{collection}'")
|
455
|
+
return
|
456
|
+
# remove from the collection
|
457
|
+
self.collections[collection].remove(component_id)
|
458
|
+
# check if this component is in any other collection
|
459
|
+
comp_colls = [coll for coll, comps in self.collections.items() if component_id in comps]
|
460
|
+
if not comp_colls: # only if no other collection contains this component, remove it
|
461
|
+
logger.warning(f"ComponentsManager: removing component '{component_id}' from ComponentsManager")
|
462
|
+
self.remove(component_id)
|
463
|
+
|
464
|
+
def remove(self, component_id: str = None):
|
465
|
+
"""
|
466
|
+
Remove a component from the ComponentsManager.
|
467
|
+
|
468
|
+
Args:
|
469
|
+
component_id (str): The ID of the component to remove
|
470
|
+
"""
|
471
|
+
if component_id not in self.components:
|
472
|
+
logger.warning(f"Component '{component_id}' not found in ComponentsManager")
|
473
|
+
return
|
474
|
+
|
475
|
+
component = self.components.pop(component_id)
|
476
|
+
self.added_time.pop(component_id)
|
477
|
+
|
478
|
+
for collection in self.collections:
|
479
|
+
if component_id in self.collections[collection]:
|
480
|
+
self.collections[collection].remove(component_id)
|
481
|
+
|
482
|
+
if self._auto_offload_enabled:
|
483
|
+
self.enable_auto_cpu_offload(self._auto_offload_device)
|
484
|
+
else:
|
485
|
+
if isinstance(component, torch.nn.Module):
|
486
|
+
component.to("cpu")
|
487
|
+
del component
|
488
|
+
import gc
|
489
|
+
|
490
|
+
gc.collect()
|
491
|
+
if torch.cuda.is_available():
|
492
|
+
torch.cuda.empty_cache()
|
493
|
+
|
494
|
+
# YiYi TODO: rename to search_components for now, may remove this method
|
495
|
+
def search_components(
|
496
|
+
self,
|
497
|
+
names: Optional[str] = None,
|
498
|
+
collection: Optional[str] = None,
|
499
|
+
load_id: Optional[str] = None,
|
500
|
+
return_dict_with_names: bool = True,
|
501
|
+
):
|
502
|
+
"""
|
503
|
+
Search components by name with simple pattern matching. Optionally filter by collection or load_id.
|
504
|
+
|
505
|
+
Args:
|
506
|
+
names: Component name(s) or pattern(s)
|
507
|
+
Patterns:
|
508
|
+
- "unet" : match any component with base name "unet" (e.g., unet_123abc)
|
509
|
+
- "!unet" : everything except components with base name "unet"
|
510
|
+
- "unet*" : anything with base name starting with "unet"
|
511
|
+
- "!unet*" : anything with base name NOT starting with "unet"
|
512
|
+
- "*unet*" : anything with base name containing "unet"
|
513
|
+
- "!*unet*" : anything with base name NOT containing "unet"
|
514
|
+
- "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet"
|
515
|
+
- "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet"
|
516
|
+
- "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae"
|
517
|
+
collection: Optional collection to filter by
|
518
|
+
load_id: Optional load_id to filter by
|
519
|
+
return_dict_with_names:
|
520
|
+
If True, returns a dictionary with component names as keys, throw an error if
|
521
|
+
multiple components with the same name are found If False, returns a dictionary
|
522
|
+
with component IDs as keys
|
523
|
+
|
524
|
+
Returns:
|
525
|
+
Dictionary mapping component names to components if return_dict_with_names=True, or a dictionary mapping
|
526
|
+
component IDs to components if return_dict_with_names=False
|
527
|
+
"""
|
528
|
+
|
529
|
+
# select components based on collection and load_id filters
|
530
|
+
selected_ids = self._lookup_ids(collection=collection, load_id=load_id)
|
531
|
+
components = {k: self.components[k] for k in selected_ids}
|
532
|
+
|
533
|
+
def get_return_dict(components, return_dict_with_names):
|
534
|
+
"""
|
535
|
+
Create a dictionary mapping component names to components if return_dict_with_names=True, or a dictionary
|
536
|
+
mapping component IDs to components if return_dict_with_names=False, throw an error if duplicate component
|
537
|
+
names are found when return_dict_with_names=True
|
538
|
+
"""
|
539
|
+
if return_dict_with_names:
|
540
|
+
dict_to_return = {}
|
541
|
+
for comp_id, comp in components.items():
|
542
|
+
comp_name = self._id_to_name(comp_id)
|
543
|
+
if comp_name in dict_to_return:
|
544
|
+
raise ValueError(
|
545
|
+
f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys"
|
546
|
+
)
|
547
|
+
dict_to_return[comp_name] = comp
|
548
|
+
return dict_to_return
|
549
|
+
else:
|
550
|
+
return components
|
551
|
+
|
552
|
+
# if no names are provided, return the filtered components as it is
|
553
|
+
if names is None:
|
554
|
+
return get_return_dict(components, return_dict_with_names)
|
555
|
+
|
556
|
+
# if names is not a string, raise an error
|
557
|
+
elif not isinstance(names, str):
|
558
|
+
raise ValueError(f"Invalid type for `names: {type(names)}, only support string")
|
559
|
+
|
560
|
+
# Create mapping from component_id to base_name for components to be used for pattern matching
|
561
|
+
base_names = {comp_id: self._id_to_name(comp_id) for comp_id in components.keys()}
|
562
|
+
|
563
|
+
# Helper function to check if a component matches a pattern based on its base name
|
564
|
+
def matches_pattern(component_id, pattern, exact_match=False):
|
565
|
+
"""
|
566
|
+
Helper function to check if a component matches a pattern based on its base name.
|
567
|
+
|
568
|
+
Args:
|
569
|
+
component_id: The component ID to check
|
570
|
+
pattern: The pattern to match against
|
571
|
+
exact_match: If True, only exact matches to base_name are considered
|
572
|
+
"""
|
573
|
+
base_name = base_names[component_id]
|
574
|
+
|
575
|
+
# Exact match with base name
|
576
|
+
if exact_match:
|
577
|
+
return pattern == base_name
|
578
|
+
|
579
|
+
# Prefix match (ends with *)
|
580
|
+
elif pattern.endswith("*"):
|
581
|
+
prefix = pattern[:-1]
|
582
|
+
return base_name.startswith(prefix)
|
583
|
+
|
584
|
+
# Contains match (starts with *)
|
585
|
+
elif pattern.startswith("*"):
|
586
|
+
search = pattern[1:-1] if pattern.endswith("*") else pattern[1:]
|
587
|
+
return search in base_name
|
588
|
+
|
589
|
+
# Exact match (no wildcards)
|
590
|
+
else:
|
591
|
+
return pattern == base_name
|
592
|
+
|
593
|
+
# Check if this is a "not" pattern
|
594
|
+
is_not_pattern = names.startswith("!")
|
595
|
+
if is_not_pattern:
|
596
|
+
names = names[1:] # Remove the ! prefix
|
597
|
+
|
598
|
+
# Handle OR patterns (containing |)
|
599
|
+
if "|" in names:
|
600
|
+
terms = names.split("|")
|
601
|
+
matches = {}
|
602
|
+
|
603
|
+
for comp_id, comp in components.items():
|
604
|
+
# For OR patterns with exact names (no wildcards), we do exact matching on base names
|
605
|
+
exact_match = all(not (term.startswith("*") or term.endswith("*")) for term in terms)
|
606
|
+
|
607
|
+
# Check if any of the terms match this component
|
608
|
+
should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
|
609
|
+
|
610
|
+
# Flip the decision if this is a NOT pattern
|
611
|
+
if is_not_pattern:
|
612
|
+
should_include = not should_include
|
613
|
+
|
614
|
+
if should_include:
|
615
|
+
matches[comp_id] = comp
|
616
|
+
|
617
|
+
log_msg = "NOT " if is_not_pattern else ""
|
618
|
+
match_type = "exactly matching" if exact_match else "matching any of patterns"
|
619
|
+
logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
|
620
|
+
|
621
|
+
# Try exact match with a base name
|
622
|
+
elif any(names == base_name for base_name in base_names.values()):
|
623
|
+
# Find all components with this base name
|
624
|
+
matches = {
|
625
|
+
comp_id: comp
|
626
|
+
for comp_id, comp in components.items()
|
627
|
+
if (base_names[comp_id] == names) != is_not_pattern
|
628
|
+
}
|
629
|
+
|
630
|
+
if is_not_pattern:
|
631
|
+
logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
|
632
|
+
else:
|
633
|
+
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
|
634
|
+
|
635
|
+
# Prefix match (ends with *)
|
636
|
+
elif names.endswith("*"):
|
637
|
+
prefix = names[:-1]
|
638
|
+
matches = {
|
639
|
+
comp_id: comp
|
640
|
+
for comp_id, comp in components.items()
|
641
|
+
if base_names[comp_id].startswith(prefix) != is_not_pattern
|
642
|
+
}
|
643
|
+
if is_not_pattern:
|
644
|
+
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
|
645
|
+
else:
|
646
|
+
logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
|
647
|
+
|
648
|
+
# Contains match (starts with *)
|
649
|
+
elif names.startswith("*"):
|
650
|
+
search = names[1:-1] if names.endswith("*") else names[1:]
|
651
|
+
matches = {
|
652
|
+
comp_id: comp
|
653
|
+
for comp_id, comp in components.items()
|
654
|
+
if (search in base_names[comp_id]) != is_not_pattern
|
655
|
+
}
|
656
|
+
if is_not_pattern:
|
657
|
+
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
|
658
|
+
else:
|
659
|
+
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
|
660
|
+
|
661
|
+
# Substring match (no wildcards, but not an exact component name)
|
662
|
+
elif any(names in base_name for base_name in base_names.values()):
|
663
|
+
matches = {
|
664
|
+
comp_id: comp
|
665
|
+
for comp_id, comp in components.items()
|
666
|
+
if (names in base_names[comp_id]) != is_not_pattern
|
667
|
+
}
|
668
|
+
if is_not_pattern:
|
669
|
+
logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
|
670
|
+
else:
|
671
|
+
logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
|
672
|
+
|
673
|
+
else:
|
674
|
+
raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
|
675
|
+
|
676
|
+
if not matches:
|
677
|
+
raise ValueError(f"No components found matching pattern '{names}'")
|
678
|
+
|
679
|
+
return get_return_dict(matches, return_dict_with_names)
|
680
|
+
|
681
|
+
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"):
|
682
|
+
"""
|
683
|
+
Enable automatic CPU offloading for all components.
|
684
|
+
|
685
|
+
The algorithm works as follows:
|
686
|
+
1. All models start on CPU by default
|
687
|
+
2. When a model's forward pass is called, it's moved to the execution device
|
688
|
+
3. If there's insufficient memory, other models on the device are moved back to CPU
|
689
|
+
4. The system tries to offload the smallest combination of models that frees enough memory
|
690
|
+
5. Models stay on the execution device until another model needs memory and forces them off
|
691
|
+
|
692
|
+
Args:
|
693
|
+
device (Union[str, int, torch.device]): The execution device where models are moved for forward passes
|
694
|
+
memory_reserve_margin (str): The memory reserve margin to use, default is 3GB. This is the amount of
|
695
|
+
memory to keep free on the device to avoid running out of memory during model
|
696
|
+
execution (e.g., for intermediate activations, gradients, etc.)
|
697
|
+
"""
|
698
|
+
if not is_accelerate_available():
|
699
|
+
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
|
700
|
+
|
701
|
+
for name, component in self.components.items():
|
702
|
+
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
|
703
|
+
remove_hook_from_module(component, recurse=True)
|
704
|
+
|
705
|
+
self.disable_auto_cpu_offload()
|
706
|
+
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
|
707
|
+
device = torch.device(device)
|
708
|
+
if device.index is None:
|
709
|
+
device = torch.device(f"{device.type}:{0}")
|
710
|
+
all_hooks = []
|
711
|
+
for name, component in self.components.items():
|
712
|
+
if isinstance(component, torch.nn.Module):
|
713
|
+
hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy)
|
714
|
+
all_hooks.append(hook)
|
715
|
+
|
716
|
+
for hook in all_hooks:
|
717
|
+
other_hooks = [h for h in all_hooks if h is not hook]
|
718
|
+
for other_hook in other_hooks:
|
719
|
+
if other_hook.hook.execution_device == hook.hook.execution_device:
|
720
|
+
hook.add_other_hook(other_hook)
|
721
|
+
|
722
|
+
self.model_hooks = all_hooks
|
723
|
+
self._auto_offload_enabled = True
|
724
|
+
self._auto_offload_device = device
|
725
|
+
|
726
|
+
def disable_auto_cpu_offload(self):
|
727
|
+
"""
|
728
|
+
Disable automatic CPU offloading for all components.
|
729
|
+
"""
|
730
|
+
if self.model_hooks is None:
|
731
|
+
self._auto_offload_enabled = False
|
732
|
+
return
|
733
|
+
|
734
|
+
for hook in self.model_hooks:
|
735
|
+
hook.offload()
|
736
|
+
hook.remove()
|
737
|
+
if self.model_hooks:
|
738
|
+
clear_device_cache()
|
739
|
+
self.model_hooks = None
|
740
|
+
self._auto_offload_enabled = False
|
741
|
+
|
742
|
+
# YiYi TODO: (1) add quantization info
|
743
|
+
def get_model_info(
|
744
|
+
self,
|
745
|
+
component_id: str,
|
746
|
+
fields: Optional[Union[str, List[str]]] = None,
|
747
|
+
) -> Optional[Dict[str, Any]]:
|
748
|
+
"""Get comprehensive information about a component.
|
749
|
+
|
750
|
+
Args:
|
751
|
+
component_id (str): Name of the component to get info for
|
752
|
+
fields (Optional[Union[str, List[str]]]):
|
753
|
+
Field(s) to return. Can be a string for single field or list of fields. If None, uses the
|
754
|
+
available_info_fields setting.
|
755
|
+
|
756
|
+
Returns:
|
757
|
+
Dictionary containing requested component metadata. If fields is specified, returns only those fields.
|
758
|
+
Otherwise, returns all fields.
|
759
|
+
"""
|
760
|
+
if component_id not in self.components:
|
761
|
+
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
|
762
|
+
|
763
|
+
component = self.components[component_id]
|
764
|
+
|
765
|
+
# Validate fields if specified
|
766
|
+
if fields is not None:
|
767
|
+
if isinstance(fields, str):
|
768
|
+
fields = [fields]
|
769
|
+
for field in fields:
|
770
|
+
if field not in self._available_info_fields:
|
771
|
+
raise ValueError(f"Field '{field}' not found in available_info_fields")
|
772
|
+
|
773
|
+
# Build complete info dict first
|
774
|
+
info = {
|
775
|
+
"model_id": component_id,
|
776
|
+
"added_time": self.added_time[component_id],
|
777
|
+
"collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps])
|
778
|
+
or None,
|
779
|
+
}
|
780
|
+
|
781
|
+
# Additional info for torch.nn.Module components
|
782
|
+
if isinstance(component, torch.nn.Module):
|
783
|
+
# Check for hook information
|
784
|
+
has_hook = hasattr(component, "_hf_hook")
|
785
|
+
execution_device = None
|
786
|
+
if has_hook and hasattr(component._hf_hook, "execution_device"):
|
787
|
+
execution_device = component._hf_hook.execution_device
|
788
|
+
|
789
|
+
info.update(
|
790
|
+
{
|
791
|
+
"class_name": component.__class__.__name__,
|
792
|
+
"size_gb": component.get_memory_footprint() / (1024**3),
|
793
|
+
"adapters": None, # Default to None
|
794
|
+
"has_hook": has_hook,
|
795
|
+
"execution_device": execution_device,
|
796
|
+
}
|
797
|
+
)
|
798
|
+
|
799
|
+
# Get adapters if applicable
|
800
|
+
if hasattr(component, "peft_config"):
|
801
|
+
info["adapters"] = list(component.peft_config.keys())
|
802
|
+
|
803
|
+
# Check for IP-Adapter scales
|
804
|
+
if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"):
|
805
|
+
processors = copy.deepcopy(component.attn_processors)
|
806
|
+
# First check if any processor is an IP-Adapter
|
807
|
+
processor_types = [v.__class__.__name__ for v in processors.values()]
|
808
|
+
if any("IPAdapter" in ptype for ptype in processor_types):
|
809
|
+
# Then get scales only from IP-Adapter processors
|
810
|
+
scales = {
|
811
|
+
k: v.scale
|
812
|
+
for k, v in processors.items()
|
813
|
+
if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__
|
814
|
+
}
|
815
|
+
if scales:
|
816
|
+
info["ip_adapter"] = summarize_dict_by_value_and_parts(scales)
|
817
|
+
|
818
|
+
# If fields specified, filter info
|
819
|
+
if fields is not None:
|
820
|
+
return {k: v for k, v in info.items() if k in fields}
|
821
|
+
else:
|
822
|
+
return info
|
823
|
+
|
824
|
+
# YiYi TODO: (1) add display fields, allow user to set which fields to display in the comnponents table
|
825
|
+
def __repr__(self):
|
826
|
+
# Handle empty components case
|
827
|
+
if not self.components:
|
828
|
+
return "Components:\n" + "=" * 50 + "\nNo components registered.\n" + "=" * 50
|
829
|
+
|
830
|
+
# Extract load_id if available
|
831
|
+
def get_load_id(component):
|
832
|
+
if hasattr(component, "_diffusers_load_id"):
|
833
|
+
return component._diffusers_load_id
|
834
|
+
return "N/A"
|
835
|
+
|
836
|
+
# Format device info compactly
|
837
|
+
def format_device(component, info):
|
838
|
+
if not info["has_hook"]:
|
839
|
+
return str(getattr(component, "device", "N/A"))
|
840
|
+
else:
|
841
|
+
device = str(getattr(component, "device", "N/A"))
|
842
|
+
exec_device = str(info["execution_device"] or "N/A")
|
843
|
+
return f"{device}({exec_device})"
|
844
|
+
|
845
|
+
# Get max length of load_ids for models
|
846
|
+
load_ids = [
|
847
|
+
get_load_id(component)
|
848
|
+
for component in self.components.values()
|
849
|
+
if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id")
|
850
|
+
]
|
851
|
+
max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15
|
852
|
+
|
853
|
+
# Get all collections for each component
|
854
|
+
component_collections = {}
|
855
|
+
for name in self.components.keys():
|
856
|
+
component_collections[name] = []
|
857
|
+
for coll, comps in self.collections.items():
|
858
|
+
if name in comps:
|
859
|
+
component_collections[name].append(coll)
|
860
|
+
if not component_collections[name]:
|
861
|
+
component_collections[name] = ["N/A"]
|
862
|
+
|
863
|
+
# Find the maximum collection name length
|
864
|
+
all_collections = [coll for colls in component_collections.values() for coll in colls]
|
865
|
+
max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10
|
866
|
+
|
867
|
+
col_widths = {
|
868
|
+
"id": max(15, max(len(name) for name in self.components.keys())),
|
869
|
+
"class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
|
870
|
+
"device": 20,
|
871
|
+
"dtype": 15,
|
872
|
+
"size": 10,
|
873
|
+
"load_id": max_load_id_len,
|
874
|
+
"collection": max_collection_len,
|
875
|
+
}
|
876
|
+
|
877
|
+
# Create the header lines
|
878
|
+
sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
|
879
|
+
dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
|
880
|
+
|
881
|
+
output = "Components:\n" + sep_line
|
882
|
+
|
883
|
+
# Separate components into models and others
|
884
|
+
models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
|
885
|
+
others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)}
|
886
|
+
|
887
|
+
# Models section
|
888
|
+
if models:
|
889
|
+
output += "Models:\n" + dash_line
|
890
|
+
# Column headers
|
891
|
+
output += f"{'Name_ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | "
|
892
|
+
output += f"{'Device: act(exec)':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | "
|
893
|
+
output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n"
|
894
|
+
output += dash_line
|
895
|
+
|
896
|
+
# Model entries
|
897
|
+
for name, component in models.items():
|
898
|
+
info = self.get_model_info(name)
|
899
|
+
device_str = format_device(component, info)
|
900
|
+
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
|
901
|
+
load_id = get_load_id(component)
|
902
|
+
|
903
|
+
# Print first collection on the main line
|
904
|
+
first_collection = component_collections[name][0] if component_collections[name] else "N/A"
|
905
|
+
|
906
|
+
output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | "
|
907
|
+
output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
|
908
|
+
output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n"
|
909
|
+
|
910
|
+
# Print additional collections on separate lines if they exist
|
911
|
+
for i in range(1, len(component_collections[name])):
|
912
|
+
collection = component_collections[name][i]
|
913
|
+
output += f"{'':<{col_widths['id']}} | {'':<{col_widths['class']}} | "
|
914
|
+
output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | "
|
915
|
+
output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n"
|
916
|
+
|
917
|
+
output += dash_line
|
918
|
+
|
919
|
+
# Other components section
|
920
|
+
if others:
|
921
|
+
if models: # Add extra newline if we had models section
|
922
|
+
output += "\n"
|
923
|
+
output += "Other Components:\n" + dash_line
|
924
|
+
# Column headers for other components
|
925
|
+
output += f"{'ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | Collection\n"
|
926
|
+
output += dash_line
|
927
|
+
|
928
|
+
# Other component entries
|
929
|
+
for name, component in others.items():
|
930
|
+
info = self.get_model_info(name)
|
931
|
+
|
932
|
+
# Print first collection on the main line
|
933
|
+
first_collection = component_collections[name][0] if component_collections[name] else "N/A"
|
934
|
+
|
935
|
+
output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n"
|
936
|
+
|
937
|
+
# Print additional collections on separate lines if they exist
|
938
|
+
for i in range(1, len(component_collections[name])):
|
939
|
+
collection = component_collections[name][i]
|
940
|
+
output += f"{'':<{col_widths['id']}} | {'':<{col_widths['class']}} | {collection}\n"
|
941
|
+
|
942
|
+
output += dash_line
|
943
|
+
|
944
|
+
# Add additional component info
|
945
|
+
output += "\nAdditional Component Info:\n" + "=" * 50 + "\n"
|
946
|
+
for name in self.components:
|
947
|
+
info = self.get_model_info(name)
|
948
|
+
if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")):
|
949
|
+
output += f"\n{name}:\n"
|
950
|
+
if info.get("adapters") is not None:
|
951
|
+
output += f" Adapters: {info['adapters']}\n"
|
952
|
+
if info.get("ip_adapter"):
|
953
|
+
output += " IP-Adapter: Enabled\n"
|
954
|
+
|
955
|
+
return output
|
956
|
+
|
957
|
+
def get_one(
|
958
|
+
self,
|
959
|
+
component_id: Optional[str] = None,
|
960
|
+
name: Optional[str] = None,
|
961
|
+
collection: Optional[str] = None,
|
962
|
+
load_id: Optional[str] = None,
|
963
|
+
) -> Any:
|
964
|
+
"""
|
965
|
+
Get a single component by either:
|
966
|
+
- searching name (pattern matching), collection, or load_id.
|
967
|
+
- passing in a component_id
|
968
|
+
Raises an error if multiple components match or none are found.
|
969
|
+
|
970
|
+
Args:
|
971
|
+
component_id (Optional[str]): Optional component ID to get
|
972
|
+
name (Optional[str]): Component name or pattern
|
973
|
+
collection (Optional[str]): Optional collection to filter by
|
974
|
+
load_id (Optional[str]): Optional load_id to filter by
|
975
|
+
|
976
|
+
Returns:
|
977
|
+
A single component
|
978
|
+
|
979
|
+
Raises:
|
980
|
+
ValueError: If no components match or multiple components match
|
981
|
+
"""
|
982
|
+
|
983
|
+
if component_id is not None and (name is not None or collection is not None or load_id is not None):
|
984
|
+
raise ValueError("If searching by component_id, do not pass name, collection, or load_id")
|
985
|
+
|
986
|
+
# search by component_id
|
987
|
+
if component_id is not None:
|
988
|
+
if component_id not in self.components:
|
989
|
+
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
|
990
|
+
return self.components[component_id]
|
991
|
+
# search with name/collection/load_id
|
992
|
+
results = self.search_components(name, collection, load_id)
|
993
|
+
|
994
|
+
if not results:
|
995
|
+
raise ValueError(f"No components found matching '{name}'")
|
996
|
+
|
997
|
+
if len(results) > 1:
|
998
|
+
raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
|
999
|
+
|
1000
|
+
return next(iter(results.values()))
|
1001
|
+
|
1002
|
+
def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] = None):
|
1003
|
+
"""
|
1004
|
+
Get component IDs by a list of names, optionally filtered by collection.
|
1005
|
+
|
1006
|
+
Args:
|
1007
|
+
names (Union[str, List[str]]): List of component names
|
1008
|
+
collection (Optional[str]): Optional collection to filter by
|
1009
|
+
|
1010
|
+
Returns:
|
1011
|
+
List[str]: List of component IDs
|
1012
|
+
"""
|
1013
|
+
ids = set()
|
1014
|
+
if not isinstance(names, list):
|
1015
|
+
names = [names]
|
1016
|
+
for name in names:
|
1017
|
+
ids.update(self._lookup_ids(name=name, collection=collection))
|
1018
|
+
return list(ids)
|
1019
|
+
|
1020
|
+
def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional[bool] = True):
|
1021
|
+
"""
|
1022
|
+
Get components by a list of IDs.
|
1023
|
+
|
1024
|
+
Args:
|
1025
|
+
ids (List[str]):
|
1026
|
+
List of component IDs
|
1027
|
+
return_dict_with_names (Optional[bool]):
|
1028
|
+
Whether to return a dictionary with component names as keys:
|
1029
|
+
|
1030
|
+
Returns:
|
1031
|
+
Dict[str, Any]: Dictionary of components.
|
1032
|
+
- If return_dict_with_names=True, keys are component names.
|
1033
|
+
- If return_dict_with_names=False, keys are component IDs.
|
1034
|
+
|
1035
|
+
Raises:
|
1036
|
+
ValueError: If duplicate component names are found in the search results when return_dict_with_names=True
|
1037
|
+
"""
|
1038
|
+
components = {id: self.components[id] for id in ids}
|
1039
|
+
|
1040
|
+
if return_dict_with_names:
|
1041
|
+
dict_to_return = {}
|
1042
|
+
for comp_id, comp in components.items():
|
1043
|
+
comp_name = self._id_to_name(comp_id)
|
1044
|
+
if comp_name in dict_to_return:
|
1045
|
+
raise ValueError(
|
1046
|
+
f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys"
|
1047
|
+
)
|
1048
|
+
dict_to_return[comp_name] = comp
|
1049
|
+
return dict_to_return
|
1050
|
+
else:
|
1051
|
+
return components
|
1052
|
+
|
1053
|
+
def get_components_by_names(self, names: List[str], collection: Optional[str] = None):
|
1054
|
+
"""
|
1055
|
+
Get components by a list of names, optionally filtered by collection.
|
1056
|
+
|
1057
|
+
Args:
|
1058
|
+
names (List[str]): List of component names
|
1059
|
+
collection (Optional[str]): Optional collection to filter by
|
1060
|
+
|
1061
|
+
Returns:
|
1062
|
+
Dict[str, Any]: Dictionary of components with component names as keys
|
1063
|
+
|
1064
|
+
Raises:
|
1065
|
+
ValueError: If duplicate component names are found in the search results
|
1066
|
+
"""
|
1067
|
+
ids = self.get_ids(names, collection)
|
1068
|
+
return self.get_components_by_ids(ids)
|