diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
diffusers/loaders/peft.py CHANGED
@@ -61,6 +61,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
61
61
  "HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
62
62
  "WanVACETransformer3DModel": lambda model_cls, weights: weights,
63
63
  "ChromaTransformer2DModel": lambda model_cls, weights: weights,
64
+ "QwenImageTransformer2DModel": lambda model_cls, weights: weights,
64
65
  }
65
66
 
66
67
 
@@ -163,6 +164,8 @@ class PeftAdapterMixin:
163
164
  from peft import inject_adapter_in_model, set_peft_model_state_dict
164
165
  from peft.tuners.tuners_utils import BaseTunerLayer
165
166
 
167
+ from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
168
+
166
169
  cache_dir = kwargs.pop("cache_dir", None)
167
170
  force_download = kwargs.pop("force_download", False)
168
171
  proxies = kwargs.pop("proxies", None)
@@ -243,20 +246,29 @@ class PeftAdapterMixin:
243
246
  k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
244
247
  }
245
248
 
246
- # create LoraConfig
247
- lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
248
-
249
249
  # adapter_name
250
250
  if adapter_name is None:
251
251
  adapter_name = get_adapter_name(self)
252
252
 
253
+ # create LoraConfig
254
+ lora_config = _create_lora_config(
255
+ state_dict,
256
+ network_alphas,
257
+ metadata,
258
+ rank,
259
+ model_state_dict=self.state_dict(),
260
+ adapter_name=adapter_name,
261
+ )
262
+
253
263
  # <Unsafe code
254
264
  # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
255
265
  # Now we remove any existing hooks to `_pipeline`.
256
266
 
257
267
  # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
258
268
  # otherwise loading LoRA weights will lead to an error.
259
- is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
269
+ is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
270
+ _pipeline
271
+ )
260
272
  peft_kwargs = {}
261
273
  if is_peft_version(">=", "0.13.1"):
262
274
  peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -308,7 +320,9 @@ class PeftAdapterMixin:
308
320
  # it to None
309
321
  incompatible_keys = None
310
322
  else:
311
- inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
323
+ inject_adapter_in_model(
324
+ lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs
325
+ )
312
326
  incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
313
327
 
314
328
  if self._prepare_lora_hotswap_kwargs is not None:
@@ -347,6 +361,10 @@ class PeftAdapterMixin:
347
361
  _pipeline.enable_model_cpu_offload()
348
362
  elif is_sequential_cpu_offload:
349
363
  _pipeline.enable_sequential_cpu_offload()
364
+ elif is_group_offload:
365
+ for component in _pipeline.components.values():
366
+ if isinstance(component, torch.nn.Module):
367
+ _maybe_remove_and_reapply_group_offloading(component)
350
368
  # Unsafe code />
351
369
 
352
370
  if prefix is not None and not state_dict:
@@ -681,11 +699,16 @@ class PeftAdapterMixin:
681
699
  if not USE_PEFT_BACKEND:
682
700
  raise ValueError("PEFT backend is required for `unload_lora()`.")
683
701
 
702
+ from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
684
703
  from ..utils import recurse_remove_peft_layers
685
704
 
686
705
  recurse_remove_peft_layers(self)
687
706
  if hasattr(self, "peft_config"):
688
707
  del self.peft_config
708
+ if hasattr(self, "_hf_peft_config_loaded"):
709
+ self._hf_peft_config_loaded = None
710
+
711
+ _maybe_remove_and_reapply_group_offloading(self)
689
712
 
690
713
  def disable_lora(self):
