diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
diffusers/loaders/peft.py
CHANGED
@@ -61,6 +61,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
|
61
61
|
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
|
62
62
|
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
|
63
63
|
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
|
64
|
+
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
|
64
65
|
}
|
65
66
|
|
66
67
|
|
@@ -163,6 +164,8 @@ class PeftAdapterMixin:
|
|
163
164
|
from peft import inject_adapter_in_model, set_peft_model_state_dict
|
164
165
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
165
166
|
|
167
|
+
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
168
|
+
|
166
169
|
cache_dir = kwargs.pop("cache_dir", None)
|
167
170
|
force_download = kwargs.pop("force_download", False)
|
168
171
|
proxies = kwargs.pop("proxies", None)
|
@@ -243,20 +246,29 @@ class PeftAdapterMixin:
|
|
243
246
|
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
|
244
247
|
}
|
245
248
|
|
246
|
-
# create LoraConfig
|
247
|
-
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
|
248
|
-
|
249
249
|
# adapter_name
|
250
250
|
if adapter_name is None:
|
251
251
|
adapter_name = get_adapter_name(self)
|
252
252
|
|
253
|
+
# create LoraConfig
|
254
|
+
lora_config = _create_lora_config(
|
255
|
+
state_dict,
|
256
|
+
network_alphas,
|
257
|
+
metadata,
|
258
|
+
rank,
|
259
|
+
model_state_dict=self.state_dict(),
|
260
|
+
adapter_name=adapter_name,
|
261
|
+
)
|
262
|
+
|
253
263
|
# <Unsafe code
|
254
264
|
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
|
255
265
|
# Now we remove any existing hooks to `_pipeline`.
|
256
266
|
|
257
267
|
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
258
268
|
# otherwise loading LoRA weights will lead to an error.
|
259
|
-
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(
|
269
|
+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
|
270
|
+
_pipeline
|
271
|
+
)
|
260
272
|
peft_kwargs = {}
|
261
273
|
if is_peft_version(">=", "0.13.1"):
|
262
274
|
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
@@ -308,7 +320,9 @@ class PeftAdapterMixin:
|
|
308
320
|
# it to None
|
309
321
|
incompatible_keys = None
|
310
322
|
else:
|
311
|
-
inject_adapter_in_model(
|
323
|
+
inject_adapter_in_model(
|
324
|
+
lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs
|
325
|
+
)
|
312
326
|
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
313
327
|
|
314
328
|
if self._prepare_lora_hotswap_kwargs is not None:
|
@@ -347,6 +361,10 @@ class PeftAdapterMixin:
|
|
347
361
|
_pipeline.enable_model_cpu_offload()
|
348
362
|
elif is_sequential_cpu_offload:
|
349
363
|
_pipeline.enable_sequential_cpu_offload()
|
364
|
+
elif is_group_offload:
|
365
|
+
for component in _pipeline.components.values():
|
366
|
+
if isinstance(component, torch.nn.Module):
|
367
|
+
_maybe_remove_and_reapply_group_offloading(component)
|
350
368
|
# Unsafe code />
|
351
369
|
|
352
370
|
if prefix is not None and not state_dict:
|
@@ -681,11 +699,16 @@ class PeftAdapterMixin:
|
|
681
699
|
if not USE_PEFT_BACKEND:
|
682
700
|
raise ValueError("PEFT backend is required for `unload_lora()`.")
|
683
701
|
|
702
|
+
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
684
703
|
from ..utils import recurse_remove_peft_layers
|
685
704
|
|
686
705
|
recurse_remove_peft_layers(self)
|
687
706
|
if hasattr(self, "peft_config"):
|
688
707
|
del self.peft_config
|
708
|
+
if hasattr(self, "_hf_peft_config_loaded"):
|
709
|
+
self._hf_peft_config_loaded = None
|
710
|
+
|
711
|
+
_maybe_remove_and_reapply_group_offloading(self)
|
689
712
|
|
690
713
|
def disable_lora(self):
|
691
714
|
"""
|
@@ -23,7 +23,8 @@ from typing_extensions import Self
|
|
23
23
|
|
24
24
|
from .. import __version__
|
25
25
|
from ..quantizers import DiffusersAutoQuantizer
|
26
|
-
from ..utils import deprecate, is_accelerate_available, logging
|
26
|
+
from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
|
27
|
+
from ..utils.torch_utils import empty_device_cache
|
27
28
|
from .single_file_utils import (
|
28
29
|
SingleFileComponentError,
|
29
30
|
convert_animatediff_checkpoint_to_diffusers,
|
@@ -31,6 +32,7 @@ from .single_file_utils import (
|
|
31
32
|
convert_autoencoder_dc_checkpoint_to_diffusers,
|
32
33
|
convert_chroma_transformer_checkpoint_to_diffusers,
|
33
34
|
convert_controlnet_checkpoint,
|
35
|
+
convert_cosmos_transformer_checkpoint_to_diffusers,
|
34
36
|
convert_flux_transformer_checkpoint_to_diffusers,
|
35
37
|
convert_hidream_transformer_to_diffusers,
|
36
38
|
convert_hunyuan_video_transformer_to_diffusers,
|
@@ -60,8 +62,12 @@ logger = logging.get_logger(__name__)
|
|
60
62
|
if is_accelerate_available():
|
61
63
|
from accelerate import dispatch_model, init_empty_weights
|
62
64
|
|
63
|
-
from ..models.
|
65
|
+
from ..models.model_loading_utils import load_model_dict_into_meta
|
64
66
|
|
67
|
+
if is_torch_version(">=", "1.9.0") and is_accelerate_available():
|
68
|
+
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
69
|
+
else:
|
70
|
+
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
65
71
|
|
66
72
|
SINGLE_FILE_LOADABLE_CLASSES = {
|
67
73
|
"StableCascadeUNet": {
|
@@ -135,6 +141,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
|
135
141
|
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
136
142
|
"default_subfolder": "transformer",
|
137
143
|
},
|
144
|
+
"WanVACETransformer3DModel": {
|
145
|
+
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
146
|
+
"default_subfolder": "transformer",
|
147
|
+
},
|
138
148
|
"AutoencoderKLWan": {
|
139
149
|
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
140
150
|
"default_subfolder": "vae",
|
@@ -143,9 +153,21 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
|
143
153
|
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
|
144
154
|
"default_subfolder": "transformer",
|
145
155
|
},
|
156
|
+
"CosmosTransformer3DModel": {
|
157
|
+
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
|
158
|
+
"default_subfolder": "transformer",
|
159
|
+
},
|
160
|
+
"QwenImageTransformer2DModel": {
|
161
|
+
"checkpoint_mapping_fn": lambda x: x,
|
162
|
+
"default_subfolder": "transformer",
|
163
|
+
},
|
146
164
|
}
|
147
165
|
|
148
166
|
|
167
|
+
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
|
168
|
+
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
|
169
|
+
|
170
|
+
|
149
171
|
def _get_single_file_loadable_mapping_class(cls):
|
150
172
|
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
151
173
|
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
@@ -218,6 +240,11 @@ class FromOriginalModelMixin:
|
|
218
240
|
revision (`str`, *optional*, defaults to `"main"`):
|
219
241
|
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
220
242
|
allowed by Git.
|
243
|
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 and
|
244
|
+
is_accelerate_available() else `False`): Speed up model loading only loading the pretrained weights and
|
245
|
+
not initializing the weights. This also tries to not use more than 1x model size in CPU memory
|
246
|
+
(including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using
|
247
|
+
an older version of PyTorch, setting this argument to `True` will raise an error.
|
221
248
|
disable_mmap ('bool', *optional*, defaults to 'False'):
|
222
249
|
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
223
250
|
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
@@ -267,6 +294,7 @@ class FromOriginalModelMixin:
|
|
267
294
|
config_revision = kwargs.pop("config_revision", None)
|
268
295
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
269
296
|
quantization_config = kwargs.pop("quantization_config", None)
|
297
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
270
298
|
device = kwargs.pop("device", None)
|
271
299
|
disable_mmap = kwargs.pop("disable_mmap", False)
|
272
300
|
|
@@ -371,19 +399,23 @@ class FromOriginalModelMixin:
|
|
371
399
|
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
|
372
400
|
diffusers_model_config.update(model_kwargs)
|
373
401
|
|
402
|
+
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
403
|
+
with ctx():
|
404
|
+
model = cls.from_config(diffusers_model_config)
|
405
|
+
|
374
406
|
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
375
|
-
|
376
|
-
|
377
|
-
|
407
|
+
|
408
|
+
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
|
409
|
+
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
410
|
+
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
411
|
+
)
|
412
|
+
else:
|
413
|
+
diffusers_format_checkpoint = checkpoint
|
414
|
+
|
378
415
|
if not diffusers_format_checkpoint:
|
379
416
|
raise SingleFileComponentError(
|
380
417
|
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
381
418
|
)
|
382
|
-
|
383
|
-
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
384
|
-
with ctx():
|
385
|
-
model = cls.from_config(diffusers_model_config)
|
386
|
-
|
387
419
|
# Check if `_keep_in_fp32_modules` is not None
|
388
420
|
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
389
421
|
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
@@ -405,7 +437,7 @@ class FromOriginalModelMixin:
|
|
405
437
|
)
|
406
438
|
|
407
439
|
device_map = None
|
408
|
-
if
|
440
|
+
if low_cpu_mem_usage:
|
409
441
|
param_device = torch.device(device) if device else torch.device("cpu")
|
410
442
|
empty_state_dict = model.state_dict()
|
411
443
|
unexpected_keys = [
|
@@ -421,6 +453,7 @@ class FromOriginalModelMixin:
|
|
421
453
|
keep_in_fp32_modules=keep_in_fp32_modules,
|
422
454
|
unexpected_keys=unexpected_keys,
|
423
455
|
)
|
456
|
+
empty_device_cache()
|
424
457
|
else:
|
425
458
|
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
426
459
|
|
@@ -46,6 +46,7 @@ from ..utils import (
|
|
46
46
|
)
|
47
47
|
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
48
48
|
from ..utils.hub_utils import _get_model_file
|
49
|
+
from ..utils.torch_utils import empty_device_cache
|
49
50
|
|
50
51
|
|
51
52
|
if is_transformers_available():
|
@@ -54,11 +55,12 @@ if is_transformers_available():
|
|
54
55
|
if is_accelerate_available():
|
55
56
|
from accelerate import init_empty_weights
|
56
57
|
|
57
|
-
from ..models.
|
58
|
+
from ..models.model_loading_utils import load_model_dict_into_meta
|
58
59
|
|
59
60
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
60
61
|
|
61
62
|
CHECKPOINT_KEY_NAMES = {
|
63
|
+
"v1": "model.diffusion_model.output_blocks.11.0.skip_connection.weight",
|
62
64
|
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
63
65
|
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
|
64
66
|
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
|
@@ -126,7 +128,18 @@ CHECKPOINT_KEY_NAMES = {
|
|
126
128
|
],
|
127
129
|
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
128
130
|
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
131
|
+
"wan_vace": "vace_blocks.0.after_proj.bias",
|
129
132
|
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
|
133
|
+
"cosmos-1.0": [
|
134
|
+
"net.x_embedder.proj.1.weight",
|
135
|
+
"net.blocks.block1.blocks.0.block.attn.to_q.0.weight",
|
136
|
+
"net.extra_pos_embedder.pos_emb_h",
|
137
|
+
],
|
138
|
+
"cosmos-2.0": [
|
139
|
+
"net.x_embedder.proj.1.weight",
|
140
|
+
"net.blocks.0.self_attn.q_proj.weight",
|
141
|
+
"net.pos_embedder.dim_spatial_range",
|
142
|
+
],
|
130
143
|
}
|
131
144
|
|
132
145
|
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
@@ -192,7 +205,17 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
192
205
|
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
|
193
206
|
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
194
207
|
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
208
|
+
"wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
|
209
|
+
"wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
|
195
210
|
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
|
211
|
+
"cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"},
|
212
|
+
"cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"},
|
213
|
+
"cosmos-1.0-v2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"},
|
214
|
+
"cosmos-1.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Video2World"},
|
215
|
+
"cosmos-2.0-t2i-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Text2Image"},
|
216
|
+
"cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"},
|
217
|
+
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
|
218
|
+
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
|
196
219
|
}
|
197
220
|
|
198
221
|
# Use to configure model sample size when original config is provided
|
@@ -698,17 +721,44 @@ def infer_diffusers_model_type(checkpoint):
|
|
698
721
|
else:
|
699
722
|
target_key = "patch_embedding.weight"
|
700
723
|
|
701
|
-
if
|
724
|
+
if CHECKPOINT_KEY_NAMES["wan_vace"] in checkpoint:
|
725
|
+
if checkpoint[target_key].shape[0] == 1536:
|
726
|
+
model_type = "wan-vace-1.3B"
|
727
|
+
elif checkpoint[target_key].shape[0] == 5120:
|
728
|
+
model_type = "wan-vace-14B"
|
729
|
+
|
730
|
+
elif checkpoint[target_key].shape[0] == 1536:
|
702
731
|
model_type = "wan-t2v-1.3B"
|
703
732
|
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
|
704
733
|
model_type = "wan-t2v-14B"
|
705
734
|
else:
|
706
735
|
model_type = "wan-i2v-14B"
|
736
|
+
|
707
737
|
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
|
708
738
|
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
|
709
739
|
model_type = "wan-t2v-14B"
|
740
|
+
|
710
741
|
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
|
711
742
|
model_type = "hidream"
|
743
|
+
|
744
|
+
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-1.0"]):
|
745
|
+
x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-1.0"][0]].shape
|
746
|
+
if x_embedder_shape[1] == 68:
|
747
|
+
model_type = "cosmos-1.0-t2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-t2w-14B"
|
748
|
+
elif x_embedder_shape[1] == 72:
|
749
|
+
model_type = "cosmos-1.0-v2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-v2w-14B"
|
750
|
+
else:
|
751
|
+
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 1.0 model.")
|
752
|
+
|
753
|
+
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-2.0"]):
|
754
|
+
x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-2.0"][0]].shape
|
755
|
+
if x_embedder_shape[1] == 68:
|
756
|
+
model_type = "cosmos-2.0-t2i-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-t2i-14B"
|
757
|
+
elif x_embedder_shape[1] == 72:
|
758
|
+
model_type = "cosmos-2.0-v2w-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-v2w-14B"
|
759
|
+
else:
|
760
|
+
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
|
761
|
+
|
712
762
|
else:
|
713
763
|
model_type = "v1"
|
714
764
|
|
@@ -1641,6 +1691,7 @@ def create_diffusers_clip_model_from_ldm(
|
|
1641
1691
|
|
1642
1692
|
if is_accelerate_available():
|
1643
1693
|
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
1694
|
+
empty_device_cache()
|
1644
1695
|
else:
|
1645
1696
|
model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
1646
1697
|
|
@@ -2100,6 +2151,7 @@ def create_diffusers_t5_model_from_checkpoint(
|
|
2100
2151
|
|
2101
2152
|
if is_accelerate_available():
|
2102
2153
|
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
2154
|
+
empty_device_cache()
|
2103
2155
|
else:
|
2104
2156
|
model.load_state_dict(diffusers_format_checkpoint)
|
2105
2157
|
|
@@ -3093,6 +3145,9 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
|
|
3093
3145
|
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
3094
3146
|
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
3095
3147
|
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
3148
|
+
# For the VACE model
|
3149
|
+
"before_proj": "proj_in",
|
3150
|
+
"after_proj": "proj_out",
|
3096
3151
|
}
|
3097
3152
|
|
3098
3153
|
for key in list(checkpoint.keys()):
|
@@ -3479,3 +3534,116 @@ def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
|
3479
3534
|
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
3480
3535
|
|
3481
3536
|
return converted_state_dict
|
3537
|
+
|
3538
|
+
|
3539
|
+
def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
3540
|
+
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
3541
|
+
|
3542
|
+
def remove_keys_(key: str, state_dict):
|
3543
|
+
state_dict.pop(key)
|
3544
|
+
|
3545
|
+
def rename_transformer_blocks_(key: str, state_dict):
|
3546
|
+
block_index = int(key.split(".")[1].removeprefix("block"))
|
3547
|
+
new_key = key
|
3548
|
+
old_prefix = f"blocks.block{block_index}"
|
3549
|
+
new_prefix = f"transformer_blocks.{block_index}"
|
3550
|
+
new_key = new_prefix + new_key.removeprefix(old_prefix)
|
3551
|
+
state_dict[new_key] = state_dict.pop(key)
|
3552
|
+
|
3553
|
+
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
|
3554
|
+
"t_embedder.1": "time_embed.t_embedder",
|
3555
|
+
"affline_norm": "time_embed.norm",
|
3556
|
+
".blocks.0.block.attn": ".attn1",
|
3557
|
+
".blocks.1.block.attn": ".attn2",
|
3558
|
+
".blocks.2.block": ".ff",
|
3559
|
+
".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
|
3560
|
+
".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
|
3561
|
+
".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
|
3562
|
+
".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
|
3563
|
+
".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
|
3564
|
+
".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
|
3565
|
+
"to_q.0": "to_q",
|
3566
|
+
"to_q.1": "norm_q",
|
3567
|
+
"to_k.0": "to_k",
|
3568
|
+
"to_k.1": "norm_k",
|
3569
|
+
"to_v.0": "to_v",
|
3570
|
+
"layer1": "net.0.proj",
|
3571
|
+
"layer2": "net.2",
|
3572
|
+
"proj.1": "proj",
|
3573
|
+
"x_embedder": "patch_embed",
|
3574
|
+
"extra_pos_embedder": "learnable_pos_embed",
|
3575
|
+
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
|
3576
|
+
"final_layer.adaLN_modulation.2": "norm_out.linear_2",
|
3577
|
+
"final_layer.linear": "proj_out",
|
3578
|
+
}
|
3579
|
+
|
3580
|
+
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
|
3581
|
+
"blocks.block": rename_transformer_blocks_,
|
3582
|
+
"logvar.0.freqs": remove_keys_,
|
3583
|
+
"logvar.0.phases": remove_keys_,
|
3584
|
+
"logvar.1.weight": remove_keys_,
|
3585
|
+
"pos_embedder.seq": remove_keys_,
|
3586
|
+
}
|
3587
|
+
|
3588
|
+
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
|
3589
|
+
"t_embedder.1": "time_embed.t_embedder",
|
3590
|
+
"t_embedding_norm": "time_embed.norm",
|
3591
|
+
"blocks": "transformer_blocks",
|
3592
|
+
"adaln_modulation_self_attn.1": "norm1.linear_1",
|
3593
|
+
"adaln_modulation_self_attn.2": "norm1.linear_2",
|
3594
|
+
"adaln_modulation_cross_attn.1": "norm2.linear_1",
|
3595
|
+
"adaln_modulation_cross_attn.2": "norm2.linear_2",
|
3596
|
+
"adaln_modulation_mlp.1": "norm3.linear_1",
|
3597
|
+
"adaln_modulation_mlp.2": "norm3.linear_2",
|
3598
|
+
"self_attn": "attn1",
|
3599
|
+
"cross_attn": "attn2",
|
3600
|
+
"q_proj": "to_q",
|
3601
|
+
"k_proj": "to_k",
|
3602
|
+
"v_proj": "to_v",
|
3603
|
+
"output_proj": "to_out.0",
|
3604
|
+
"q_norm": "norm_q",
|
3605
|
+
"k_norm": "norm_k",
|
3606
|
+
"mlp.layer1": "ff.net.0.proj",
|
3607
|
+
"mlp.layer2": "ff.net.2",
|
3608
|
+
"x_embedder.proj.1": "patch_embed.proj",
|
3609
|
+
"final_layer.adaln_modulation.1": "norm_out.linear_1",
|
3610
|
+
"final_layer.adaln_modulation.2": "norm_out.linear_2",
|
3611
|
+
"final_layer.linear": "proj_out",
|
3612
|
+
}
|
3613
|
+
|
3614
|
+
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
|
3615
|
+
"accum_video_sample_counter": remove_keys_,
|
3616
|
+
"accum_image_sample_counter": remove_keys_,
|
3617
|
+
"accum_iteration": remove_keys_,
|
3618
|
+
"accum_train_in_hours": remove_keys_,
|
3619
|
+
"pos_embedder.seq": remove_keys_,
|
3620
|
+
"pos_embedder.dim_spatial_range": remove_keys_,
|
3621
|
+
"pos_embedder.dim_temporal_range": remove_keys_,
|
3622
|
+
"_extra_state": remove_keys_,
|
3623
|
+
}
|
3624
|
+
|
3625
|
+
PREFIX_KEY = "net."
|
3626
|
+
if "net.blocks.block1.blocks.0.block.attn.to_q.0.weight" in checkpoint:
|
3627
|
+
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
|
3628
|
+
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
|
3629
|
+
else:
|
3630
|
+
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
|
3631
|
+
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
|
3632
|
+
|
3633
|
+
state_dict_keys = list(converted_state_dict.keys())
|
3634
|
+
for key in state_dict_keys:
|
3635
|
+
new_key = key[:]
|
3636
|
+
if new_key.startswith(PREFIX_KEY):
|
3637
|
+
new_key = new_key.removeprefix(PREFIX_KEY)
|
3638
|
+
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
3639
|
+
new_key = new_key.replace(replace_key, rename_key)
|
3640
|
+
converted_state_dict[new_key] = converted_state_dict.pop(key)
|
3641
|
+
|
3642
|
+
state_dict_keys = list(converted_state_dict.keys())
|
3643
|
+
for key in state_dict_keys:
|
3644
|
+
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
3645
|
+
if special_key not in key:
|
3646
|
+
continue
|
3647
|
+
handler_fn_inplace(key, converted_state_dict)
|
3648
|
+
|
3649
|
+
return converted_state_dict
|
@@ -17,12 +17,10 @@ from ..models.embeddings import (
|
|
17
17
|
ImageProjection,
|
18
18
|
MultiIPAdapterImageProjection,
|
19
19
|
)
|
20
|
-
from ..models.
|
21
|
-
from ..
|
22
|
-
|
23
|
-
|
24
|
-
logging,
|
25
|
-
)
|
20
|
+
from ..models.model_loading_utils import load_model_dict_into_meta
|
21
|
+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
22
|
+
from ..utils import is_accelerate_available, is_torch_version, logging
|
23
|
+
from ..utils.torch_utils import empty_device_cache
|
26
24
|
|
27
25
|
|
28
26
|
if is_accelerate_available():
|
@@ -84,13 +82,12 @@ class FluxTransformer2DLoadersMixin:
|
|
84
82
|
else:
|
85
83
|
device_map = {"": self.device}
|
86
84
|
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
85
|
+
empty_device_cache()
|
87
86
|
|
88
87
|
return image_projection
|
89
88
|
|
90
89
|
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
91
|
-
from ..models.
|
92
|
-
FluxIPAdapterJointAttnProcessor2_0,
|
93
|
-
)
|
90
|
+
from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
|
94
91
|
|
95
92
|
if low_cpu_mem_usage:
|
96
93
|
if is_accelerate_available():
|
@@ -122,7 +119,7 @@ class FluxTransformer2DLoadersMixin:
|
|
122
119
|
else:
|
123
120
|
cross_attention_dim = self.config.joint_attention_dim
|
124
121
|
hidden_size = self.inner_dim
|
125
|
-
attn_processor_class =
|
122
|
+
attn_processor_class = FluxIPAdapterAttnProcessor
|
126
123
|
num_image_text_embeds = []
|
127
124
|
for state_dict in state_dicts:
|
128
125
|
if "proj.weight" in state_dict["image_proj"]:
|
@@ -158,6 +155,8 @@ class FluxTransformer2DLoadersMixin:
|
|
158
155
|
|
159
156
|
key_id += 1
|
160
157
|
|
158
|
+
empty_device_cache()
|
159
|
+
|
161
160
|
return attn_procs
|
162
161
|
|
163
162
|
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
@@ -16,8 +16,10 @@ from typing import Dict
|
|
16
16
|
|
17
17
|
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
18
18
|
from ..models.embeddings import IPAdapterTimeImageProjection
|
19
|
-
from ..models.
|
19
|
+
from ..models.model_loading_utils import load_model_dict_into_meta
|
20
|
+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
20
21
|
from ..utils import is_accelerate_available, is_torch_version, logging
|
22
|
+
from ..utils.torch_utils import empty_device_cache
|
21
23
|
|
22
24
|
|
23
25
|
logger = logging.get_logger(__name__)
|
@@ -80,6 +82,8 @@ class SD3Transformer2DLoadersMixin:
|
|
80
82
|
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
|
81
83
|
)
|
82
84
|
|
85
|
+
empty_device_cache()
|
86
|
+
|
83
87
|
return attn_procs
|
84
88
|
|
85
89
|
def _convert_ip_adapter_image_proj_to_diffusers(
|
@@ -147,6 +151,7 @@ class SD3Transformer2DLoadersMixin:
|
|
147
151
|
else:
|
148
152
|
device_map = {"": self.device}
|
149
153
|
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
154
|
+
empty_device_cache()
|
150
155
|
|
151
156
|
return image_proj
|
152
157
|
|
diffusers/loaders/unet.py
CHANGED
@@ -30,7 +30,8 @@ from ..models.embeddings import (
|
|
30
30
|
IPAdapterPlusImageProjection,
|
31
31
|
MultiIPAdapterImageProjection,
|
32
32
|
)
|
33
|
-
from ..models.
|
33
|
+
from ..models.model_loading_utils import load_model_dict_into_meta
|
34
|
+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
34
35
|
from ..utils import (
|
35
36
|
USE_PEFT_BACKEND,
|
36
37
|
_get_model_file,
|
@@ -43,6 +44,7 @@ from ..utils import (
|
|
43
44
|
is_torch_version,
|
44
45
|
logging,
|
45
46
|
)
|
47
|
+
from ..utils.torch_utils import empty_device_cache
|
46
48
|
from .lora_base import _func_optionally_disable_offloading
|
47
49
|
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
48
50
|
from .utils import AttnProcsLayers
|
@@ -131,6 +133,8 @@ class UNet2DConditionLoadersMixin:
|
|
131
133
|
)
|
132
134
|
```
|
133
135
|
"""
|
136
|
+
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
137
|
+
|
134
138
|
cache_dir = kwargs.pop("cache_dir", None)
|
135
139
|
force_download = kwargs.pop("force_download", False)
|
136
140
|
proxies = kwargs.pop("proxies", None)
|
@@ -203,6 +207,7 @@ class UNet2DConditionLoadersMixin:
|
|
203
207
|
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
|
204
208
|
is_model_cpu_offload = False
|
205
209
|
is_sequential_cpu_offload = False
|
210
|
+
is_group_offload = False
|
206
211
|
|
207
212
|
if is_lora:
|
208
213
|
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
|
@@ -211,7 +216,7 @@ class UNet2DConditionLoadersMixin:
|
|
211
216
|
if is_custom_diffusion:
|
212
217
|
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
|
213
218
|
elif is_lora:
|
214
|
-
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
|
219
|
+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
|
215
220
|
state_dict=state_dict,
|
216
221
|
unet_identifier_key=self.unet_name,
|
217
222
|
network_alphas=network_alphas,
|
@@ -230,7 +235,9 @@ class UNet2DConditionLoadersMixin:
|
|
230
235
|
|
231
236
|
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
|
232
237
|
if is_custom_diffusion and _pipeline is not None:
|
233
|
-
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(
|
238
|
+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
|
239
|
+
_pipeline=_pipeline
|
240
|
+
)
|
234
241
|
|
235
242
|
# only custom diffusion needs to set attn processors
|
236
243
|
self.set_attn_processor(attn_processors)
|
@@ -241,6 +248,10 @@ class UNet2DConditionLoadersMixin:
|
|
241
248
|
_pipeline.enable_model_cpu_offload()
|
242
249
|
elif is_sequential_cpu_offload:
|
243
250
|
_pipeline.enable_sequential_cpu_offload()
|
251
|
+
elif is_group_offload:
|
252
|
+
for component in _pipeline.components.values():
|
253
|
+
if isinstance(component, torch.nn.Module):
|
254
|
+
_maybe_remove_and_reapply_group_offloading(component)
|
244
255
|
# Unsafe code />
|
245
256
|
|
246
257
|
def _process_custom_diffusion(self, state_dict):
|
@@ -307,6 +318,7 @@ class UNet2DConditionLoadersMixin:
|
|
307
318
|
|
308
319
|
is_model_cpu_offload = False
|
309
320
|
is_sequential_cpu_offload = False
|
321
|
+
is_group_offload = False
|
310
322
|
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
|
311
323
|
|
312
324
|
if len(state_dict_to_be_used) > 0:
|
@@ -356,7 +368,9 @@ class UNet2DConditionLoadersMixin:
|
|
356
368
|
|
357
369
|
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
358
370
|
# otherwise loading LoRA weights will lead to an error
|
359
|
-
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(
|
371
|
+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
|
372
|
+
_pipeline
|
373
|
+
)
|
360
374
|
peft_kwargs = {}
|
361
375
|
if is_peft_version(">=", "0.13.1"):
|
362
376
|
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
@@ -389,7 +403,7 @@ class UNet2DConditionLoadersMixin:
|
|
389
403
|
if warn_msg:
|
390
404
|
logger.warning(warn_msg)
|
391
405
|
|
392
|
-
return is_model_cpu_offload, is_sequential_cpu_offload
|
406
|
+
return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload
|
393
407
|
|
394
408
|
@classmethod
|
395
409
|
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
@@ -741,6 +755,7 @@ class UNet2DConditionLoadersMixin:
|
|
741
755
|
else:
|
742
756
|
device_map = {"": self.device}
|
743
757
|
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
758
|
+
empty_device_cache()
|
744
759
|
|
745
760
|
return image_projection
|
746
761
|
|
@@ -838,6 +853,8 @@ class UNet2DConditionLoadersMixin:
|
|
838
853
|
|
839
854
|
key_id += 2
|
840
855
|
|
856
|
+
empty_device_cache()
|
857
|
+
|
841
858
|
return attn_procs
|
842
859
|
|
843
860
|
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|