diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +74 -28
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  import os
15
16
  from typing import Callable, Dict, List, Optional, Union
16
17
 
@@ -21,7 +22,6 @@ from ..utils import (
21
22
  USE_PEFT_BACKEND,
22
23
  convert_state_dict_to_diffusers,
23
24
  convert_state_dict_to_peft,
24
- convert_unet_state_dict_to_peft,
25
25
  deprecate,
26
26
  get_adapter_name,
27
27
  get_peft_kwargs,
@@ -33,8 +33,9 @@ from ..utils import (
33
33
  logging,
34
34
  scale_lora_layers,
35
35
  )
36
- from .lora_base import LoraBaseMixin
36
+ from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
37
37
  from .lora_conversion_utils import (
38
+ _convert_bfl_flux_control_lora_to_diffusers,
38
39
  _convert_kohya_flux_lora_to_diffusers,
39
40
  _convert_non_diffusers_lora_to_diffusers,
40
41
  _convert_xlabs_flux_lora_to_diffusers,
@@ -62,8 +63,7 @@ TEXT_ENCODER_NAME = "text_encoder"
62
63
  UNET_NAME = "unet"
63
64
  TRANSFORMER_NAME = "transformer"
64
65
 
65
- LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
66
- LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
66
+ _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
67
67
 
68
68
 
69
69
  class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
@@ -222,7 +222,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
222
222
  "framework": "pytorch",
223
223
  }
224
224
 
225
- state_dict = cls._fetch_state_dict(
225
+ state_dict = _fetch_state_dict(
226
226
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
227
227
  weight_name=weight_name,
228
228
  use_safetensors=use_safetensors,
@@ -282,7 +282,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
282
282
  adapter_name (`str`, *optional*):
283
283
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
284
284
  `default_{i}` where i is the total number of adapters being loaded.
285
- Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
285
+ low_cpu_mem_usage (`bool`, *optional*):
286
+ Speed up model loading only loading the pretrained LoRA weights and not initializing the random
287
+ weights.
286
288
  """
287
289
  if not USE_PEFT_BACKEND:
288
290
  raise ValueError("PEFT backend is required for this method.")
@@ -300,8 +302,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
300
302
  if not only_text_encoder:
301
303
  # Load the layers corresponding to UNet.
302
304
  logger.info(f"Loading {cls.unet_name}.")
303
- unet.load_attn_procs(
305
+ unet.load_lora_adapter(
304
306
  state_dict,
307
+ prefix=cls.unet_name,
305
308
  network_alphas=network_alphas,
306
309
  adapter_name=adapter_name,
307
310
  _pipeline=_pipeline,
@@ -341,7 +344,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
341
344
  adapter_name (`str`, *optional*):
342
345
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
343
346
  `default_{i}` where i is the total number of adapters being loaded.
344
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
347
+ low_cpu_mem_usage (`bool`, *optional*):
348
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
349
+ weights.
345
350
  """
346
351
  if not USE_PEFT_BACKEND:
347
352
  raise ValueError("PEFT backend is required for this method.")
@@ -407,6 +412,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
407
412
  }
408
413
 
409
414
  lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
415
+
410
416
  if "use_dora" in lora_config_kwargs:
411
417
  if lora_config_kwargs["use_dora"]:
412
418
  if is_peft_version("<", "0.9.0"):
@@ -416,6 +422,17 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
416
422
  else:
417
423
  if is_peft_version("<", "0.9.0"):
418
424
  lora_config_kwargs.pop("use_dora")
425
+
426
+ if "lora_bias" in lora_config_kwargs:
427
+ if lora_config_kwargs["lora_bias"]:
428
+ if is_peft_version("<=", "0.13.2"):
429
+ raise ValueError(
430
+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
431
+ )
432
+ else:
433
+ if is_peft_version("<=", "0.13.2"):
434
+ lora_config_kwargs.pop("lora_bias")
435
+
419
436
  lora_config = LoraConfig(**lora_config_kwargs)
420
437
 
421
438
  # adapter_name
@@ -601,7 +618,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
601
618
  adapter_name (`str`, *optional*):
602
619
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
603
620
  `default_{i}` where i is the total number of adapters being loaded.
604
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
621
+ low_cpu_mem_usage (`bool`, *optional*):
622
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
623
+ weights.
605
624
  kwargs (`dict`, *optional*):
606
625
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
607
626
  """
@@ -744,7 +763,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
744
763
  "framework": "pytorch",
745
764
  }
746
765
 
747
- state_dict = cls._fetch_state_dict(
766
+ state_dict = _fetch_state_dict(
748
767
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
749
768
  weight_name=weight_name,
750
769
  use_safetensors=use_safetensors,
@@ -805,7 +824,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
805
824
  adapter_name (`str`, *optional*):
806
825
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
807
826
  `default_{i}` where i is the total number of adapters being loaded.
808
- Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
827
+ low_cpu_mem_usage (`bool`, *optional*):
828
+ Speed up model loading only loading the pretrained LoRA weights and not initializing the random
829
+ weights.
809
830
  """
810
831
  if not USE_PEFT_BACKEND:
811
832
  raise ValueError("PEFT backend is required for this method.")
@@ -823,8 +844,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
823
844
  if not only_text_encoder:
824
845
  # Load the layers corresponding to UNet.
825
846
  logger.info(f"Loading {cls.unet_name}.")
826
- unet.load_attn_procs(
847
+ unet.load_lora_adapter(
827
848
  state_dict,
849
+ prefix=cls.unet_name,
828
850
  network_alphas=network_alphas,
829
851
  adapter_name=adapter_name,
830
852
  _pipeline=_pipeline,
@@ -865,7 +887,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
865
887
  adapter_name (`str`, *optional*):
866
888
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
867
889
  `default_{i}` where i is the total number of adapters being loaded.
868
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
890
+ low_cpu_mem_usage (`bool`, *optional*):
891
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
892
+ weights.
869
893
  """
870
894
  if not USE_PEFT_BACKEND:
871
895
  raise ValueError("PEFT backend is required for this method.")
@@ -931,6 +955,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
931
955
  }
932
956
 
933
957
  lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
958
+
934
959
  if "use_dora" in lora_config_kwargs:
935
960
  if lora_config_kwargs["use_dora"]:
936
961
  if is_peft_version("<", "0.9.0"):
@@ -940,6 +965,17 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
940
965
  else:
941
966
  if is_peft_version("<", "0.9.0"):
942
967
  lora_config_kwargs.pop("use_dora")
968
+
969
+ if "lora_bias" in lora_config_kwargs:
970
+ if lora_config_kwargs["lora_bias"]:
971
+ if is_peft_version("<=", "0.13.2"):
972
+ raise ValueError(
973
+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
974
+ )
975
+ else:
976
+ if is_peft_version("<=", "0.13.2"):
977
+ lora_config_kwargs.pop("lora_bias")
978
+
943
979
  lora_config = LoraConfig(**lora_config_kwargs)
944
980
 
945
981
  # adapter_name
@@ -1182,7 +1218,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1182
1218
  "framework": "pytorch",
1183
1219
  }
1184
1220
 
1185
- state_dict = cls._fetch_state_dict(
1221
+ state_dict = _fetch_state_dict(
1186
1222
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
1187
1223
  weight_name=weight_name,
1188
1224
  use_safetensors=use_safetensors,
@@ -1226,7 +1262,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1226
1262
  adapter_name (`str`, *optional*):
1227
1263
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1228
1264
  `default_{i}` where i is the total number of adapters being loaded.
1229
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
1265
+ low_cpu_mem_usage (`bool`, *optional*):
1266
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1267
+ weights.
1230
1268
  kwargs (`dict`, *optional*):
1231
1269
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1232
1270
  """
@@ -1250,13 +1288,17 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1250
1288
  if not is_correct_format:
1251
1289
  raise ValueError("Invalid LoRA checkpoint.")
1252
1290
 
1253
- self.load_lora_into_transformer(
1254
- state_dict,
1255
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
1256
- adapter_name=adapter_name,
1257
- _pipeline=self,
1258
- low_cpu_mem_usage=low_cpu_mem_usage,
1259
- )
1291
+ transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
1292
+ if len(transformer_state_dict) > 0:
1293
+ self.load_lora_into_transformer(
1294
+ state_dict,
1295
+ transformer=getattr(self, self.transformer_name)
1296
+ if not hasattr(self, "transformer")
1297
+ else self.transformer,
1298
+ adapter_name=adapter_name,
1299
+ _pipeline=self,
1300
+ low_cpu_mem_usage=low_cpu_mem_usage,
1301
+ )
1260
1302
 
1261
1303
  text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1262
1304
  if len(text_encoder_state_dict) > 0:
@@ -1301,94 +1343,24 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1301
1343
  adapter_name (`str`, *optional*):
1302
1344
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1303
1345
  `default_{i}` where i is the total number of adapters being loaded.
1304
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
1346
+ low_cpu_mem_usage (`bool`, *optional*):
1347
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1348
+ weights.
1305
1349
  """
1306
1350
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
1307
1351
  raise ValueError(
1308
1352
  "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1309
1353
  )
1310
1354
 
1311
- from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
1312
-
1313
- keys = list(state_dict.keys())
1314
-
1315
- transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
1316
- state_dict = {
1317
- k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
1318
- }
1319
-
1320
- if len(state_dict.keys()) > 0:
1321
- # check with first key if is not in peft format
1322
- first_key = next(iter(state_dict.keys()))
1323
- if "lora_A" not in first_key:
1324
- state_dict = convert_unet_state_dict_to_peft(state_dict)
1325
-
1326
- if adapter_name in getattr(transformer, "peft_config", {}):
1327
- raise ValueError(
1328
- f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
1329
- )
1330
-
1331
- rank = {}
1332
- for key, val in state_dict.items():
1333
- if "lora_B" in key:
1334
- rank[key] = val.shape[1]
1335
-
1336
- lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
1337
- if "use_dora" in lora_config_kwargs:
1338
- if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
1339
- raise ValueError(
1340
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1341
- )
1342
- else:
1343
- lora_config_kwargs.pop("use_dora")
1344
- lora_config = LoraConfig(**lora_config_kwargs)
1345
-
1346
- # adapter_name
1347
- if adapter_name is None:
1348
- adapter_name = get_adapter_name(transformer)
1349
-
1350
- # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
1351
- # otherwise loading LoRA weights will lead to an error
1352
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
1353
-
1354
- peft_kwargs = {}
1355
- if is_peft_version(">=", "0.13.1"):
1356
- peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
1357
-
1358
- inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
1359
- incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
1360
-
1361
- warn_msg = ""
1362
- if incompatible_keys is not None:
1363
- # Check only for unexpected keys.
1364
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1365
- if unexpected_keys:
1366
- lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
1367
- if lora_unexpected_keys:
1368
- warn_msg = (
1369
- f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
1370
- f" {', '.join(lora_unexpected_keys)}. "
1371
- )
1372
-
1373
- # Filter missing keys specific to the current adapter.
1374
- missing_keys = getattr(incompatible_keys, "missing_keys", None)
1375
- if missing_keys:
1376
- lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
1377
- if lora_missing_keys:
1378
- warn_msg += (
1379
- f"Loading adapter weights from state_dict led to missing keys in the model:"
1380
- f" {', '.join(lora_missing_keys)}."
1381
- )
1382
-
1383
- if warn_msg:
1384
- logger.warning(warn_msg)
1385
-
1386
- # Offload back.
1387
- if is_model_cpu_offload:
1388
- _pipeline.enable_model_cpu_offload()
1389
- elif is_sequential_cpu_offload:
1390
- _pipeline.enable_sequential_cpu_offload()
1391
- # Unsafe code />
1355
+ # Load the layers corresponding to transformer.
1356
+ logger.info(f"Loading {cls.transformer_name}.")
1357
+ transformer.load_lora_adapter(
1358
+ state_dict,
1359
+ network_alphas=None,
1360
+ adapter_name=adapter_name,
1361
+ _pipeline=_pipeline,
1362
+ low_cpu_mem_usage=low_cpu_mem_usage,
1363
+ )
1392
1364
 
1393
1365
  @classmethod
1394
1366
  # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -1424,7 +1396,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1424
1396
  adapter_name (`str`, *optional*):
1425
1397
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1426
1398
  `default_{i}` where i is the total number of adapters being loaded.
1427
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
1399
+ low_cpu_mem_usage (`bool`, *optional*):
1400
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1401
+ weights.
1428
1402
  """
1429
1403
  if not USE_PEFT_BACKEND:
1430
1404
  raise ValueError("PEFT backend is required for this method.")
@@ -1490,6 +1464,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1490
1464
  }
1491
1465
 
1492
1466
  lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
1467
+
1493
1468
  if "use_dora" in lora_config_kwargs:
1494
1469
  if lora_config_kwargs["use_dora"]:
1495
1470
  if is_peft_version("<", "0.9.0"):
@@ -1499,6 +1474,17 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1499
1474
  else:
1500
1475
  if is_peft_version("<", "0.9.0"):
1501
1476
  lora_config_kwargs.pop("use_dora")
1477
+
1478
+ if "lora_bias" in lora_config_kwargs:
1479
+ if lora_config_kwargs["lora_bias"]:
1480
+ if is_peft_version("<=", "0.13.2"):
1481
+ raise ValueError(
1482
+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
1483
+ )
1484
+ else:
1485
+ if is_peft_version("<=", "0.13.2"):
1486
+ lora_config_kwargs.pop("lora_bias")
1487
+
1502
1488
  lora_config = LoraConfig(**lora_config_kwargs)
1503
1489
 
1504
1490
  # adapter_name
@@ -1666,6 +1652,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1666
1652
  _lora_loadable_modules = ["transformer", "text_encoder"]
1667
1653
  transformer_name = TRANSFORMER_NAME
1668
1654
  text_encoder_name = TEXT_ENCODER_NAME
1655
+ _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
1669
1656
 
1670
1657
  @classmethod
1671
1658
  @validate_hf_hub_args
@@ -1742,7 +1729,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1742
1729
  "framework": "pytorch",
1743
1730
  }
1744
1731
 
1745
- state_dict = cls._fetch_state_dict(
1732
+ state_dict = _fetch_state_dict(
1746
1733
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
1747
1734
  weight_name=weight_name,
1748
1735
  use_safetensors=use_safetensors,
@@ -1775,6 +1762,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1775
1762
  # xlabs doesn't use `alpha`.
1776
1763
  return (state_dict, None) if return_alphas else state_dict
1777
1764
 
1765
+ is_bfl_control = any("query_norm.scale" in k for k in state_dict)
1766
+ if is_bfl_control:
1767
+ state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
1768
+ return (state_dict, None) if return_alphas else state_dict
1769
+
1778
1770
  # For state dicts like
1779
1771
  # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
1780
1772
  keys = list(state_dict.keys())
@@ -1819,7 +1811,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1819
1811
  adapter_name (`str`, *optional*):
1820
1812
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1821
1813
  `default_{i}` where i is the total number of adapters being loaded.
1822
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
1814
+ low_cpu_mem_usage (`bool`, *optional*):
1815
+ `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1816
+ weights.
1823
1817
  """
1824
1818
  if not USE_PEFT_BACKEND:
1825
1819
  raise ValueError("PEFT backend is required for this method.")
@@ -1839,19 +1833,57 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1839
1833
  pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
1840
1834
  )
1841
1835
 
1842
- is_correct_format = all("lora" in key for key in state_dict.keys())
1843
- if not is_correct_format:
1836
+ has_lora_keys = any("lora" in key for key in state_dict.keys())
1837
+
1838
+ # Flux Control LoRAs also have norm keys
1839
+ has_norm_keys = any(
1840
+ norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys
1841
+ )
1842
+
1843
+ if not (has_lora_keys or has_norm_keys):
1844
1844
  raise ValueError("Invalid LoRA checkpoint.")
1845
1845
 
1846
- self.load_lora_into_transformer(
1847
- state_dict,
1848
- network_alphas=network_alphas,
1849
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
1850
- adapter_name=adapter_name,
1851
- _pipeline=self,
1852
- low_cpu_mem_usage=low_cpu_mem_usage,
1846
+ transformer_lora_state_dict = {
1847
+ k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k
1848
+ }
1849
+ transformer_norm_state_dict = {
1850
+ k: state_dict.pop(k)
1851
+ for k in list(state_dict.keys())
1852
+ if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
1853
+ }
1854
+
1855
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
1856
+ has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
1857
+ transformer, transformer_lora_state_dict, transformer_norm_state_dict
1858
+ )
1859
+
1860
+ if has_param_with_expanded_shape:
1861
+ logger.info(
1862
+ "The LoRA weights contain parameters that have different shapes that expected by the transformer. "
1863
+ "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
1864
+ "To get a comprehensive list of parameter names that were modified, enable debug logging."
1865
+ )
1866
+ transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
1867
+ transformer=transformer, lora_state_dict=transformer_lora_state_dict
1853
1868
  )
1854
1869
 
1870
+ if len(transformer_lora_state_dict) > 0:
1871
+ self.load_lora_into_transformer(
1872
+ transformer_lora_state_dict,
1873
+ network_alphas=network_alphas,
1874
+ transformer=transformer,
1875
+ adapter_name=adapter_name,
1876
+ _pipeline=self,
1877
+ low_cpu_mem_usage=low_cpu_mem_usage,
1878
+ )
1879
+
1880
+ if len(transformer_norm_state_dict) > 0:
1881
+ transformer._transformer_norm_layers = self._load_norm_into_transformer(
1882
+ transformer_norm_state_dict,
1883
+ transformer=transformer,
1884
+ discard_original_layers=False,
1885
+ )
1886
+
1855
1887
  text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1856
1888
  if len(text_encoder_state_dict) > 0:
1857
1889
  self.load_lora_into_text_encoder(
@@ -1881,104 +1913,86 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1881
1913
  The value of the network alpha used for stable learning and preventing underflow. This value has the
1882
1914
  same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
1883
1915
  link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
1884
- transformer (`SD3Transformer2DModel`):
1916
+ transformer (`FluxTransformer2DModel`):
1885
1917
  The Transformer model to load the LoRA layers into.
1886
1918
  adapter_name (`str`, *optional*):
1887
1919
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1888
1920
  `default_{i}` where i is the total number of adapters being loaded.
1889
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
1921
+ low_cpu_mem_usage (`bool`, *optional*):
1922
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1923
+ weights.
1890
1924
  """
1891
1925
  if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
1892
1926
  raise ValueError(
1893
1927
  "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1894
1928
  )
1895
1929
 
1896
- from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
1897
-
1930
+ # Load the layers corresponding to transformer.
1898
1931
  keys = list(state_dict.keys())
1932
+ transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
1933
+ if transformer_present:
1934
+ logger.info(f"Loading {cls.transformer_name}.")
1935
+ transformer.load_lora_adapter(
1936
+ state_dict,
1937
+ network_alphas=network_alphas,
1938
+ adapter_name=adapter_name,
1939
+ _pipeline=_pipeline,
1940
+ low_cpu_mem_usage=low_cpu_mem_usage,
1941
+ )
1899
1942
 
1900
- transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
1901
- state_dict = {
1902
- k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
1903
- }
1904
-
1905
- if len(state_dict.keys()) > 0:
1906
- # check with first key if is not in peft format
1907
- first_key = next(iter(state_dict.keys()))
1908
- if "lora_A" not in first_key:
1909
- state_dict = convert_unet_state_dict_to_peft(state_dict)
1910
-
1911
- if adapter_name in getattr(transformer, "peft_config", {}):
1912
- raise ValueError(
1913
- f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
1914
- )
1943
+ @classmethod
1944
+ def _load_norm_into_transformer(
1945
+ cls,
1946
+ state_dict,
1947
+ transformer,
1948
+ prefix=None,
1949
+ discard_original_layers=False,
1950
+ ) -> Dict[str, torch.Tensor]:
1951
+ # Remove prefix if present
1952
+ prefix = prefix or cls.transformer_name
1953
+ for key in list(state_dict.keys()):
1954
+ if key.split(".")[0] == prefix:
1955
+ state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
1956
+
1957
+ # Find invalid keys
1958
+ transformer_state_dict = transformer.state_dict()
1959
+ transformer_keys = set(transformer_state_dict.keys())
1960
+ state_dict_keys = set(state_dict.keys())
1961
+ extra_keys = list(state_dict_keys - transformer_keys)
1962
+
1963
+ if extra_keys:
1964
+ logger.warning(
1965
+ f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}."
1966
+ )
1915
1967
 
1916
- rank = {}
1917
- for key, val in state_dict.items():
1918
- if "lora_B" in key:
1919
- rank[key] = val.shape[1]
1968
+ for key in extra_keys:
1969
+ state_dict.pop(key)
1920
1970
 
1921
- if network_alphas is not None and len(network_alphas) >= 1:
1922
- prefix = cls.transformer_name
1923
- alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
1924
- network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
1971
+ # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected
1972
+ overwritten_layers_state_dict = {}
1973
+ if not discard_original_layers:
1974
+ for key in state_dict.keys():
1975
+ overwritten_layers_state_dict[key] = transformer_state_dict[key].clone()
1925
1976
 
1926
- lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
1927
- if "use_dora" in lora_config_kwargs:
1928
- if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
1929
- raise ValueError(
1930
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1931
- )
1932
- else:
1933
- lora_config_kwargs.pop("use_dora")
1934
- lora_config = LoraConfig(**lora_config_kwargs)
1935
-
1936
- # adapter_name
1937
- if adapter_name is None:
1938
- adapter_name = get_adapter_name(transformer)
1939
-
1940
- # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
1941
- # otherwise loading LoRA weights will lead to an error
1942
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
1943
-
1944
- peft_kwargs = {}
1945
- if is_peft_version(">=", "0.13.1"):
1946
- peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
1947
-
1948
- inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
1949
- incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
1950
-
1951
- warn_msg = ""
1952
- if incompatible_keys is not None:
1953
- # Check only for unexpected keys.
1954
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1955
- if unexpected_keys:
1956
- lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
1957
- if lora_unexpected_keys:
1958
- warn_msg = (
1959
- f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
1960
- f" {', '.join(lora_unexpected_keys)}. "
1961
- )
1977
+ logger.info(
1978
+ "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer "
1979
+ 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly '
1980
+ "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. "
1981
+ "If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues."
1982
+ )
1962
1983
 
1963
- # Filter missing keys specific to the current adapter.
1964
- missing_keys = getattr(incompatible_keys, "missing_keys", None)
1965
- if missing_keys:
1966
- lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
1967
- if lora_missing_keys:
1968
- warn_msg += (
1969
- f"Loading adapter weights from state_dict led to missing keys in the model:"
1970
- f" {', '.join(lora_missing_keys)}."
1971
- )
1984
+ # We can't load with strict=True because the current state_dict does not contain all the transformer keys
1985
+ incompatible_keys = transformer.load_state_dict(state_dict, strict=False)
1986
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1972
1987
 
1973
- if warn_msg:
1974
- logger.warning(warn_msg)
1988
+ # We shouldn't expect to see the supported norm keys here being present in the unexpected keys.
1989
+ if unexpected_keys:
1990
+ if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys):
1991
+ raise ValueError(
1992
+ f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer."
1993
+ )
1975
1994
 
1976
- # Offload back.
1977
- if is_model_cpu_offload:
1978
- _pipeline.enable_model_cpu_offload()
1979
- elif is_sequential_cpu_offload:
1980
- _pipeline.enable_sequential_cpu_offload()
1981
- # Unsafe code />
1995
+ return overwritten_layers_state_dict
1982
1996
 
1983
1997
  @classmethod
1984
1998
  # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -2014,7 +2028,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2014
2028
  adapter_name (`str`, *optional*):
2015
2029
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2016
2030
  `default_{i}` where i is the total number of adapters being loaded.
2017
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
2031
+ low_cpu_mem_usage (`bool`, *optional*):
2032
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2033
+ weights.
2018
2034
  """
2019
2035
  if not USE_PEFT_BACKEND:
2020
2036
  raise ValueError("PEFT backend is required for this method.")
@@ -2080,6 +2096,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2080
2096
  }
2081
2097
 
2082
2098
  lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
2099
+
2083
2100
  if "use_dora" in lora_config_kwargs:
2084
2101
  if lora_config_kwargs["use_dora"]:
2085
2102
  if is_peft_version("<", "0.9.0"):
@@ -2089,6 +2106,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2089
2106
  else:
2090
2107
  if is_peft_version("<", "0.9.0"):
2091
2108
  lora_config_kwargs.pop("use_dora")
2109
+
2110
+ if "lora_bias" in lora_config_kwargs:
2111
+ if lora_config_kwargs["lora_bias"]:
2112
+ if is_peft_version("<=", "0.13.2"):
2113
+ raise ValueError(
2114
+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
2115
+ )
2116
+ else:
2117
+ if is_peft_version("<=", "0.13.2"):
2118
+ lora_config_kwargs.pop("lora_bias")
2119
+
2092
2120
  lora_config = LoraConfig(**lora_config_kwargs)
2093
2121
 
2094
2122
  # adapter_name
@@ -2173,7 +2201,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2173
2201
  safe_serialization=safe_serialization,
2174
2202
  )
2175
2203
 
2176
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
2177
2204
  def fuse_lora(
2178
2205
  self,
2179
2206
  components: List[str] = ["transformer", "text_encoder"],
@@ -2213,6 +2240,19 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2213
2240
  pipeline.fuse_lora(lora_scale=0.7)
2214
2241
  ```
2215
2242
  """
2243
+
2244
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
2245
+ if (
2246
+ hasattr(transformer, "_transformer_norm_layers")
2247
+ and isinstance(transformer._transformer_norm_layers, dict)
2248
+ and len(transformer._transformer_norm_layers.keys()) > 0
2249
+ ):
2250
+ logger.info(
2251
+ "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer "
2252
+ "as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly "
2253
+ "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed."
2254
+ )
2255
+
2216
2256
  super().fuse_lora(
2217
2257
  components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
2218
2258
  )
@@ -2231,8 +2271,168 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2231
2271
  Args:
2232
2272
  components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
2233
2273
  """
2274
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
2275
+ if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
2276
+ transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
2277
+
2234
2278
  super().unfuse_lora(components=components)
2235
2279
 
2280
+ # We override this here account for `_transformer_norm_layers`.
2281
+ def unload_lora_weights(self):
2282
+ super().unload_lora_weights()
2283
+
2284
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
2285
+ if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
2286
+ transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
2287
+ transformer._transformer_norm_layers = None
2288
+
2289
+ @classmethod
2290
+ def _maybe_expand_transformer_param_shape_or_error_(
2291
+ cls,
2292
+ transformer: torch.nn.Module,
2293
+ lora_state_dict=None,
2294
+ norm_state_dict=None,
2295
+ prefix=None,
2296
+ ) -> bool:
2297
+ """
2298
+ Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and
2299
+ generalizes things a bit so that any parameter that needs expansion receives appropriate treatement.
2300
+ """
2301
+ state_dict = {}
2302
+ if lora_state_dict is not None:
2303
+ state_dict.update(lora_state_dict)
2304
+ if norm_state_dict is not None:
2305
+ state_dict.update(norm_state_dict)
2306
+
2307
+ # Remove prefix if present
2308
+ prefix = prefix or cls.transformer_name
2309
+ for key in list(state_dict.keys()):
2310
+ if key.split(".")[0] == prefix:
2311
+ state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
2312
+
2313
+ # Expand transformer parameter shapes if they don't match lora
2314
+ has_param_with_shape_update = False
2315
+ is_peft_loaded = getattr(transformer, "peft_config", None) is not None
2316
+ for name, module in transformer.named_modules():
2317
+ if isinstance(module, torch.nn.Linear):
2318
+ module_weight = module.weight.data
2319
+ module_bias = module.bias.data if module.bias is not None else None
2320
+ bias = module_bias is not None
2321
+
2322
+ lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
2323
+ lora_A_weight_name = f"{lora_base_name}.lora_A.weight"
2324
+ lora_B_weight_name = f"{lora_base_name}.lora_B.weight"
2325
+ if lora_A_weight_name not in state_dict:
2326
+ continue
2327
+
2328
+ in_features = state_dict[lora_A_weight_name].shape[1]
2329
+ out_features = state_dict[lora_B_weight_name].shape[0]
2330
+
2331
+ # This means there's no need for an expansion in the params, so we simply skip.
2332
+ if tuple(module_weight.shape) == (out_features, in_features):
2333
+ continue
2334
+
2335
+ module_out_features, module_in_features = module_weight.shape
2336
+ debug_message = ""
2337
+ if in_features > module_in_features:
2338
+ debug_message += (
2339
+ f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
2340
+ f"checkpoint contains higher number of features than expected. The number of input_features will be "
2341
+ f"expanded from {module_in_features} to {in_features}"
2342
+ )
2343
+ if out_features > module_out_features:
2344
+ debug_message += (
2345
+ ", and the number of output features will be "
2346
+ f"expanded from {module_out_features} to {out_features}."
2347
+ )
2348
+ else:
2349
+ debug_message += "."
2350
+ if debug_message:
2351
+ logger.debug(debug_message)
2352
+
2353
+ if out_features > module_out_features or in_features > module_in_features:
2354
+ has_param_with_shape_update = True
2355
+ parent_module_name, _, current_module_name = name.rpartition(".")
2356
+ parent_module = transformer.get_submodule(parent_module_name)
2357
+
2358
+ with torch.device("meta"):
2359
+ expanded_module = torch.nn.Linear(
2360
+ in_features, out_features, bias=bias, dtype=module_weight.dtype
2361
+ )
2362
+ # Only weights are expanded and biases are not. This is because only the input dimensions
2363
+ # are changed while the output dimensions remain the same. The shape of the weight tensor
2364
+ # is (out_features, in_features), while the shape of bias tensor is (out_features,), which
2365
+ # explains the reason why only weights are expanded.
2366
+ new_weight = torch.zeros_like(
2367
+ expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
2368
+ )
2369
+ slices = tuple(slice(0, dim) for dim in module_weight.shape)
2370
+ new_weight[slices] = module_weight
2371
+ tmp_state_dict = {"weight": new_weight}
2372
+ if module_bias is not None:
2373
+ tmp_state_dict["bias"] = module_bias
2374
+ expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)
2375
+
2376
+ setattr(parent_module, current_module_name, expanded_module)
2377
+
2378
+ del tmp_state_dict
2379
+
2380
+ if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
2381
+ attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
2382
+ new_value = int(expanded_module.weight.data.shape[1])
2383
+ old_value = getattr(transformer.config, attribute_name)
2384
+ setattr(transformer.config, attribute_name, new_value)
2385
+ logger.info(
2386
+ f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
2387
+ )
2388
+
2389
+ return has_param_with_shape_update
2390
+
2391
+ @classmethod
2392
+ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
2393
+ expanded_module_names = set()
2394
+ transformer_state_dict = transformer.state_dict()
2395
+ prefix = f"{cls.transformer_name}."
2396
+
2397
+ lora_module_names = [
2398
+ key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
2399
+ ]
2400
+ lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
2401
+ lora_module_names = sorted(set(lora_module_names))
2402
+ transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
2403
+ unexpected_modules = set(lora_module_names) - set(transformer_module_names)
2404
+ if unexpected_modules:
2405
+ logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
2406
+
2407
+ is_peft_loaded = getattr(transformer, "peft_config", None) is not None
2408
+ for k in lora_module_names:
2409
+ if k in unexpected_modules:
2410
+ continue
2411
+
2412
+ base_param_name = (
2413
+ f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
2414
+ )
2415
+ base_weight_param = transformer_state_dict[base_param_name]
2416
+ lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
2417
+
2418
+ if base_weight_param.shape[1] > lora_A_param.shape[1]:
2419
+ shape = (lora_A_param.shape[0], base_weight_param.shape[1])
2420
+ expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
2421
+ expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
2422
+ lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
2423
+ expanded_module_names.add(k)
2424
+ elif base_weight_param.shape[1] < lora_A_param.shape[1]:
2425
+ raise NotImplementedError(
2426
+ f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
2427
+ )
2428
+
2429
+ if expanded_module_names:
2430
+ logger.info(
2431
+ f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
2432
+ )
2433
+
2434
+ return lora_state_dict
2435
+
2236
2436
 
2237
2437
  # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
2238
2438
  # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
@@ -2242,7 +2442,10 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2242
2442
  text_encoder_name = TEXT_ENCODER_NAME
2243
2443
 
2244
2444
  @classmethod
2245
- def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
2445
+ # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
2446
+ def load_lora_into_transformer(
2447
+ cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
2448
+ ):
2246
2449
  """
2247
2450
  This will load the LoRA layers specified in `state_dict` into `transformer`.
2248
2451
 
@@ -2255,93 +2458,32 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2255
2458
  The value of the network alpha used for stable learning and preventing underflow. This value has the
2256
2459
  same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
2257
2460
  link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
2258
- unet (`UNet2DConditionModel`):
2259
- The UNet model to load the LoRA layers into.
2461
+ transformer (`UVit2DModel`):
2462
+ The Transformer model to load the LoRA layers into.
2260
2463
  adapter_name (`str`, *optional*):
2261
2464
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2262
2465
  `default_{i}` where i is the total number of adapters being loaded.
2466
+ low_cpu_mem_usage (`bool`, *optional*):
2467
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2468
+ weights.
2263
2469
  """
2264
- if not USE_PEFT_BACKEND:
2265
- raise ValueError("PEFT backend is required for this method.")
2266
-
2267
- from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
2470
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
2471
+ raise ValueError(
2472
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2473
+ )
2268
2474
 
2475
+ # Load the layers corresponding to transformer.
2269
2476
  keys = list(state_dict.keys())
2270
-
2271
- transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
2272
- state_dict = {
2273
- k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
2274
- }
2275
-
2276
- if network_alphas is not None:
2277
- alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)]
2278
- network_alphas = {
2279
- k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
2280
- }
2281
-
2282
- if len(state_dict.keys()) > 0:
2283
- if adapter_name in getattr(transformer, "peft_config", {}):
2284
- raise ValueError(
2285
- f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
2286
- )
2287
-
2288
- rank = {}
2289
- for key, val in state_dict.items():
2290
- if "lora_B" in key:
2291
- rank[key] = val.shape[1]
2292
-
2293
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
2294
- if "use_dora" in lora_config_kwargs:
2295
- if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
2296
- raise ValueError(
2297
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
2298
- )
2299
- else:
2300
- lora_config_kwargs.pop("use_dora")
2301
- lora_config = LoraConfig(**lora_config_kwargs)
2302
-
2303
- # adapter_name
2304
- if adapter_name is None:
2305
- adapter_name = get_adapter_name(transformer)
2306
-
2307
- # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
2308
- # otherwise loading LoRA weights will lead to an error
2309
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
2310
-
2311
- inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
2312
- incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
2313
-
2314
- warn_msg = ""
2315
- if incompatible_keys is not None:
2316
- # Check only for unexpected keys.
2317
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
2318
- if unexpected_keys:
2319
- lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
2320
- if lora_unexpected_keys:
2321
- warn_msg = (
2322
- f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
2323
- f" {', '.join(lora_unexpected_keys)}. "
2324
- )
2325
-
2326
- # Filter missing keys specific to the current adapter.
2327
- missing_keys = getattr(incompatible_keys, "missing_keys", None)
2328
- if missing_keys:
2329
- lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
2330
- if lora_missing_keys:
2331
- warn_msg += (
2332
- f"Loading adapter weights from state_dict led to missing keys in the model:"
2333
- f" {', '.join(lora_missing_keys)}."
2334
- )
2335
-
2336
- if warn_msg:
2337
- logger.warning(warn_msg)
2338
-
2339
- # Offload back.
2340
- if is_model_cpu_offload:
2341
- _pipeline.enable_model_cpu_offload()
2342
- elif is_sequential_cpu_offload:
2343
- _pipeline.enable_sequential_cpu_offload()
2344
- # Unsafe code />
2477
+ transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
2478
+ if transformer_present:
2479
+ logger.info(f"Loading {cls.transformer_name}.")
2480
+ transformer.load_lora_adapter(
2481
+ state_dict,
2482
+ network_alphas=network_alphas,
2483
+ adapter_name=adapter_name,
2484
+ _pipeline=_pipeline,
2485
+ low_cpu_mem_usage=low_cpu_mem_usage,
2486
+ )
2345
2487
 
2346
2488
  @classmethod
2347
2489
  # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -2377,7 +2519,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2377
2519
  adapter_name (`str`, *optional*):
2378
2520
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2379
2521
  `default_{i}` where i is the total number of adapters being loaded.
2380
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
2522
+ low_cpu_mem_usage (`bool`, *optional*):
2523
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2524
+ weights.
2381
2525
  """
2382
2526
  if not USE_PEFT_BACKEND:
2383
2527
  raise ValueError("PEFT backend is required for this method.")
@@ -2443,6 +2587,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2443
2587
  }
2444
2588
 
2445
2589
  lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
2590
+
2446
2591
  if "use_dora" in lora_config_kwargs:
2447
2592
  if lora_config_kwargs["use_dora"]:
2448
2593
  if is_peft_version("<", "0.9.0"):
@@ -2452,6 +2597,17 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2452
2597
  else:
2453
2598
  if is_peft_version("<", "0.9.0"):
2454
2599
  lora_config_kwargs.pop("use_dora")
2600
+
2601
+ if "lora_bias" in lora_config_kwargs:
2602
+ if lora_config_kwargs["lora_bias"]:
2603
+ if is_peft_version("<=", "0.13.2"):
2604
+ raise ValueError(
2605
+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
2606
+ )
2607
+ else:
2608
+ if is_peft_version("<=", "0.13.2"):
2609
+ lora_config_kwargs.pop("lora_bias")
2610
+
2455
2611
  lora_config = LoraConfig(**lora_config_kwargs)
2456
2612
 
2457
2613
  # adapter_name
@@ -2538,7 +2694,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2538
2694
 
2539
2695
  class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2540
2696
  r"""
2541
- Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoX`].
2697
+ Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline`].
2542
2698
  """
2543
2699
 
2544
2700
  _lora_loadable_modules = ["transformer"]
@@ -2619,7 +2775,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2619
2775
  "framework": "pytorch",
2620
2776
  }
2621
2777
 
2622
- state_dict = cls._fetch_state_dict(
2778
+ state_dict = _fetch_state_dict(
2623
2779
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
2624
2780
  weight_name=weight_name,
2625
2781
  use_safetensors=use_safetensors,
@@ -2658,7 +2814,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2658
2814
  adapter_name (`str`, *optional*):
2659
2815
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2660
2816
  `default_{i}` where i is the total number of adapters being loaded.
2661
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
2817
+ low_cpu_mem_usage (`bool`, *optional*):
2818
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2819
+ weights.
2662
2820
  kwargs (`dict`, *optional*):
2663
2821
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
2664
2822
  """