691
714
  """
@@ -23,7 +23,8 @@ from typing_extensions import Self
23
23
 
24
24
  from .. import __version__
25
25
  from ..quantizers import DiffusersAutoQuantizer
26
- from ..utils import deprecate, is_accelerate_available, logging
26
+ from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
27
+ from ..utils.torch_utils import empty_device_cache
27
28
  from .single_file_utils import (
28
29
  SingleFileComponentError,
29
30
  convert_animatediff_checkpoint_to_diffusers,
@@ -31,6 +32,7 @@ from .single_file_utils import (
31
32
  convert_autoencoder_dc_checkpoint_to_diffusers,
32
33
  convert_chroma_transformer_checkpoint_to_diffusers,
33
34
  convert_controlnet_checkpoint,
35
+ convert_cosmos_transformer_checkpoint_to_diffusers,
34
36
  convert_flux_transformer_checkpoint_to_diffusers,
35
37
  convert_hidream_transformer_to_diffusers,
36
38
  convert_hunyuan_video_transformer_to_diffusers,
@@ -60,8 +62,12 @@ logger = logging.get_logger(__name__)
60
62
  if is_accelerate_available():
61
63
  from accelerate import dispatch_model, init_empty_weights
62
64
 
63
- from ..models.modeling_utils import load_model_dict_into_meta
65
+ from ..models.model_loading_utils import load_model_dict_into_meta
64
66
 
67
+ if is_torch_version(">=", "1.9.0") and is_accelerate_available():
68
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
69
+ else:
70
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
65
71
 
66
72
  SINGLE_FILE_LOADABLE_CLASSES = {
67
73
  "StableCascadeUNet": {
@@ -135,6 +141,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
135
141
  "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
136
142
  "default_subfolder": "transformer",
137
143
  },
144
+ "WanVACETransformer3DModel": {
145
+ "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
146
+ "default_subfolder": "transformer",
147
+ },
138
148
  "AutoencoderKLWan": {
139
149
  "checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
140
150
  "default_subfolder": "vae",
@@ -143,9 +153,21 @@ SINGLE_FILE_LOADABLE_CLASSES = {
143
153
  "checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
144
154
  "default_subfolder": "transformer",
145
155
  },
156
+ "CosmosTransformer3DModel": {
157
+ "checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
158
+ "default_subfolder": "transformer",
159
+ },
160
+ "QwenImageTransformer2DModel": {
161
+ "checkpoint_mapping_fn": lambda x: x,
162
+ "default_subfolder": "transformer",
163
+ },
146
164
  }
147
165
 
148
166
 
167
+ def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
168
+ return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
169
+
170
+
149
171
  def _get_single_file_loadable_mapping_class(cls):
150
172
  diffusers_module = importlib.import_module(__name__.split(".")[0])
151
173
  for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
@@ -218,6 +240,11 @@ class FromOriginalModelMixin:
218
240
  revision (`str`, *optional*, defaults to `"main"`):
219
241
  The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
220
242
  allowed by Git.
243
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 and
244
+ is_accelerate_available() else `False`): Speed up model loading only loading the pretrained weights and
245
+ not initializing the weights. This also tries to not use more than 1x model size in CPU memory
246
+ (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using
247
+ an older version of PyTorch, setting this argument to `True` will raise an error.
221
248
  disable_mmap ('bool', *optional*, defaults to 'False'):
222
249
  Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
223
250
  is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
@@ -267,6 +294,7 @@ class FromOriginalModelMixin:
267
294
  config_revision = kwargs.pop("config_revision", None)
268
295
  torch_dtype = kwargs.pop("torch_dtype", None)
269
296
  quantization_config = kwargs.pop("quantization_config", None)
297
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
270
298
  device = kwargs.pop("device", None)
271
299
  disable_mmap = kwargs.pop("disable_mmap", False)
272
300
 
@@ -371,19 +399,23 @@ class FromOriginalModelMixin:
371
399
  model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
372
400
  diffusers_model_config.update(model_kwargs)
373
401
 
402
+ ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
403
+ with ctx():
404
+ model = cls.from_config(diffusers_model_config)
405
+
374
406
  checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
375
- diffusers_format_checkpoint = checkpoint_mapping_fn(
376
- config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
377
- )
407
+
408
+ if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
409
+ diffusers_format_checkpoint = checkpoint_mapping_fn(
410
+ config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
411
+ )
412
+ else:
413
+ diffusers_format_checkpoint = checkpoint
414
+
378
415
  if not diffusers_format_checkpoint:
379
416
  raise SingleFileComponentError(
380
417
  f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
381
418
  )
382
-
383
- ctx = init_empty_weights if is_accelerate_available() else nullcontext
384
- with ctx():
385
- model = cls.from_config(diffusers_model_config)
386
-
387
419
  # Check if `_keep_in_fp32_modules` is not None
388
420
  use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
389
421
  (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
@@ -405,7 +437,7 @@ class FromOriginalModelMixin:
405
437
  )
406
438
 
407
439
  device_map = None
408
- if is_accelerate_available():
440
+ if low_cpu_mem_usage:
409
441
  param_device = torch.device(device) if device else torch.device("cpu")
410
442
  empty_state_dict = model.state_dict()
411
443
  unexpected_keys = [
@@ -421,6 +453,7 @@ class FromOriginalModelMixin:
421
453
  keep_in_fp32_modules=keep_in_fp32_modules,
422
454
  unexpected_keys=unexpected_keys,
423
455
  )
456
+ empty_device_cache()
424
457
  else:
425
458
  _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
426
459
 
@@ -46,6 +46,7 @@ from ..utils import (
46
46
  )
47
47
  from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
48
48
  from ..utils.hub_utils import _get_model_file
49
+ from ..utils.torch_utils import empty_device_cache
49
50
 
50
51
 
51
52
  if is_transformers_available():
@@ -54,11 +55,12 @@ if is_transformers_available():
54
55
  if is_accelerate_available():
55
56
  from accelerate import init_empty_weights
56
57
 
57
- from ..models.modeling_utils import load_model_dict_into_meta
58
+ from ..models.model_loading_utils import load_model_dict_into_meta
58
59
 
59
60
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
60
61
 
61
62
  CHECKPOINT_KEY_NAMES = {
63
+ "v1": "model.diffusion_model.output_blocks.11.0.skip_connection.weight",
62
64
  "v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
63
65
  "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
64
66
  "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
@@ -126,7 +128,18 @@ CHECKPOINT_KEY_NAMES = {
126
128
  ],
127
129
  "wan": ["model.diffusion_model.head.modulation", "head.modulation"],
128
130
  "wan_vae": "decoder.middle.0.residual.0.gamma",
131
+ "wan_vace": "vace_blocks.0.after_proj.bias",
129
132
  "hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
133
+ "cosmos-1.0": [
134
+ "net.x_embedder.proj.1.weight",
135
+ "net.blocks.block1.blocks.0.block.attn.to_q.0.weight",
136
+ "net.extra_pos_embedder.pos_emb_h",
137
+ ],
138
+ "cosmos-2.0": [
139
+ "net.x_embedder.proj.1.weight",
140
+ "net.blocks.0.self_attn.q_proj.weight",
141
+ "net.pos_embedder.dim_spatial_range",
142
+ ],
130
143
  }
131
144
 
132
145
  DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -192,7 +205,17 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
192
205
  "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
193
206
  "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
194
207
  "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
208
+ "wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
209
+ "wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
195
210
  "hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
211
+ "cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"},
212
+ "cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"},
213
+ "cosmos-1.0-v2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"},
214
+ "cosmos-1.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Video2World"},
215
+ "cosmos-2.0-t2i-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Text2Image"},
216
+ "cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"},
217
+ "cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
218
+ "cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
196
219
  }
197
220
 
198
221
  # Use to configure model sample size when original config is provided
@@ -698,17 +721,44 @@ def infer_diffusers_model_type(checkpoint):
698
721
  else:
699
722
  target_key = "patch_embedding.weight"
700
723
 
701
- if checkpoint[target_key].shape[0] == 1536:
724
+ if CHECKPOINT_KEY_NAMES["wan_vace"] in checkpoint:
725
+ if checkpoint[target_key].shape[0] == 1536:
726
+ model_type = "wan-vace-1.3B"
727
+ elif checkpoint[target_key].shape[0] == 5120:
728
+ model_type = "wan-vace-14B"
729
+
730
+ elif checkpoint[target_key].shape[0] == 1536:
702
731
  model_type = "wan-t2v-1.3B"
703
732
  elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
704
733
  model_type = "wan-t2v-14B"
705
734
  else:
706
735
  model_type = "wan-i2v-14B"
736
+
707
737
  elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
708
738
  # All Wan models use the same VAE so we can use the same default model repo to fetch the config
709
739
  model_type = "wan-t2v-14B"
740
+
710
741
  elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
711
742
  model_type = "hidream"
743
+
744
+ elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-1.0"]):
745
+ x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-1.0"][0]].shape
746
+ if x_embedder_shape[1] == 68:
747
+ model_type = "cosmos-1.0-t2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-t2w-14B"
748
+ elif x_embedder_shape[1] == 72:
749
+ model_type = "cosmos-1.0-v2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-v2w-14B"
750
+ else:
751
+ raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 1.0 model.")
752
+
753
+ elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-2.0"]):
754
+ x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-2.0"][0]].shape
755
+ if x_embedder_shape[1] == 68:
756
+ model_type = "cosmos-2.0-t2i-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-t2i-14B"
757
+ elif x_embedder_shape[1] == 72:
758
+ model_type = "cosmos-2.0-v2w-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-v2w-14B"
759
+ else:
760
+ raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
761
+
712
762
  else:
713
763
  model_type = "v1"
714
764
 
@@ -1641,6 +1691,7 @@ def create_diffusers_clip_model_from_ldm(
1641
1691
 
1642
1692
  if is_accelerate_available():
1643
1693
  load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1694
+ empty_device_cache()
1644
1695
  else:
1645
1696
  model.load_state_dict(diffusers_format_checkpoint, strict=False)
1646
1697
 
@@ -2100,6 +2151,7 @@ def create_diffusers_t5_model_from_checkpoint(
2100
2151
 
2101
2152
  if is_accelerate_available():
2102
2153
  load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
2154
+ empty_device_cache()
2103
2155
  else:
2104
2156
  model.load_state_dict(diffusers_format_checkpoint)
2105
2157
 
@@ -3093,6 +3145,9 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
3093
3145
  "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
3094
3146
  "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
3095
3147
  "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
3148
+ # For the VACE model
3149
+ "before_proj": "proj_in",
3150
+ "after_proj": "proj_out",
3096
3151
  }
3097
3152
 
3098
3153
  for key in list(checkpoint.keys()):
@@ -3479,3 +3534,116 @@ def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
3479
3534
  converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
3480
3535
 
3481
3536
  return converted_state_dict
3537
+
3538
+
3539
+ def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
3540
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
3541
+
3542
+ def remove_keys_(key: str, state_dict):
3543
+ state_dict.pop(key)
3544
+
3545
+ def rename_transformer_blocks_(key: str, state_dict):
3546
+ block_index = int(key.split(".")[1].removeprefix("block"))
3547
+ new_key = key
3548
+ old_prefix = f"blocks.block{block_index}"
3549
+ new_prefix = f"transformer_blocks.{block_index}"
3550
+ new_key = new_prefix + new_key.removeprefix(old_prefix)
3551
+ state_dict[new_key] = state_dict.pop(key)
3552
+
3553
+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
3554
+ "t_embedder.1": "time_embed.t_embedder",
3555
+ "affline_norm": "time_embed.norm",
3556
+ ".blocks.0.block.attn": ".attn1",
3557
+ ".blocks.1.block.attn": ".attn2",
3558
+ ".blocks.2.block": ".ff",
3559
+ ".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
3560
+ ".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
3561
+ ".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
3562
+ ".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
3563
+ ".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
3564
+ ".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
3565
+ "to_q.0": "to_q",
3566
+ "to_q.1": "norm_q",
3567
+ "to_k.0": "to_k",
3568
+ "to_k.1": "norm_k",
3569
+ "to_v.0": "to_v",
3570
+ "layer1": "net.0.proj",
3571
+ "layer2": "net.2",
3572
+ "proj.1": "proj",
3573
+ "x_embedder": "patch_embed",
3574
+ "extra_pos_embedder": "learnable_pos_embed",
3575
+ "final_layer.adaLN_modulation.1": "norm_out.linear_1",
3576
+ "final_layer.adaLN_modulation.2": "norm_out.linear_2",
3577
+ "final_layer.linear": "proj_out",
3578
+ }
3579
+
3580
+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
3581
+ "blocks.block": rename_transformer_blocks_,
3582
+ "logvar.0.freqs": remove_keys_,
3583
+ "logvar.0.phases": remove_keys_,
3584
+ "logvar.1.weight": remove_keys_,
3585
+ "pos_embedder.seq": remove_keys_,
3586
+ }
3587
+
3588
+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
3589
+ "t_embedder.1": "time_embed.t_embedder",
3590
+ "t_embedding_norm": "time_embed.norm",
3591
+ "blocks": "transformer_blocks",
3592
+ "adaln_modulation_self_attn.1": "norm1.linear_1",
3593
+ "adaln_modulation_self_attn.2": "norm1.linear_2",
3594
+ "adaln_modulation_cross_attn.1": "norm2.linear_1",
3595
+ "adaln_modulation_cross_attn.2": "norm2.linear_2",
3596
+ "adaln_modulation_mlp.1": "norm3.linear_1",
3597
+ "adaln_modulation_mlp.2": "norm3.linear_2",
3598
+ "self_attn": "attn1",
3599
+ "cross_attn": "attn2",
3600
+ "q_proj": "to_q",
3601
+ "k_proj": "to_k",
3602
+ "v_proj": "to_v",
3603
+ "output_proj": "to_out.0",
3604
+ "q_norm": "norm_q",
3605
+ "k_norm": "norm_k",
3606
+ "mlp.layer1": "ff.net.0.proj",
3607
+ "mlp.layer2": "ff.net.2",
3608
+ "x_embedder.proj.1": "patch_embed.proj",
3609
+ "final_layer.adaln_modulation.1": "norm_out.linear_1",
3610
+ "final_layer.adaln_modulation.2": "norm_out.linear_2",
3611
+ "final_layer.linear": "proj_out",
3612
+ }
3613
+
3614
+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
3615
+ "accum_video_sample_counter": remove_keys_,
3616
+ "accum_image_sample_counter": remove_keys_,
3617
+ "accum_iteration": remove_keys_,
3618
+ "accum_train_in_hours": remove_keys_,
3619
+ "pos_embedder.seq": remove_keys_,
3620
+ "pos_embedder.dim_spatial_range": remove_keys_,
3621
+ "pos_embedder.dim_temporal_range": remove_keys_,
3622
+ "_extra_state": remove_keys_,
3623
+ }
3624
+
3625
+ PREFIX_KEY = "net."
3626
+ if "net.blocks.block1.blocks.0.block.attn.to_q.0.weight" in checkpoint:
3627
+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
3628
+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
3629
+ else:
3630
+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
3631
+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
3632
+
3633
+ state_dict_keys = list(converted_state_dict.keys())
3634
+ for key in state_dict_keys:
3635
+ new_key = key[:]
3636
+ if new_key.startswith(PREFIX_KEY):
3637
+ new_key = new_key.removeprefix(PREFIX_KEY)
3638
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
3639
+ new_key = new_key.replace(replace_key, rename_key)
3640
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
3641
+
3642
+ state_dict_keys = list(converted_state_dict.keys())
3643
+ for key in state_dict_keys:
3644
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
3645
+ if special_key not in key:
3646
+ continue
3647
+ handler_fn_inplace(key, converted_state_dict)
3648
+
3649
+ return converted_state_dict
@@ -17,12 +17,10 @@ from ..models.embeddings import (
17
17
  ImageProjection,
18
18
  MultiIPAdapterImageProjection,
19
19
  )
20
- from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
21
- from ..utils import (
22
- is_accelerate_available,
23
- is_torch_version,
24
- logging,
25
- )
20
+ from ..models.model_loading_utils import load_model_dict_into_meta
21
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
22
+ from ..utils import is_accelerate_available, is_torch_version, logging
23
+ from ..utils.torch_utils import empty_device_cache
26
24
 
27
25
 
28
26
  if is_accelerate_available():
@@ -84,13 +82,12 @@ class FluxTransformer2DLoadersMixin:
84
82
  else:
85
83
  device_map = {"": self.device}
86
84
  load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
85
+ empty_device_cache()
87
86
 
88
87
  return image_projection
89
88
 
90
89
  def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
91
- from ..models.attention_processor import (
92
- FluxIPAdapterJointAttnProcessor2_0,
93
- )
90
+ from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
94
91
 
95
92
  if low_cpu_mem_usage:
96
93
  if is_accelerate_available():
@@ -122,7 +119,7 @@ class FluxTransformer2DLoadersMixin:
122
119
  else:
123
120
  cross_attention_dim = self.config.joint_attention_dim
124
121
  hidden_size = self.inner_dim
125
- attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
122
+ attn_processor_class = FluxIPAdapterAttnProcessor
126
123
  num_image_text_embeds = []
127
124
  for state_dict in state_dicts:
128
125
  if "proj.weight" in state_dict["image_proj"]:
@@ -158,6 +155,8 @@ class FluxTransformer2DLoadersMixin:
158
155
 
159
156
  key_id += 1
160
157
 
158
+ empty_device_cache()
159
+
161
160
  return attn_procs
162
161
 
163
162
  def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
@@ -16,8 +16,10 @@ from typing import Dict
16
16
 
17
17
  from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
18
18
  from ..models.embeddings import IPAdapterTimeImageProjection
19
- from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
19
+ from ..models.model_loading_utils import load_model_dict_into_meta
20
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
20
21
  from ..utils import is_accelerate_available, is_torch_version, logging
22
+ from ..utils.torch_utils import empty_device_cache
21
23
 
22
24
 
23
25
  logger = logging.get_logger(__name__)
@@ -80,6 +82,8 @@ class SD3Transformer2DLoadersMixin:
80
82
  attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
81
83
  )
82
84
 
85
+ empty_device_cache()
86
+
83
87
  return attn_procs
84
88
 
85
89
  def _convert_ip_adapter_image_proj_to_diffusers(
@@ -147,6 +151,7 @@ class SD3Transformer2DLoadersMixin:
147
151
  else:
148
152
  device_map = {"": self.device}
149
153
  load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
154
+ empty_device_cache()
150
155
 
151
156
  return image_proj
152
157
 
diffusers/loaders/unet.py CHANGED
@@ -30,7 +30,8 @@ from ..models.embeddings import (
30
30
  IPAdapterPlusImageProjection,
31
31
  MultiIPAdapterImageProjection,
32
32
  )
33
- from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
33
+ from ..models.model_loading_utils import load_model_dict_into_meta
34
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
34
35
  from ..utils import (
35
36
  USE_PEFT_BACKEND,
36
37
  _get_model_file,
@@ -43,6 +44,7 @@ from ..utils import (
43
44
  is_torch_version,
44
45
  logging,
45
46
  )
47
+ from ..utils.torch_utils import empty_device_cache
46
48
  from .lora_base import _func_optionally_disable_offloading
47
49
  from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
48
50
  from .utils import AttnProcsLayers
@@ -131,6 +133,8 @@ class UNet2DConditionLoadersMixin:
131
133
  )
132
134
  ```
133
135
  """
