diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -78,12 +78,15 @@ if is_torch_available():
78
78
  "Lumina2LoraLoaderMixin",
79
79
  "WanLoraLoaderMixin",
80
80
  "HiDreamImageLoraLoaderMixin",
81
+ "SkyReelsV2LoraLoaderMixin",
82
+ "QwenImageLoraLoaderMixin",
81
83
  ]
82
84
  _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
83
85
  _import_structure["ip_adapter"] = [
84
86
  "IPAdapterMixin",
85
87
  "FluxIPAdapterMixin",
86
88
  "SD3IPAdapterMixin",
89
+ "ModularIPAdapterMixin",
87
90
  ]
88
91
 
89
92
  _import_structure["peft"] = ["PeftAdapterMixin"]
@@ -101,6 +104,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
101
104
  from .ip_adapter import (
102
105
  FluxIPAdapterMixin,
103
106
  IPAdapterMixin,
107
+ ModularIPAdapterMixin,
104
108
  SD3IPAdapterMixin,
105
109
  )
106
110
  from .lora_pipeline import (
@@ -115,8 +119,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
115
119
  LTXVideoLoraLoaderMixin,
116
120
  Lumina2LoraLoaderMixin,
117
121
  Mochi1LoraLoaderMixin,
122
+ QwenImageLoraLoaderMixin,
118
123
  SanaLoraLoaderMixin,
119
124
  SD3LoraLoaderMixin,
125
+ SkyReelsV2LoraLoaderMixin,
120
126
  StableDiffusionLoraLoaderMixin,
121
127
  StableDiffusionXLLoraLoaderMixin,
122
128
  WanLoraLoaderMixin,
@@ -40,8 +40,6 @@ if is_transformers_available():
40
40
  from ..models.attention_processor import (
41
41
  AttnProcessor,
42
42
  AttnProcessor2_0,
43
- FluxAttnProcessor2_0,
44
- FluxIPAdapterJointAttnProcessor2_0,
45
43
  IPAdapterAttnProcessor,
46
44
  IPAdapterAttnProcessor2_0,
47
45
  IPAdapterXFormersAttnProcessor,
@@ -354,6 +352,256 @@ class IPAdapterMixin:
354
352
  self.unet.set_attn_processor(attn_procs)
355
353
 
356
354
 
355
+ class ModularIPAdapterMixin:
356
+ """Mixin for handling IP Adapters."""
357
+
358
+ @validate_hf_hub_args
359
+ def load_ip_adapter(
360
+ self,
361
+ pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
362
+ subfolder: Union[str, List[str]],
363
+ weight_name: Union[str, List[str]],
364
+ **kwargs,
365
+ ):
366
+ """
367
+ Parameters:
368
+ pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
369
+ Can be either:
370
+
371
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
372
+ the Hub.
373
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
374
+ with [`ModelMixin.save_pretrained`].
375
+ - A [torch state
376
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
377
+ subfolder (`str` or `List[str]`):
378
+ The subfolder location of a model file within a larger model repository on the Hub or locally. If a
379
+ list is passed, it should have the same length as `weight_name`.
380
+ weight_name (`str` or `List[str]`):
381
+ The name of the weight file to load. If a list is passed, it should have the same length as
382
+ `subfolder`.
383
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
384
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
385
+ is not used.
386
+ force_download (`bool`, *optional*, defaults to `False`):
387
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
388
+ cached versions if they exist.
389
+
390
+ proxies (`Dict[str, str]`, *optional*):
391
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
392
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
393
+ local_files_only (`bool`, *optional*, defaults to `False`):
394
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
395
+ won't be downloaded from the Hub.
396
+ token (`str` or *bool*, *optional*):
397
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
398
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
399
+ revision (`str`, *optional*, defaults to `"main"`):
400
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
401
+ allowed by Git.
402
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
403
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
404
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
405
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
406
+ argument to `True` will raise an error.
407
+ """
408
+
409
+ # handle the list inputs for multiple IP Adapters
410
+ if not isinstance(weight_name, list):
411
+ weight_name = [weight_name]
412
+
413
+ if not isinstance(pretrained_model_name_or_path_or_dict, list):
414
+ pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
415
+ if len(pretrained_model_name_or_path_or_dict) == 1:
416
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
417
+
418
+ if not isinstance(subfolder, list):
419
+ subfolder = [subfolder]
420
+ if len(subfolder) == 1:
421
+ subfolder = subfolder * len(weight_name)
422
+
423
+ if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
424
+ raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
425
+
426
+ if len(weight_name) != len(subfolder):
427
+ raise ValueError("`weight_name` and `subfolder` must have the same length.")
428
+
429
+ # Load the main state dict first.
430
+ cache_dir = kwargs.pop("cache_dir", None)
431
+ force_download = kwargs.pop("force_download", False)
432
+ proxies = kwargs.pop("proxies", None)
433
+ local_files_only = kwargs.pop("local_files_only", None)
434
+ token = kwargs.pop("token", None)
435
+ revision = kwargs.pop("revision", None)
436
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
437
+
438
+ if low_cpu_mem_usage and not is_accelerate_available():
439
+ low_cpu_mem_usage = False
440
+ logger.warning(
441
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
442
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
443
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
444
+ " install accelerate\n```\n."
445
+ )
446
+
447
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
448
+ raise NotImplementedError(
449
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
450
+ " `low_cpu_mem_usage=False`."
451
+ )
452
+
453
+ user_agent = {
454
+ "file_type": "attn_procs_weights",
455
+ "framework": "pytorch",
456
+ }
457
+ state_dicts = []
458
+ for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
459
+ pretrained_model_name_or_path_or_dict, weight_name, subfolder
460
+ ):
461
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
462
+ model_file = _get_model_file(
463
+ pretrained_model_name_or_path_or_dict,
464
+ weights_name=weight_name,
465
+ cache_dir=cache_dir,
466
+ force_download=force_download,
467
+ proxies=proxies,
468
+ local_files_only=local_files_only,
469
+ token=token,
470
+ revision=revision,
471
+ subfolder=subfolder,
472
+ user_agent=user_agent,
473
+ )
474
+ if weight_name.endswith(".safetensors"):
475
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
476
+ with safe_open(model_file, framework="pt", device="cpu") as f:
477
+ for key in f.keys():
478
+ if key.startswith("image_proj."):
479
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
480
+ elif key.startswith("ip_adapter."):
481
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
482
+ else:
483
+ state_dict = load_state_dict(model_file)
484
+ else:
485
+ state_dict = pretrained_model_name_or_path_or_dict
486
+
487
+ keys = list(state_dict.keys())
488
+ if "image_proj" not in keys and "ip_adapter" not in keys:
489
+ raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
490
+
491
+ state_dicts.append(state_dict)
492
+
493
+ unet_name = getattr(self, "unet_name", "unet")
494
+ unet = getattr(self, unet_name)
495
+ unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
496
+
497
+ extra_loras = unet._load_ip_adapter_loras(state_dicts)
498
+ if extra_loras != {}:
499
+ if not USE_PEFT_BACKEND:
500
+ logger.warning("PEFT backend is required to load these weights.")
501
+ else:
502
+ # apply the IP Adapter Face ID LoRA weights
503
+ peft_config = getattr(unet, "peft_config", {})
504
+ for k, lora in extra_loras.items():
505
+ if f"faceid_{k}" not in peft_config:
506
+ self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
507
+ self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
508
+
509
+ def set_ip_adapter_scale(self, scale):
510
+ """
511
+ Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
512
+ granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
513
+
514
+ Example:
515
+
516
+ ```py
517
+ # To use original IP-Adapter
518
+ scale = 1.0
519
+ pipeline.set_ip_adapter_scale(scale)
520
+
521
+ # To use style block only
522
+ scale = {
523
+ "up": {"block_0": [0.0, 1.0, 0.0]},
524
+ }
525
+ pipeline.set_ip_adapter_scale(scale)
526
+
527
+ # To use style+layout blocks
528
+ scale = {
529
+ "down": {"block_2": [0.0, 1.0]},
530
+ "up": {"block_0": [0.0, 1.0, 0.0]},
531
+ }
532
+ pipeline.set_ip_adapter_scale(scale)
533
+
534
+ # To use style and layout from 2 reference images
535
+ scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
536
+ pipeline.set_ip_adapter_scale(scales)
537
+ ```
538
+ """
539
+ unet_name = getattr(self, "unet_name", "unet")
540
+ unet = getattr(self, unet_name)
541
+ if not isinstance(scale, list):
542
+ scale = [scale]
543
+ scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
544
+
545
+ for attn_name, attn_processor in unet.attn_processors.items():
546
+ if isinstance(
547
+ attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
548
+ ):
549
+ if len(scale_configs) != len(attn_processor.scale):
550
+ raise ValueError(
551
+ f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
552
+ )
553
+ elif len(scale_configs) == 1:
554
+ scale_configs = scale_configs * len(attn_processor.scale)
555
+ for i, scale_config in enumerate(scale_configs):
556
+ if isinstance(scale_config, dict):
557
+ for k, s in scale_config.items():
558
+ if attn_name.startswith(k):
559
+ attn_processor.scale[i] = s
560
+ else:
561
+ attn_processor.scale[i] = scale_config
562
+
563
+ def unload_ip_adapter(self):
564
+ """
565
+ Unloads the IP Adapter weights
566
+
567
+ Examples:
568
+
569
+ ```python
570
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
571
+ >>> pipeline.unload_ip_adapter()
572
+ >>> ...
573
+ ```
574
+ """
575
+
576
+ # remove hidden encoder
577
+ if self.unet is None:
578
+ return
579
+
580
+ self.unet.encoder_hid_proj = None
581
+ self.unet.config.encoder_hid_dim_type = None
582
+
583
+ # Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
584
+ if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
585
+ self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
586
+ self.unet.text_encoder_hid_proj = None
587
+ self.unet.config.encoder_hid_dim_type = "text_proj"
588
+
589
+ # restore original Unet attention processors layers
590
+ attn_procs = {}
591
+ for name, value in self.unet.attn_processors.items():
592
+ attn_processor_class = (
593
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
594
+ )
595
+ attn_procs[name] = (
596
+ attn_processor_class
597
+ if isinstance(
598
+ value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
599
+ )
600
+ else value.__class__()
601
+ )
602
+ self.unet.set_attn_processor(attn_procs)
603
+
604
+
357
605
  class FluxIPAdapterMixin:
358
606
  """Mixin for handling Flux IP Adapters."""
359
607
 
@@ -617,6 +865,9 @@ class FluxIPAdapterMixin:
617
865
  >>> ...
618
866
  ```
619
867
  """
868
+ # TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level
869
+ from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor
870
+
620
871
  # remove CLIP image encoder
621
872
  if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
622
873
  self.image_encoder = None
@@ -636,9 +887,9 @@ class FluxIPAdapterMixin:
636
887
  # restore original Transformer attention processors layers
637
888
  attn_procs = {}
638
889
  for name, value in self.transformer.attn_processors.items():
639
- attn_processor_class = FluxAttnProcessor2_0()
890
+ attn_processor_class = FluxAttnProcessor()
640
891
  attn_procs[name] = (
641
- attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
892
+ attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__()
642
893
  )
643
894
  self.transformer.set_attn_processor(attn_procs)
644
895
 
@@ -330,6 +330,8 @@ def _load_lora_into_text_encoder(
330
330
  hotswap: bool = False,
331
331
  metadata=None,
332
332
  ):
333
+ from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
334
+
333
335
  if not USE_PEFT_BACKEND:
334
336
  raise ValueError("PEFT backend is required for this method.")
335
337
 
@@ -391,7 +393,9 @@ def _load_lora_into_text_encoder(
391
393
  adapter_name = get_adapter_name(text_encoder)
392
394
 
393
395
  # <Unsafe code
394
- is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
396
+ is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = _func_optionally_disable_offloading(
397
+ _pipeline
398
+ )
395
399
  # inject LoRA layers and load the state dict
396
400
  # in transformers we automatically check whether the adapter name is already in use or not
397
401
  text_encoder.load_adapter(
@@ -410,6 +414,10 @@ def _load_lora_into_text_encoder(
410
414
  _pipeline.enable_model_cpu_offload()
411
415
  elif is_sequential_cpu_offload:
412
416
  _pipeline.enable_sequential_cpu_offload()
417
+ elif is_group_offload:
418
+ for component in _pipeline.components.values():
419
+ if isinstance(component, torch.nn.Module):
420
+ _maybe_remove_and_reapply_group_offloading(component)
413
421
  # Unsafe code />
414
422
 
415
423
  if prefix is not None and not state_dict:
@@ -433,30 +441,38 @@ def _func_optionally_disable_offloading(_pipeline):
433
441
 
434
442
  Returns:
435
443
  tuple:
436
- A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
444
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
437
445
  """
446
+ from ..hooks.group_offloading import _is_group_offload_enabled
447
+
438
448
  is_model_cpu_offload = False
439
449
  is_sequential_cpu_offload = False
450
+ is_group_offload = False
440
451
 
441
452
  if _pipeline is not None and _pipeline.hf_device_map is None:
442
453
  for _, component in _pipeline.components.items():
443
- if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
444
- if not is_model_cpu_offload:
445
- is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
446
- if not is_sequential_cpu_offload:
447
- is_sequential_cpu_offload = (
448
- isinstance(component._hf_hook, AlignDevicesHook)
449
- or hasattr(component._hf_hook, "hooks")
450
- and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
451
- )
454
+ if not isinstance(component, nn.Module):
455
+ continue
456
+ is_group_offload = is_group_offload or _is_group_offload_enabled(component)
457
+ if not hasattr(component, "_hf_hook"):
458
+ continue
459
+ is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
460
+ is_sequential_cpu_offload = is_sequential_cpu_offload or (
461
+ isinstance(component._hf_hook, AlignDevicesHook)
462
+ or hasattr(component._hf_hook, "hooks")
463
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
464
+ )
452
465
 
453
- logger.info(
454
- "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
455
- )
456
- if is_sequential_cpu_offload or is_model_cpu_offload:
457
- remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
466
+ if is_sequential_cpu_offload or is_model_cpu_offload:
467
+ logger.info(
468
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
469
+ )
470
+ for _, component in _pipeline.components.items():
471
+ if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
472
+ continue
473
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
458
474
 
459
- return (is_model_cpu_offload, is_sequential_cpu_offload)
475
+ return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
460
476
 
461
477
 
462
478
  class LoraBaseMixin:
@@ -738,7 +754,11 @@ class LoraBaseMixin:
738
754
  # Decompose weights into weights for denoiser and text encoders.
739
755
  _component_adapter_weights = {}
740
756
  for component in self._lora_loadable_modules:
741
- model = getattr(self, component)
757
+ model = getattr(self, component, None)
758
+ # To guard for cases like Wan. In Wan2.1 and WanVace, we have a single denoiser.
759
+ # Whereas in Wan 2.2, we have two denoisers.
760
+ if model is None:
761
+ continue
742
762
 
743
763
  for adapter_name, weights in zip(adapter_names, adapter_weights):
744
764
  if isinstance(weights, dict):
@@ -921,6 +941,27 @@ class LoraBaseMixin:
921
941
  Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
922
942
  you want to load multiple adapters and free some GPU memory.
923
943
 
944
+ After offloading the LoRA adapters to CPU, as long as the rest of the model is still on GPU, the LoRA adapters
945
+ can no longer be used for inference, as that would cause a device mismatch. Remember to set the device back to
946
+ GPU before using those LoRA adapters for inference.
947
+
948
+ ```python
949
+ >>> pipe.load_lora_weights(path_1, adapter_name="adapter-1")
950
+ >>> pipe.load_lora_weights(path_2, adapter_name="adapter-2")
951
+ >>> pipe.set_adapters("adapter-1")
952
+ >>> image_1 = pipe(**kwargs)
953
+ >>> # switch to adapter-2, offload adapter-1
954
+ >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cpu")
955
+ >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cuda:0")
956
+ >>> pipe.set_adapters("adapter-2")
957
+ >>> image_2 = pipe(**kwargs)
958
+ >>> # switch back to adapter-1, offload adapter-2
959
+ >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cpu")
960
+ >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cuda:0")
961
+ >>> pipe.set_adapters("adapter-1")
962
+ >>> ...
963
+ ```
964
+
924
965
  Args:
925
966
  adapter_names (`List[str]`):
926
967
  List of adapters to send device to.
@@ -936,6 +977,10 @@ class LoraBaseMixin:
936
977
  for module in model.modules():
937
978
  if isinstance(module, BaseTunerLayer):
938
979
  for adapter_name in adapter_names:
980
+ if adapter_name not in module.lora_A:
981
+ # it is sufficient to check lora_A
982
+ continue
983
+
939
984
  module.lora_A[adapter_name].to(device)
940
985
  module.lora_B[adapter_name].to(device)
941
986
  # this is a param, not a module, so device placement is not in-place -> re-assign
@@ -1022,15 +1067,3 @@ class LoraBaseMixin:
1022
1067
  @classmethod
1023
1068
  def _optionally_disable_offloading(cls, _pipeline):
1024
1069
  return _func_optionally_disable_offloading(_pipeline=_pipeline)
1025
-
1026
- @classmethod
1027
- def _fetch_state_dict(cls, *args, **kwargs):
1028
- deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
1029
- deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
1030
- return _fetch_state_dict(*args, **kwargs)
1031
-
1032
- @classmethod
1033
- def _best_guess_weight_name(cls, *args, **kwargs):
1034
- deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
1035
- deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
1036
- return _best_guess_weight_name(*args, **kwargs)