@@ -2691,7 +2849,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2691
2849
  )
2692
2850
 
2693
2851
  @classmethod
2694
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
2852
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
2695
2853
  def load_lora_into_transformer(
2696
2854
  cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
2697
2855
  ):
@@ -2703,99 +2861,29 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2703
2861
  A standard state dict containing the lora layer parameters. The keys can either be indexed directly
2704
2862
  into the unet or prefixed with an additional `unet` which can be used to distinguish between text
2705
2863
  encoder lora layers.
2706
- transformer (`SD3Transformer2DModel`):
2864
+ transformer (`CogVideoXTransformer3DModel`):
2707
2865
  The Transformer model to load the LoRA layers into.
2708
2866
  adapter_name (`str`, *optional*):
2709
2867
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2710
2868
  `default_{i}` where i is the total number of adapters being loaded.
2711
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
2869
+ low_cpu_mem_usage (`bool`, *optional*):
2870
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2871
+ weights.
2712
2872
  """
2713
2873
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
2714
2874
  raise ValueError(
2715
2875
  "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2716
2876
  )
2717
2877
 
2718
- from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
2719
-
2720
- keys = list(state_dict.keys())
2721
-
2722
- transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
2723
- state_dict = {
2724
- k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
2725
- }
2726
-
2727
- if len(state_dict.keys()) > 0:
2728
- # check with first key if is not in peft format
2729
- first_key = next(iter(state_dict.keys()))
2730
- if "lora_A" not in first_key:
2731
- state_dict = convert_unet_state_dict_to_peft(state_dict)
2732
-
2733
- if adapter_name in getattr(transformer, "peft_config", {}):
2734
- raise ValueError(
2735
- f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
2736
- )
2737
-
2738
- rank = {}
2739
- for key, val in state_dict.items():
2740
- if "lora_B" in key:
2741
- rank[key] = val.shape[1]
2742
-
2743
- lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
2744
- if "use_dora" in lora_config_kwargs:
2745
- if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
2746
- raise ValueError(
2747
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
2748
- )
2749
- else:
2750
- lora_config_kwargs.pop("use_dora")
2751
- lora_config = LoraConfig(**lora_config_kwargs)
2752
-
2753
- # adapter_name
2754
- if adapter_name is None:
2755
- adapter_name = get_adapter_name(transformer)
2756
-
2757
- # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
2758
- # otherwise loading LoRA weights will lead to an error
2759
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
2760
-
2761
- peft_kwargs = {}
2762
- if is_peft_version(">=", "0.13.1"):
2763
- peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
2764
-
2765
- inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
2766
- incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
2767
-
2768
- warn_msg = ""
2769
- if incompatible_keys is not None:
2770
- # Check only for unexpected keys.
2771
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
2772
- if unexpected_keys:
2773
- lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
2774
- if lora_unexpected_keys:
2775
- warn_msg = (
2776
- f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
2777
- f" {', '.join(lora_unexpected_keys)}. "
2778
- )
2779
-
2780
- # Filter missing keys specific to the current adapter.
2781
- missing_keys = getattr(incompatible_keys, "missing_keys", None)
2782
- if missing_keys:
2783
- lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
2784
- if lora_missing_keys:
2785
- warn_msg += (
2786
- f"Loading adapter weights from state_dict led to missing keys in the model:"
2787
- f" {', '.join(lora_missing_keys)}."
2788
- )
2789
-
2790
- if warn_msg:
2791
- logger.warning(warn_msg)
2792
-
2793
- # Offload back.
2794
- if is_model_cpu_offload:
2795
- _pipeline.enable_model_cpu_offload()
2796
- elif is_sequential_cpu_offload:
2797
- _pipeline.enable_sequential_cpu_offload()
2798
- # Unsafe code />
2878
+ # Load the layers corresponding to transformer.
2879
+ logger.info(f"Loading {cls.transformer_name}.")
2880
+ transformer.load_lora_adapter(
2881
+ state_dict,
2882
+ network_alphas=None,
2883
+ adapter_name=adapter_name,
2884
+ _pipeline=_pipeline,
2885
+ low_cpu_mem_usage=low_cpu_mem_usage,
2886
+ )
2799
2887
 
2800
2888
  @classmethod
2801
2889
  # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
@@ -2911,6 +2999,1238 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2911
2999
  super().unfuse_lora(components=components)
2912
3000
 
2913
3001
 
3002
+ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3003
+ r"""
3004
+ Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`].
3005
+ """
3006
+
3007
+ _lora_loadable_modules = ["transformer"]
3008
+ transformer_name = TRANSFORMER_NAME
3009
+
3010
+ @classmethod
3011
+ @validate_hf_hub_args
3012
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
3013
+ def lora_state_dict(
3014
+ cls,
3015
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3016
+ **kwargs,
3017
+ ):
3018
+ r"""
3019
+ Return state dict for lora weights and the network alphas.
3020
+
3021
+ <Tip warning={true}>
3022
+
3023
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
3024
+
3025
+ This function is experimental and might change in the future.
3026
+
3027
+ </Tip>
3028
+
3029
+ Parameters:
3030
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3031
+ Can be either:
3032
+
3033
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3034
+ the Hub.
3035
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3036
+ with [`ModelMixin.save_pretrained`].
3037
+ - A [torch state
3038
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3039
+
3040
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3041
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3042
+ is not used.
3043
+ force_download (`bool`, *optional*, defaults to `False`):
3044
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3045
+ cached versions if they exist.
3046
+
3047
+ proxies (`Dict[str, str]`, *optional*):
3048
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3049
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3050
+ local_files_only (`bool`, *optional*, defaults to `False`):
3051
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
3052
+ won't be downloaded from the Hub.
3053
+ token (`str` or *bool*, *optional*):
3054
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3055
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
3056
+ revision (`str`, *optional*, defaults to `"main"`):
3057
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3058
+ allowed by Git.
3059
+ subfolder (`str`, *optional*, defaults to `""`):
3060
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
3061
+
3062
+ """
3063
+ # Load the main state dict first which has the LoRA layers for either of
3064
+ # transformer and text encoder or both.
3065
+ cache_dir = kwargs.pop("cache_dir", None)
3066
+ force_download = kwargs.pop("force_download", False)
3067
+ proxies = kwargs.pop("proxies", None)
3068
+ local_files_only = kwargs.pop("local_files_only", None)
3069
+ token = kwargs.pop("token", None)
3070
+ revision = kwargs.pop("revision", None)
3071
+ subfolder = kwargs.pop("subfolder", None)
3072
+ weight_name = kwargs.pop("weight_name", None)
3073
+ use_safetensors = kwargs.pop("use_safetensors", None)
3074
+
3075
+ allow_pickle = False
3076
+ if use_safetensors is None:
3077
+ use_safetensors = True
3078
+ allow_pickle = True
3079
+
3080
+ user_agent = {
3081
+ "file_type": "attn_procs_weights",
3082
+ "framework": "pytorch",
3083
+ }
3084
+
3085
+ state_dict = _fetch_state_dict(
3086
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
3087
+ weight_name=weight_name,
3088
+ use_safetensors=use_safetensors,
3089
+ local_files_only=local_files_only,
3090
+ cache_dir=cache_dir,
3091
+ force_download=force_download,
3092
+ proxies=proxies,
3093
+ token=token,
3094
+ revision=revision,
3095
+ subfolder=subfolder,
3096
+ user_agent=user_agent,
3097
+ allow_pickle=allow_pickle,
3098
+ )
3099
+
3100
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
3101
+ if is_dora_scale_present:
3102
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3103
+ logger.warning(warn_msg)
3104
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
3105
+
3106
+ return state_dict
3107
+
3108
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3109
+ def load_lora_weights(
3110
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
3111
+ ):
3112
+ """
3113
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3114
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3115
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3116
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3117
+ dict is loaded into `self.transformer`.
3118
+
3119
+ Parameters:
3120
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3121
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3122
+ adapter_name (`str`, *optional*):
3123
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3124
+ `default_{i}` where i is the total number of adapters being loaded.
3125
+ low_cpu_mem_usage (`bool`, *optional*):
3126
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3127
+ weights.
3128
+ kwargs (`dict`, *optional*):
3129
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3130
+ """
3131
+ if not USE_PEFT_BACKEND:
3132
+ raise ValueError("PEFT backend is required for this method.")
3133
+
3134
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
3135
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3136
+ raise ValueError(
3137
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3138
+ )
3139
+
3140
+ # if a dict is passed, copy it instead of modifying it inplace
3141
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
3142
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
3143
+
3144
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3145
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3146
+
3147
+ is_correct_format = all("lora" in key for key in state_dict.keys())
3148
+ if not is_correct_format:
3149
+ raise ValueError("Invalid LoRA checkpoint.")
3150
+
3151
+ self.load_lora_into_transformer(
3152
+ state_dict,
3153
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
3154
+ adapter_name=adapter_name,
3155
+ _pipeline=self,
3156
+ low_cpu_mem_usage=low_cpu_mem_usage,
3157
+ )
3158
+
3159
+ @classmethod
3160
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
3161
+ def load_lora_into_transformer(
3162
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
3163
+ ):
3164
+ """
3165
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
3166
+
3167
+ Parameters:
3168
+ state_dict (`dict`):
3169
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3170
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3171
+ encoder lora layers.
3172
+ transformer (`MochiTransformer3DModel`):
3173
+ The Transformer model to load the LoRA layers into.
3174
+ adapter_name (`str`, *optional*):
3175
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3176
+ `default_{i}` where i is the total number of adapters being loaded.
3177
+ low_cpu_mem_usage (`bool`, *optional*):
3178
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3179
+ weights.
3180
+ """
3181
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3182
+ raise ValueError(
3183
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3184
+ )
3185
+
3186
+ # Load the layers corresponding to transformer.
3187
+ logger.info(f"Loading {cls.transformer_name}.")
3188
+ transformer.load_lora_adapter(
3189
+ state_dict,
3190
+ network_alphas=None,
3191
+ adapter_name=adapter_name,
3192
+ _pipeline=_pipeline,
3193
+ low_cpu_mem_usage=low_cpu_mem_usage,
3194
+ )
3195
+
3196
+ @classmethod
3197
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
3198
+ def save_lora_weights(
3199
+ cls,
3200
+ save_directory: Union[str, os.PathLike],
3201
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
3202
+ is_main_process: bool = True,
3203
+ weight_name: str = None,
3204
+ save_function: Callable = None,
3205
+ safe_serialization: bool = True,
3206
+ ):
3207
+ r"""
3208
+ Save the LoRA parameters corresponding to the UNet and text encoder.
3209
+
3210
+ Arguments:
3211
+ save_directory (`str` or `os.PathLike`):
3212
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
3213
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
3214
+ State dict of the LoRA layers corresponding to the `transformer`.
3215
+ is_main_process (`bool`, *optional*, defaults to `True`):
3216
+ Whether the process calling this is the main process or not. Useful during distributed training and you
3217
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
3218
+ process to avoid race conditions.
3219
+ save_function (`Callable`):
3220
+ The function to use to save the state dictionary. Useful during distributed training when you need to
3221
+ replace `torch.save` with another method. Can be configured with the environment variable
3222
+ `DIFFUSERS_SAVE_MODE`.
3223
+ safe_serialization (`bool`, *optional*, defaults to `True`):
3224
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3225
+ """
3226
+ state_dict = {}
3227
+
3228
+ if not transformer_lora_layers:
3229
+ raise ValueError("You must pass `transformer_lora_layers`.")
3230
+
3231
+ if transformer_lora_layers:
3232
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3233
+
3234
+ # Save the model
3235
+ cls.write_lora_layers(
3236
+ state_dict=state_dict,
3237
+ save_directory=save_directory,
3238
+ is_main_process=is_main_process,
3239
+ weight_name=weight_name,
3240
+ save_function=save_function,
3241
+ safe_serialization=safe_serialization,
3242
+ )
3243
+
3244
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
3245
+ def fuse_lora(
3246
+ self,
3247
+ components: List[str] = ["transformer", "text_encoder"],
3248
+ lora_scale: float = 1.0,
3249
+ safe_fusing: bool = False,
3250
+ adapter_names: Optional[List[str]] = None,
3251
+ **kwargs,
3252
+ ):
3253
+ r"""
3254
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
3255
+
3256
+ <Tip warning={true}>
3257
+
3258
+ This is an experimental API.
3259
+
3260
+ </Tip>
3261
+
3262
+ Args:
3263
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
3264
+ lora_scale (`float`, defaults to 1.0):
3265
+ Controls how much to influence the outputs with the LoRA parameters.
3266
+ safe_fusing (`bool`, defaults to `False`):
3267
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
3268
+ adapter_names (`List[str]`, *optional*):
3269
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
3270
+
3271
+ Example:
3272
+
3273
+ ```py
3274
+ from diffusers import DiffusionPipeline
3275
+ import torch
3276
+
3277
+ pipeline = DiffusionPipeline.from_pretrained(
3278
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
3279
+ ).to("cuda")
3280
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
3281
+ pipeline.fuse_lora(lora_scale=0.7)
3282
+ ```
3283
+ """
3284
+ super().fuse_lora(
3285
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3286
+ )
3287
+
3288
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
3289
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
3290
+ r"""
3291
+ Reverses the effect of
3292
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
3293
+
3294
+ <Tip warning={true}>
3295
+
3296
+ This is an experimental API.
3297
+
3298
+ </Tip>
3299
+
3300
+ Args:
3301
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3302
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3303
+ unfuse_text_encoder (`bool`, defaults to `True`):
3304
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
3305
+ LoRA parameters then it won't have any effect.
3306
+ """
3307
+ super().unfuse_lora(components=components)
3308
+
3309
+
3310
+ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3311
+ r"""
3312
+ Load LoRA layers into [`LTXVideoTransformer3DModel`]. Specific to [`LTXPipeline`].
3313
+ """
3314
+
3315
+ _lora_loadable_modules = ["transformer"]
3316
+ transformer_name = TRANSFORMER_NAME
3317
+
3318
+ @classmethod
3319
+ @validate_hf_hub_args
3320
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
3321
+ def lora_state_dict(
3322
+ cls,
3323
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3324
+ **kwargs,
3325
+ ):
3326
+ r"""
3327
+ Return state dict for lora weights and the network alphas.
3328
+
3329
+ <Tip warning={true}>
3330
+
3331
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
3332
+
3333
+ This function is experimental and might change in the future.
3334
+
3335
+ </Tip>
3336
+
3337
+ Parameters:
3338
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3339
+ Can be either:
3340
+
3341
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3342
+ the Hub.
3343
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3344
+ with [`ModelMixin.save_pretrained`].
3345
+ - A [torch state
3346
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3347
+
3348
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3349
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3350
+ is not used.
3351
+ force_download (`bool`, *optional*, defaults to `False`):
3352
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3353
+ cached versions if they exist.
3354
+
3355
+ proxies (`Dict[str, str]`, *optional*):
3356
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3357
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3358
+ local_files_only (`bool`, *optional*, defaults to `False`):
3359
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
3360
+ won't be downloaded from the Hub.
3361
+ token (`str` or *bool*, *optional*):
3362
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3363
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
3364
+ revision (`str`, *optional*, defaults to `"main"`):
3365
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3366
+ allowed by Git.
3367
+ subfolder (`str`, *optional*, defaults to `""`):
3368
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
3369
+
3370
+ """
3371
+ # Load the main state dict first which has the LoRA layers for either of
3372
+ # transformer and text encoder or both.
3373
+ cache_dir = kwargs.pop("cache_dir", None)
3374
+ force_download = kwargs.pop("force_download", False)
3375
+ proxies = kwargs.pop("proxies", None)
3376
+ local_files_only = kwargs.pop("local_files_only", None)
3377
+ token = kwargs.pop("token", None)
3378
+ revision = kwargs.pop("revision", None)
3379
+ subfolder = kwargs.pop("subfolder", None)
3380
+ weight_name = kwargs.pop("weight_name", None)
3381
+ use_safetensors = kwargs.pop("use_safetensors", None)
3382
+
3383
+ allow_pickle = False
3384
+ if use_safetensors is None:
3385
+ use_safetensors = True
3386
+ allow_pickle = True
3387
+
3388
+ user_agent = {
3389
+ "file_type": "attn_procs_weights",
3390
+ "framework": "pytorch",
3391
+ }
3392
+
3393
+ state_dict = _fetch_state_dict(
3394
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
3395
+ weight_name=weight_name,
3396
+ use_safetensors=use_safetensors,
3397
+ local_files_only=local_files_only,
3398
+ cache_dir=cache_dir,
3399
+ force_download=force_download,
3400
+ proxies=proxies,
3401
+ token=token,
3402
+ revision=revision,
3403
+ subfolder=subfolder,
3404
+ user_agent=user_agent,
3405
+ allow_pickle=allow_pickle,
3406
+ )
3407
+
3408
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
3409
+ if is_dora_scale_present:
3410
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3411
+ logger.warning(warn_msg)
3412
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
3413
+
3414
+ return state_dict
3415
+
3416
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3417
+ def load_lora_weights(
3418
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
3419
+ ):
3420
+ """
3421
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3422
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3423
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3424
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3425
+ dict is loaded into `self.transformer`.
3426
+
3427
+ Parameters:
3428
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3429
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3430
+ adapter_name (`str`, *optional*):
3431
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3432
+ `default_{i}` where i is the total number of adapters being loaded.
3433
+ low_cpu_mem_usage (`bool`, *optional*):
3434
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3435
+ weights.
3436
+ kwargs (`dict`, *optional*):
3437
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3438
+ """
3439
+ if not USE_PEFT_BACKEND:
3440
+ raise ValueError("PEFT backend is required for this method.")
3441
+
3442
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
3443
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3444
+ raise ValueError(
3445
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3446
+ )
3447
+
3448
+ # if a dict is passed, copy it instead of modifying it inplace
3449
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
3450
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
3451
+
3452
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3453
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3454
+
3455
+ is_correct_format = all("lora" in key for key in state_dict.keys())
3456
+ if not is_correct_format:
3457
+ raise ValueError("Invalid LoRA checkpoint.")
3458
+
3459
+ self.load_lora_into_transformer(
3460
+ state_dict,
3461
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
3462
+ adapter_name=adapter_name,
3463
+ _pipeline=self,
3464
+ low_cpu_mem_usage=low_cpu_mem_usage,
3465
+ )
3466
+
3467
+ @classmethod
3468
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
3469
+ def load_lora_into_transformer(
3470
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
3471
+ ):
3472
+ """
3473
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
3474
+
3475
+ Parameters:
3476
+ state_dict (`dict`):
3477
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3478
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3479
+ encoder lora layers.
3480
+ transformer (`LTXVideoTransformer3DModel`):
3481
+ The Transformer model to load the LoRA layers into.
3482
+ adapter_name (`str`, *optional*):
3483
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3484
+ `default_{i}` where i is the total number of adapters being loaded.
3485
+ low_cpu_mem_usage (`bool`, *optional*):
3486
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3487
+ weights.
3488
+ """
3489
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3490
+ raise ValueError(
3491
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3492
+ )
3493
+
3494
+ # Load the layers corresponding to transformer.
3495
+ logger.info(f"Loading {cls.transformer_name}.")
3496
+ transformer.load_lora_adapter(
3497
+ state_dict,
3498
+ network_alphas=None,
3499
+ adapter_name=adapter_name,
3500
+ _pipeline=_pipeline,
3501
+ low_cpu_mem_usage=low_cpu_mem_usage,
3502
+ )
3503
+
3504
+ @classmethod
3505
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
3506
+ def save_lora_weights(
3507
+ cls,
3508
+ save_directory: Union[str, os.PathLike],
3509
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
3510
+ is_main_process: bool = True,
3511
+ weight_name: str = None,
3512
+ save_function: Callable = None,
3513
+ safe_serialization: bool = True,
3514
+ ):
3515
+ r"""
3516
+ Save the LoRA parameters corresponding to the UNet and text encoder.
3517
+
3518
+ Arguments:
3519
+ save_directory (`str` or `os.PathLike`):
3520
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
3521
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
3522
+ State dict of the LoRA layers corresponding to the `transformer`.
3523
+ is_main_process (`bool`, *optional*, defaults to `True`):
3524
+ Whether the process calling this is the main process or not. Useful during distributed training and you
3525
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
3526
+ process to avoid race conditions.
3527
+ save_function (`Callable`):
3528
+ The function to use to save the state dictionary. Useful during distributed training when you need to
3529
+ replace `torch.save` with another method. Can be configured with the environment variable
3530
+ `DIFFUSERS_SAVE_MODE`.
3531
+ safe_serialization (`bool`, *optional*, defaults to `True`):
3532
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3533
+ """
3534
+ state_dict = {}
3535
+
3536
+ if not transformer_lora_layers:
3537
+ raise ValueError("You must pass `transformer_lora_layers`.")
3538
+
3539
+ if transformer_lora_layers:
3540
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3541
+
3542
+ # Save the model
3543
+ cls.write_lora_layers(
3544
+ state_dict=state_dict,
3545
+ save_directory=save_directory,
3546
+ is_main_process=is_main_process,
3547
+ weight_name=weight_name,
3548
+ save_function=save_function,
3549
+ safe_serialization=safe_serialization,
3550
+ )
3551
+
3552
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
3553
+ def fuse_lora(
3554
+ self,
3555
+ components: List[str] = ["transformer", "text_encoder"],
3556
+ lora_scale: float = 1.0,
3557
+ safe_fusing: bool = False,
3558
+ adapter_names: Optional[List[str]] = None,
3559
+ **kwargs,
3560
+ ):
3561
+ r"""
3562
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
3563
+
3564
+ <Tip warning={true}>
3565
+
3566
+ This is an experimental API.
3567
+
3568
+ </Tip>
3569
+
3570
+ Args:
3571
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
3572
+ lora_scale (`float`, defaults to 1.0):
3573
+ Controls how much to influence the outputs with the LoRA parameters.
3574
+ safe_fusing (`bool`, defaults to `False`):
3575
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
3576
+ adapter_names (`List[str]`, *optional*):
3577
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
3578
+
3579
+ Example:
3580
+
3581
+ ```py
3582
+ from diffusers import DiffusionPipeline
3583
+ import torch
3584
+
3585
+ pipeline = DiffusionPipeline.from_pretrained(
3586
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
3587
+ ).to("cuda")
3588
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
3589
+ pipeline.fuse_lora(lora_scale=0.7)
3590
+ ```
3591
+ """
3592
+ super().fuse_lora(
3593
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3594
+ )
3595
+
3596
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
3597
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
3598
+ r"""
3599
+ Reverses the effect of
3600
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
3601
+
3602
+ <Tip warning={true}>
3603
+
3604
+ This is an experimental API.
3605
+
3606
+ </Tip>
3607
+
3608
+ Args:
3609
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3610
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3611
+ unfuse_text_encoder (`bool`, defaults to `True`):
3612
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
3613
+ LoRA parameters then it won't have any effect.
3614
+ """
3615
+ super().unfuse_lora(components=components)
3616
+
3617
+
3618
+ class SanaLoraLoaderMixin(LoraBaseMixin):
3619
+ r"""
3620
+ Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
3621
+ """
3622
+
3623
+ _lora_loadable_modules = ["transformer"]
3624
+ transformer_name = TRANSFORMER_NAME
3625
+
3626
+ @classmethod
3627
+ @validate_hf_hub_args
3628
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
3629
+ def lora_state_dict(
3630
+ cls,
3631
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3632
+ **kwargs,
3633
+ ):
3634
+ r"""
3635
+ Return state dict for lora weights and the network alphas.
3636
+
3637
+ <Tip warning={true}>
3638
+
3639
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
3640
+
3641
+ This function is experimental and might change in the future.
3642
+
3643
+ </Tip>
3644
+
3645
+ Parameters:
3646
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3647
+ Can be either:
3648
+
3649
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3650
+ the Hub.
3651
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3652
+ with [`ModelMixin.save_pretrained`].
3653
+ - A [torch state
3654
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3655
+
3656
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3657
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3658
+ is not used.
3659
+ force_download (`bool`, *optional*, defaults to `False`):
3660
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3661
+ cached versions if they exist.
3662
+
3663
+ proxies (`Dict[str, str]`, *optional*):
3664
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3665
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3666
+ local_files_only (`bool`, *optional*, defaults to `False`):
3667
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
3668
+ won't be downloaded from the Hub.
3669
+ token (`str` or *bool*, *optional*):
3670
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3671
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
3672
+ revision (`str`, *optional*, defaults to `"main"`):
3673
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3674
+ allowed by Git.
3675
+ subfolder (`str`, *optional*, defaults to `""`):
3676
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
3677
+
3678
+ """
3679
+ # Load the main state dict first which has the LoRA layers for either of
3680
+ # transformer and text encoder or both.
3681
+ cache_dir = kwargs.pop("cache_dir", None)
3682
+ force_download = kwargs.pop("force_download", False)
3683
+ proxies = kwargs.pop("proxies", None)
3684
+ local_files_only = kwargs.pop("local_files_only", None)
3685
+ token = kwargs.pop("token", None)
3686
+ revision = kwargs.pop("revision", None)
3687
+ subfolder = kwargs.pop("subfolder", None)
3688
+ weight_name = kwargs.pop("weight_name", None)
3689
+ use_safetensors = kwargs.pop("use_safetensors", None)
3690
+
3691
+ allow_pickle = False
3692
+ if use_safetensors is None:
3693
+ use_safetensors = True
3694
+ allow_pickle = True
3695
+
3696
+ user_agent = {
3697
+ "file_type": "attn_procs_weights",
3698
+ "framework": "pytorch",
3699
+ }
3700
+
3701
+ state_dict = _fetch_state_dict(
3702
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
3703
+ weight_name=weight_name,
3704
+ use_safetensors=use_safetensors,
3705
+ local_files_only=local_files_only,
3706
+ cache_dir=cache_dir,
3707
+ force_download=force_download,
3708
+ proxies=proxies,
3709
+ token=token,
3710
+ revision=revision,
3711
+ subfolder=subfolder,
3712
+ user_agent=user_agent,
3713
+ allow_pickle=allow_pickle,
3714
+ )
3715
+
3716
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
3717
+ if is_dora_scale_present:
3718
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3719
+ logger.warning(warn_msg)
3720
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
3721
+
3722
+ return state_dict
3723
+
3724
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3725
+ def load_lora_weights(
3726
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
3727
+ ):
3728
+ """
3729
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3730
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3731
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3732
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3733
+ dict is loaded into `self.transformer`.
3734
+
3735
+ Parameters:
3736
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3737
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3738
+ adapter_name (`str`, *optional*):
3739
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3740
+ `default_{i}` where i is the total number of adapters being loaded.
3741
+ low_cpu_mem_usage (`bool`, *optional*):
3742
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3743
+ weights.
3744
+ kwargs (`dict`, *optional*):
3745
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3746
+ """
3747
+ if not USE_PEFT_BACKEND:
3748
+ raise ValueError("PEFT backend is required for this method.")
3749
+
3750
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
3751
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3752
+ raise ValueError(
3753
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3754
+ )
3755
+
3756
+ # if a dict is passed, copy it instead of modifying it inplace
3757
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
3758
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
3759
+
3760
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3761
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3762
+
3763
+ is_correct_format = all("lora" in key for key in state_dict.keys())
3764
+ if not is_correct_format:
3765
+ raise ValueError("Invalid LoRA checkpoint.")
3766
+
3767
+ self.load_lora_into_transformer(
3768
+ state_dict,
3769
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
3770
+ adapter_name=adapter_name,
3771
+ _pipeline=self,
3772
+ low_cpu_mem_usage=low_cpu_mem_usage,
3773
+ )
3774
+
3775
+ @classmethod
3776
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
3777
+ def load_lora_into_transformer(
3778
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
3779
+ ):
3780
+ """
3781
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
3782
+
3783
+ Parameters:
3784
+ state_dict (`dict`):
3785
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3786
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3787
+ encoder lora layers.
3788
+ transformer (`SanaTransformer2DModel`):
3789
+ The Transformer model to load the LoRA layers into.
3790
+ adapter_name (`str`, *optional*):
3791
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3792
+ `default_{i}` where i is the total number of adapters being loaded.
3793
+ low_cpu_mem_usage (`bool`, *optional*):
3794
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3795
+ weights.
3796
+ """
3797
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3798
+ raise ValueError(
3799
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3800
+ )
3801
+
3802
+ # Load the layers corresponding to transformer.
3803
+ logger.info(f"Loading {cls.transformer_name}.")
3804
+ transformer.load_lora_adapter(
3805
+ state_dict,
3806
+ network_alphas=None,
3807
+ adapter_name=adapter_name,
3808
+ _pipeline=_pipeline,
3809
+ low_cpu_mem_usage=low_cpu_mem_usage,
3810
+ )
3811
+
3812
+ @classmethod
3813
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
3814
+ def save_lora_weights(
3815
+ cls,
3816
+ save_directory: Union[str, os.PathLike],
3817
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
3818
+ is_main_process: bool = True,
3819
+ weight_name: str = None,
3820
+ save_function: Callable = None,
3821
+ safe_serialization: bool = True,
3822
+ ):
3823
+ r"""
3824
+ Save the LoRA parameters corresponding to the UNet and text encoder.
3825
+
3826
+ Arguments:
3827
+ save_directory (`str` or `os.PathLike`):
3828
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
3829
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
3830
+ State dict of the LoRA layers corresponding to the `transformer`.
3831
+ is_main_process (`bool`, *optional*, defaults to `True`):
3832
+ Whether the process calling this is the main process or not. Useful during distributed training and you
3833
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
3834
+ process to avoid race conditions.
3835
+ save_function (`Callable`):
3836
+ The function to use to save the state dictionary. Useful during distributed training when you need to
3837
+ replace `torch.save` with another method. Can be configured with the environment variable
3838
+ `DIFFUSERS_SAVE_MODE`.
3839
+ safe_serialization (`bool`, *optional*, defaults to `True`):
3840
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3841
+ """
3842
+ state_dict = {}
3843
+
3844
+ if not transformer_lora_layers:
3845
+ raise ValueError("You must pass `transformer_lora_layers`.")
3846
+
3847
+ if transformer_lora_layers:
3848
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3849
+
3850
+ # Save the model
3851
+ cls.write_lora_layers(
3852
+ state_dict=state_dict,
3853
+ save_directory=save_directory,
3854
+ is_main_process=is_main_process,
3855
+ weight_name=weight_name,
3856
+ save_function=save_function,
3857
+ safe_serialization=safe_serialization,
3858
+ )
3859
+
3860
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
3861
+ def fuse_lora(
3862
+ self,
3863
+ components: List[str] = ["transformer", "text_encoder"],
3864
+ lora_scale: float = 1.0,
3865
+ safe_fusing: bool = False,
3866
+ adapter_names: Optional[List[str]] = None,
3867
+ **kwargs,
3868
+ ):
3869
+ r"""
3870
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
3871
+
3872
+ <Tip warning={true}>
3873
+
3874
+ This is an experimental API.
3875
+
3876
+ </Tip>
3877
+
3878
+ Args:
3879
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
3880
+ lora_scale (`float`, defaults to 1.0):
3881
+ Controls how much to influence the outputs with the LoRA parameters.
3882
+ safe_fusing (`bool`, defaults to `False`):
3883
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
3884
+ adapter_names (`List[str]`, *optional*):
3885
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
3886
+
3887
+ Example:
3888
+
3889
+ ```py
3890
+ from diffusers import DiffusionPipeline
3891
+ import torch
3892
+
3893
+ pipeline = DiffusionPipeline.from_pretrained(
3894
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
3895
+ ).to("cuda")
3896
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
3897
+ pipeline.fuse_lora(lora_scale=0.7)
3898
+ ```
3899
+ """
3900
+ super().fuse_lora(
3901
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3902
+ )
3903
+
3904
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
3905
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
3906
+ r"""
3907
+ Reverses the effect of
3908
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
3909
+
3910
+ <Tip warning={true}>
3911
+
3912
+ This is an experimental API.
3913
+
3914
+ </Tip>
3915
+
3916
+ Args:
3917
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3918
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3919
+ unfuse_text_encoder (`bool`, defaults to `True`):
3920
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
3921
+ LoRA parameters then it won't have any effect.
3922
+ """
3923
+ super().unfuse_lora(components=components)
3924
+
3925
+
3926
+ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
3927
+ r"""
3928
+ Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
3929
+ """
3930
+
3931
+ _lora_loadable_modules = ["transformer"]
3932
+ transformer_name = TRANSFORMER_NAME
3933
+
3934
+ @classmethod
3935
+ @validate_hf_hub_args
3936
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
3937
+ def lora_state_dict(
3938
+ cls,
3939
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3940
+ **kwargs,
3941
+ ):
3942
+ r"""
3943
+ Return state dict for lora weights and the network alphas.
3944
+
3945
+ <Tip warning={true}>
3946
+
3947
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
3948
+
3949
+ This function is experimental and might change in the future.
3950
+
3951
+ </Tip>
3952
+
3953
+ Parameters:
3954
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3955
+ Can be either:
3956
+
3957
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3958
+ the Hub.
3959
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3960
+ with [`ModelMixin.save_pretrained`].
3961
+ - A [torch state
3962
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3963
+
3964
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3965
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3966
+ is not used.
3967
+ force_download (`bool`, *optional*, defaults to `False`):
3968
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3969
+ cached versions if they exist.
3970
+
3971
+ proxies (`Dict[str, str]`, *optional*):
3972
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3973
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3974
+ local_files_only (`bool`, *optional*, defaults to `False`):
3975
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
3976
+ won't be downloaded from the Hub.
3977
+ token (`str` or *bool*, *optional*):
3978
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3979
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
3980
+ revision (`str`, *optional*, defaults to `"main"`):
3981
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3982
+ allowed by Git.
3983
+ subfolder (`str`, *optional*, defaults to `""`):
3984
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
3985
+
3986
+ """
3987
+ # Load the main state dict first which has the LoRA layers for either of
3988
+ # transformer and text encoder or both.
3989
+ cache_dir = kwargs.pop("cache_dir", None)
3990
+ force_download = kwargs.pop("force_download", False)
3991
+ proxies = kwargs.pop("proxies", None)
3992
+ local_files_only = kwargs.pop("local_files_only", None)
3993
+ token = kwargs.pop("token", None)
3994
+ revision = kwargs.pop("revision", None)
3995
+ subfolder = kwargs.pop("subfolder", None)
3996
+ weight_name = kwargs.pop("weight_name", None)
3997
+ use_safetensors = kwargs.pop("use_safetensors", None)
3998
+
3999
+ allow_pickle = False
4000
+ if use_safetensors is None:
4001
+ use_safetensors = True
4002
+ allow_pickle = True
4003
+
4004
+ user_agent = {
4005
+ "file_type": "attn_procs_weights",
4006
+ "framework": "pytorch",
4007
+ }
4008
+
4009
+ state_dict = _fetch_state_dict(
4010
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
4011
+ weight_name=weight_name,
4012
+ use_safetensors=use_safetensors,
4013
+ local_files_only=local_files_only,
4014
+ cache_dir=cache_dir,
4015
+ force_download=force_download,
4016
+ proxies=proxies,
4017
+ token=token,
4018
+ revision=revision,
4019
+ subfolder=subfolder,
4020
+ user_agent=user_agent,
4021
+ allow_pickle=allow_pickle,
4022
+ )
4023
+
4024
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
4025
+ if is_dora_scale_present:
4026
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
4027
+ logger.warning(warn_msg)
4028
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
4029
+
4030
+ return state_dict
4031
+
4032
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
4033
+ def load_lora_weights(
4034
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
4035
+ ):
4036
+ """
4037
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
4038
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
4039
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
4040
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
4041
+ dict is loaded into `self.transformer`.
4042
+
4043
+ Parameters:
4044
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4045
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4046
+ adapter_name (`str`, *optional*):
4047
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4048
+ `default_{i}` where i is the total number of adapters being loaded.
4049
+ low_cpu_mem_usage (`bool`, *optional*):
4050
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4051
+ weights.
4052
+ kwargs (`dict`, *optional*):
4053
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4054
+ """
4055
+ if not USE_PEFT_BACKEND:
4056
+ raise ValueError("PEFT backend is required for this method.")
4057
+
4058
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
4059
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4060
+ raise ValueError(
4061
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4062
+ )
4063
+
4064
+ # if a dict is passed, copy it instead of modifying it inplace
4065
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
4066
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
4067
+
4068
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4069
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4070
+
4071
+ is_correct_format = all("lora" in key for key in state_dict.keys())
4072
+ if not is_correct_format:
4073
+ raise ValueError("Invalid LoRA checkpoint.")
4074
+
4075
+ self.load_lora_into_transformer(
4076
+ state_dict,
4077
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4078
+ adapter_name=adapter_name,
4079
+ _pipeline=self,
4080
+ low_cpu_mem_usage=low_cpu_mem_usage,
4081
+ )
4082
+
4083
+ @classmethod
4084
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
4085
+ def load_lora_into_transformer(
4086
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
4087
+ ):
4088
+ """
4089
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
4090
+
4091
+ Parameters:
4092
+ state_dict (`dict`):
4093
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
4094
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
4095
+ encoder lora layers.
4096
+ transformer (`HunyuanVideoTransformer3DModel`):
4097
+ The Transformer model to load the LoRA layers into.
4098
+ adapter_name (`str`, *optional*):
4099
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4100
+ `default_{i}` where i is the total number of adapters being loaded.
4101
+ low_cpu_mem_usage (`bool`, *optional*):
4102
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4103
+ weights.
4104
+ """
4105
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4106
+ raise ValueError(
4107
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4108
+ )
4109
+
4110
+ # Load the layers corresponding to transformer.
4111
+ logger.info(f"Loading {cls.transformer_name}.")
4112
+ transformer.load_lora_adapter(
4113
+ state_dict,
4114
+ network_alphas=None,
4115
+ adapter_name=adapter_name,
4116
+ _pipeline=_pipeline,
4117
+ low_cpu_mem_usage=low_cpu_mem_usage,
4118
+ )
4119
+
4120
+ @classmethod
4121
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
4122
+ def save_lora_weights(
4123
+ cls,
4124
+ save_directory: Union[str, os.PathLike],
4125
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
4126
+ is_main_process: bool = True,
4127
+ weight_name: str = None,
4128
+ save_function: Callable = None,
4129
+ safe_serialization: bool = True,
4130
+ ):
4131
+ r"""
4132
+ Save the LoRA parameters corresponding to the UNet and text encoder.
4133
+
4134
+ Arguments:
4135
+ save_directory (`str` or `os.PathLike`):
4136
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
4137
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
4138
+ State dict of the LoRA layers corresponding to the `transformer`.
4139
+ is_main_process (`bool`, *optional*, defaults to `True`):
4140
+ Whether the process calling this is the main process or not. Useful during distributed training and you
4141
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
4142
+ process to avoid race conditions.
4143
+ save_function (`Callable`):
4144
+ The function to use to save the state dictionary. Useful during distributed training when you need to
4145
+ replace `torch.save` with another method. Can be configured with the environment variable
4146
+ `DIFFUSERS_SAVE_MODE`.
4147
+ safe_serialization (`bool`, *optional*, defaults to `True`):
4148
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4149
+ """
4150
+ state_dict = {}
4151
+
4152
+ if not transformer_lora_layers:
4153
+ raise ValueError("You must pass `transformer_lora_layers`.")
4154
+
4155
+ if transformer_lora_layers:
4156
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
4157
+
4158
+ # Save the model
4159
+ cls.write_lora_layers(
4160
+ state_dict=state_dict,
4161
+ save_directory=save_directory,
4162
+ is_main_process=is_main_process,
4163
+ weight_name=weight_name,
4164
+ save_function=save_function,
4165
+ safe_serialization=safe_serialization,
4166
+ )
4167
+
4168
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
4169
+ def fuse_lora(
4170
+ self,
4171
+ components: List[str] = ["transformer", "text_encoder"],
4172
+ lora_scale: float = 1.0,
4173
+ safe_fusing: bool = False,
4174
+ adapter_names: Optional[List[str]] = None,
4175
+ **kwargs,
4176
+ ):
4177
+ r"""
4178
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
4179
+
4180
+ <Tip warning={true}>
4181
+
4182
+ This is an experimental API.
4183
+
4184
+ </Tip>
4185
+
4186
+ Args:
4187
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
4188
+ lora_scale (`float`, defaults to 1.0):
4189
+ Controls how much to influence the outputs with the LoRA parameters.
4190
+ safe_fusing (`bool`, defaults to `False`):
4191
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
4192
+ adapter_names (`List[str]`, *optional*):
4193
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
4194
+
4195
+ Example:
4196
+
4197
+ ```py
4198
+ from diffusers import DiffusionPipeline
4199
+ import torch
4200
+
4201
+ pipeline = DiffusionPipeline.from_pretrained(
4202
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
4203
+ ).to("cuda")
4204
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
4205
+ pipeline.fuse_lora(lora_scale=0.7)
4206
+ ```
4207
+ """
4208
+ super().fuse_lora(
4209
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
4210
+ )
4211
+
4212
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
4213
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
4214
+ r"""
4215
+ Reverses the effect of
4216
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
4217
+
4218
+ <Tip warning={true}>
4219
+
4220
+ This is an experimental API.
4221
+
4222
+ </Tip>
4223
+
4224
+ Args:
4225
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
4226
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
4227
+ unfuse_text_encoder (`bool`, defaults to `True`):
4228
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
4229
+ LoRA parameters then it won't have any effect.
4230
+ """
4231
+ super().unfuse_lora(components=components)
4232
+
4233
+
2914
4234
  class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2915
4235
  def __init__(self, *args, **kwargs):
2916
4236
  deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."