136
+ from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
137
+
134
138
  cache_dir = kwargs.pop("cache_dir", None)
135
139
  force_download = kwargs.pop("force_download", False)
136
140
  proxies = kwargs.pop("proxies", None)
@@ -203,6 +207,7 @@ class UNet2DConditionLoadersMixin:
203
207
  is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
204
208
  is_model_cpu_offload = False
205
209
  is_sequential_cpu_offload = False
210
+ is_group_offload = False
206
211
 
207
212
  if is_lora:
208
213
  deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
@@ -211,7 +216,7 @@ class UNet2DConditionLoadersMixin:
211
216
  if is_custom_diffusion:
212
217
  attn_processors = self._process_custom_diffusion(state_dict=state_dict)
213
218
  elif is_lora:
214
- is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
219
+ is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
215
220
  state_dict=state_dict,
216
221
  unet_identifier_key=self.unet_name,
217
222
  network_alphas=network_alphas,
@@ -230,7 +235,9 @@ class UNet2DConditionLoadersMixin:
230
235
 
231
236
  # For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
232
237
  if is_custom_diffusion and _pipeline is not None:
233
- is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
238
+ is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
239
+ _pipeline=_pipeline
240
+ )
234
241
 
235
242
  # only custom diffusion needs to set attn processors
236
243
  self.set_attn_processor(attn_processors)
@@ -241,6 +248,10 @@ class UNet2DConditionLoadersMixin:
241
248
  _pipeline.enable_model_cpu_offload()
242
249
  elif is_sequential_cpu_offload:
243
250
  _pipeline.enable_sequential_cpu_offload()
251
+ elif is_group_offload:
252
+ for component in _pipeline.components.values():
253
+ if isinstance(component, torch.nn.Module):
254
+ _maybe_remove_and_reapply_group_offloading(component)
244
255
  # Unsafe code />
245
256
 
246
257
  def _process_custom_diffusion(self, state_dict):
@@ -307,6 +318,7 @@ class UNet2DConditionLoadersMixin:
307
318
 
308
319
  is_model_cpu_offload = False
309
320
  is_sequential_cpu_offload = False
321
+ is_group_offload = False
310
322
  state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
311
323
 
312
324
  if len(state_dict_to_be_used) > 0:
@@ -356,7 +368,9 @@ class UNet2DConditionLoadersMixin:
356
368
 
357
369
  # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
358
370
  # otherwise loading LoRA weights will lead to an error
359
- is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
371
+ is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
372
+ _pipeline
373
+ )
360
374
  peft_kwargs = {}
361
375
  if is_peft_version(">=", "0.13.1"):
362
376
  peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -389,7 +403,7 @@ class UNet2DConditionLoadersMixin:
389
403
  if warn_msg:
390
404
  logger.warning(warn_msg)
391
405
 
392
- return is_model_cpu_offload, is_sequential_cpu_offload
406
+ return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload
393
407
 
394
408
  @classmethod
395
409
  # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
@@ -741,6 +755,7 @@ class UNet2DConditionLoadersMixin:
741
755
  else:
742
756
  device_map = {"": self.device}
743
757
  load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
758
+ empty_device_cache()
744
759
 
745
760
  return image_projection
746
761
 
@@ -838,6 +853,8 @@ class UNet2DConditionLoadersMixin:
838
853
 
839
854
  key_id += 2
840
855
 
856
+ empty_device_cache()
857
+
841
858
  return attn_procs
842
859
 
843
860
  def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):