diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  import os
15
16
  from typing import Callable, Dict, List, Optional, Union
16
17
 
@@ -21,17 +22,36 @@ from ..utils import (
21
22
  USE_PEFT_BACKEND,
22
23
  convert_state_dict_to_diffusers,
23
24
  convert_state_dict_to_peft,
24
- convert_unet_state_dict_to_peft,
25
25
  deprecate,
26
26
  get_adapter_name,
27
27
  get_peft_kwargs,
28
+ is_peft_available,
28
29
  is_peft_version,
30
+ is_torch_version,
29
31
  is_transformers_available,
32
+ is_transformers_version,
30
33
  logging,
31
34
  scale_lora_layers,
32
35
  )
33
- from .lora_base import LoraBaseMixin
34
- from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
36
+ from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
37
+ from .lora_conversion_utils import (
38
+ _convert_bfl_flux_control_lora_to_diffusers,
39
+ _convert_kohya_flux_lora_to_diffusers,
40
+ _convert_non_diffusers_lora_to_diffusers,
41
+ _convert_xlabs_flux_lora_to_diffusers,
42
+ _maybe_map_sgm_blocks_to_diffusers,
43
+ )
44
+
45
+
46
+ _LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
47
+ if is_torch_version(">=", "1.9.0"):
48
+ if (
49
+ is_peft_available()
50
+ and is_peft_version(">=", "0.13.1")
51
+ and is_transformers_available()
52
+ and is_transformers_version(">", "4.45.2")
53
+ ):
54
+ _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
35
55
 
36
56
 
37
57
  if is_transformers_available():
@@ -43,8 +63,7 @@ TEXT_ENCODER_NAME = "text_encoder"
43
63
  UNET_NAME = "unet"
44
64
  TRANSFORMER_NAME = "transformer"
45
65
 
46
- LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
47
- LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
66
+ _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
48
67
 
49
68
 
50
69
  class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
@@ -78,15 +97,24 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
78
97
  Parameters:
79
98
  pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
80
99
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
81
- kwargs (`dict`, *optional*):
82
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
83
100
  adapter_name (`str`, *optional*):
84
101
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
85
102
  `default_{i}` where i is the total number of adapters being loaded.
103
+ low_cpu_mem_usage (`bool`, *optional*):
104
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
105
+ weights.
106
+ kwargs (`dict`, *optional*):
107
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
86
108
  """
87
109
  if not USE_PEFT_BACKEND:
88
110
  raise ValueError("PEFT backend is required for this method.")
89
111
 
112
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
113
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
114
+ raise ValueError(
115
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
116
+ )
117
+
90
118
  # if a dict is passed, copy it instead of modifying it inplace
91
119
  if isinstance(pretrained_model_name_or_path_or_dict, dict):
92
120
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
@@ -94,7 +122,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
94
122
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
95
123
  state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
96
124
 
97
- is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
125
+ is_correct_format = all("lora" in key for key in state_dict.keys())
98
126
  if not is_correct_format:
99
127
  raise ValueError("Invalid LoRA checkpoint.")
100
128
 
@@ -104,6 +132,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
104
132
  unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
105
133
  adapter_name=adapter_name,
106
134
  _pipeline=self,
135
+ low_cpu_mem_usage=low_cpu_mem_usage,
107
136
  )
108
137
  self.load_lora_into_text_encoder(
109
138
  state_dict,
@@ -114,6 +143,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
114
143
  lora_scale=self.lora_scale,
115
144
  adapter_name=adapter_name,
116
145
  _pipeline=self,
146
+ low_cpu_mem_usage=low_cpu_mem_usage,
117
147
  )
118
148
 
119
149
  @classmethod
@@ -192,7 +222,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
192
222
  "framework": "pytorch",
193
223
  }
194
224
 
195
- state_dict = cls._fetch_state_dict(
225
+ state_dict = _fetch_state_dict(
196
226
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
197
227
  weight_name=weight_name,
198
228
  use_safetensors=use_safetensors,
@@ -206,6 +236,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
206
236
  user_agent=user_agent,
207
237
  allow_pickle=allow_pickle,
208
238
  )
239
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
240
+ if is_dora_scale_present:
241
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
242
+ logger.warning(warn_msg)
243
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
209
244
 
210
245
  network_alphas = None
211
246
  # TODO: replace it with a method from `state_dict_utils`
@@ -227,7 +262,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
227
262
  return state_dict, network_alphas
228
263
 
229
264
  @classmethod
230
- def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
265
+ def load_lora_into_unet(
266
+ cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
267
+ ):
231
268
  """
232
269
  This will load the LoRA layers specified in `state_dict` into `unet`.
233
270
 
@@ -245,10 +282,18 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
245
282
  adapter_name (`str`, *optional*):
246
283
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
247
284
  `default_{i}` where i is the total number of adapters being loaded.
285
+ low_cpu_mem_usage (`bool`, *optional*):
286
+ Speed up model loading only loading the pretrained LoRA weights and not initializing the random
287
+ weights.
248
288
  """
249
289
  if not USE_PEFT_BACKEND:
250
290
  raise ValueError("PEFT backend is required for this method.")
251
291
 
292
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
293
+ raise ValueError(
294
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
295
+ )
296
+
252
297
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
253
298
  # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
254
299
  # their prefixes.
@@ -257,8 +302,13 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
257
302
  if not only_text_encoder:
258
303
  # Load the layers corresponding to UNet.
259
304
  logger.info(f"Loading {cls.unet_name}.")
260
- unet.load_attn_procs(
261
- state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
305
+ unet.load_lora_adapter(
306
+ state_dict,
307
+ prefix=cls.unet_name,
308
+ network_alphas=network_alphas,
309
+ adapter_name=adapter_name,
310
+ _pipeline=_pipeline,
311
+ low_cpu_mem_usage=low_cpu_mem_usage,
262
312
  )
263
313
 
264
314
  @classmethod
@@ -271,6 +321,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
271
321
  lora_scale=1.0,
272
322
  adapter_name=None,
273
323
  _pipeline=None,
324
+ low_cpu_mem_usage=False,
274
325
  ):
275
326
  """
276
327
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -280,7 +331,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
280
331
  A standard state dict containing the lora layer parameters. The key should be prefixed with an
281
332
  additional `text_encoder` to distinguish between unet lora layers.
282
333
  network_alphas (`Dict[str, float]`):
283
- See `LoRALinearLayer` for more details.
334
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
335
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
336
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
284
337
  text_encoder (`CLIPTextModel`):
285
338
  The text encoder model to load the LoRA layers into.
286
339
  prefix (`str`):
@@ -291,10 +344,27 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
291
344
  adapter_name (`str`, *optional*):
292
345
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
293
346
  `default_{i}` where i is the total number of adapters being loaded.
347
+ low_cpu_mem_usage (`bool`, *optional*):
348
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
349
+ weights.
294
350
  """
295
351
  if not USE_PEFT_BACKEND:
296
352
  raise ValueError("PEFT backend is required for this method.")
297
353
 
354
+ peft_kwargs = {}
355
+ if low_cpu_mem_usage:
356
+ if not is_peft_version(">=", "0.13.1"):
357
+ raise ValueError(
358
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
359
+ )
360
+ if not is_transformers_version(">", "4.45.2"):
361
+ # Note from sayakpaul: It's not in `transformers` stable yet.
362
+ # https://github.com/huggingface/transformers/pull/33725/
363
+ raise ValueError(
364
+ "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
365
+ )
366
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
367
+
298
368
  from peft import LoraConfig
299
369
 
300
370
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -342,6 +412,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
342
412
  }
343
413
 
344
414
  lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
415
+
345
416
  if "use_dora" in lora_config_kwargs:
346
417
  if lora_config_kwargs["use_dora"]:
347
418
  if is_peft_version("<", "0.9.0"):
@@ -351,6 +422,17 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
351
422
  else:
352
423
  if is_peft_version("<", "0.9.0"):
353
424
  lora_config_kwargs.pop("use_dora")
425
+
426
+ if "lora_bias" in lora_config_kwargs:
427
+ if lora_config_kwargs["lora_bias"]:
428
+ if is_peft_version("<=", "0.13.2"):
429
+ raise ValueError(
430
+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
431
+ )
432
+ else:
433
+ if is_peft_version("<=", "0.13.2"):
434
+ lora_config_kwargs.pop("lora_bias")
435
+
354
436
  lora_config = LoraConfig(**lora_config_kwargs)
355
437
 
356
438
  # adapter_name
@@ -365,6 +447,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
365
447
  adapter_name=adapter_name,
366
448
  adapter_state_dict=text_encoder_lora_state_dict,
367
449
  peft_config=lora_config,
450
+ **peft_kwargs,
368
451
  )
369
452
 
370
453
  # scale LoRA layers with `lora_scale`
@@ -535,12 +618,21 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
535
618
  adapter_name (`str`, *optional*):
536
619
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
537
620
  `default_{i}` where i is the total number of adapters being loaded.
621
+ low_cpu_mem_usage (`bool`, *optional*):
622
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
623
+ weights.
538
624
  kwargs (`dict`, *optional*):
539
625
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
540
626
  """
541
627
  if not USE_PEFT_BACKEND:
542
628
  raise ValueError("PEFT backend is required for this method.")
543
629
 
630
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
631
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
632
+ raise ValueError(
633
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
634
+ )
635
+
544
636
  # We could have accessed the unet config from `lora_state_dict()` too. We pass
545
637
  # it here explicitly to be able to tell that it's coming from an SDXL
546
638
  # pipeline.
@@ -555,12 +647,18 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
555
647
  unet_config=self.unet.config,
556
648
  **kwargs,
557
649
  )
558
- is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
650
+
651
+ is_correct_format = all("lora" in key for key in state_dict.keys())
559
652
  if not is_correct_format:
560
653
  raise ValueError("Invalid LoRA checkpoint.")
561
654
 
562
655
  self.load_lora_into_unet(
563
- state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
656
+ state_dict,
657
+ network_alphas=network_alphas,
658
+ unet=self.unet,
659
+ adapter_name=adapter_name,
660
+ _pipeline=self,
661
+ low_cpu_mem_usage=low_cpu_mem_usage,
564
662
  )
565
663
  text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
566
664
  if len(text_encoder_state_dict) > 0:
@@ -572,6 +670,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
572
670
  lora_scale=self.lora_scale,
573
671
  adapter_name=adapter_name,
574
672
  _pipeline=self,
673
+ low_cpu_mem_usage=low_cpu_mem_usage,
575
674
  )
576
675
 
577
676
  text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
@@ -584,6 +683,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
584
683
  lora_scale=self.lora_scale,
585
684
  adapter_name=adapter_name,
586
685
  _pipeline=self,
686
+ low_cpu_mem_usage=low_cpu_mem_usage,
587
687
  )
588
688
 
589
689
  @classmethod
@@ -663,7 +763,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
663
763
  "framework": "pytorch",
664
764
  }
665
765
 
666
- state_dict = cls._fetch_state_dict(
766
+ state_dict = _fetch_state_dict(
667
767
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
668
768
  weight_name=weight_name,
669
769
  use_safetensors=use_safetensors,
@@ -677,6 +777,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
677
777
  user_agent=user_agent,
678
778
  allow_pickle=allow_pickle,
679
779
  )
780
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
781
+ if is_dora_scale_present:
782
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
783
+ logger.warning(warn_msg)
784
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
680
785
 
681
786
  network_alphas = None
682
787
  # TODO: replace it with a method from `state_dict_utils`
@@ -699,7 +804,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
699
804
 
700
805
  @classmethod
701
806
  # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
702
- def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
807
+ def load_lora_into_unet(
808
+ cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
809
+ ):
703
810
  """
704
811
  This will load the LoRA layers specified in `state_dict` into `unet`.
705
812
 
@@ -717,10 +824,18 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
717
824
  adapter_name (`str`, *optional*):
718
825
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
719
826
  `default_{i}` where i is the total number of adapters being loaded.
827
+ low_cpu_mem_usage (`bool`, *optional*):
828
+ Speed up model loading only loading the pretrained LoRA weights and not initializing the random
829
+ weights.
720
830
  """
721
831
  if not USE_PEFT_BACKEND:
722
832
  raise ValueError("PEFT backend is required for this method.")
723
833
 
834
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
835
+ raise ValueError(
836
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
837
+ )
838
+
724
839
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
725
840
  # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
726
841
  # their prefixes.
@@ -729,8 +844,13 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
729
844
  if not only_text_encoder:
730
845
  # Load the layers corresponding to UNet.
731
846
  logger.info(f"Loading {cls.unet_name}.")
732
- unet.load_attn_procs(
733
- state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
847
+ unet.load_lora_adapter(
848
+ state_dict,
849
+ prefix=cls.unet_name,
850
+ network_alphas=network_alphas,
851
+ adapter_name=adapter_name,
852
+ _pipeline=_pipeline,
853
+ low_cpu_mem_usage=low_cpu_mem_usage,
734
854
  )
735
855
 
736
856
  @classmethod
@@ -744,6 +864,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
744
864
  lora_scale=1.0,
745
865
  adapter_name=None,
746
866
  _pipeline=None,
867
+ low_cpu_mem_usage=False,
747
868
  ):
748
869
  """
749
870
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -753,7 +874,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
753
874
  A standard state dict containing the lora layer parameters. The key should be prefixed with an
754
875
  additional `text_encoder` to distinguish between unet lora layers.
755
876
  network_alphas (`Dict[str, float]`):
756
- See `LoRALinearLayer` for more details.
877
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
878
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
879
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
757
880
  text_encoder (`CLIPTextModel`):
758
881
  The text encoder model to load the LoRA layers into.
759
882
  prefix (`str`):
@@ -764,10 +887,27 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
764
887
  adapter_name (`str`, *optional*):
765
888
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
766
889
  `default_{i}` where i is the total number of adapters being loaded.
890
+ low_cpu_mem_usage (`bool`, *optional*):
891
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
892
+ weights.
767
893
  """
768
894
  if not USE_PEFT_BACKEND:
769
895
  raise ValueError("PEFT backend is required for this method.")
770
896
 
897
+ peft_kwargs = {}
898
+ if low_cpu_mem_usage:
899
+ if not is_peft_version(">=", "0.13.1"):
900
+ raise ValueError(
901
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
902
+ )
903
+ if not is_transformers_version(">", "4.45.2"):
904
+ # Note from sayakpaul: It's not in `transformers` stable yet.
905
+ # https://github.com/huggingface/transformers/pull/33725/
906
+ raise ValueError(
907
+ "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
908
+ )
909
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
910
+
771
911
  from peft import LoraConfig
772
912
 
773
913
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -815,6 +955,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
815
955
  }
816
956
 
817
957
  lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
958
+
818
959
  if "use_dora" in lora_config_kwargs:
819
960
  if lora_config_kwargs["use_dora"]:
820
961
  if is_peft_version("<", "0.9.0"):
@@ -824,6 +965,17 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
824
965
  else:
825
966
  if is_peft_version("<", "0.9.0"):
826
967
  lora_config_kwargs.pop("use_dora")
968
+
969
+ if "lora_bias" in lora_config_kwargs:
970
+ if lora_config_kwargs["lora_bias"]:
971
+ if is_peft_version("<=", "0.13.2"):
972
+ raise ValueError(
973
+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
974
+ )
975
+ else:
976
+ if is_peft_version("<=", "0.13.2"):
977
+ lora_config_kwargs.pop("lora_bias")
978
+
827
979
  lora_config = LoraConfig(**lora_config_kwargs)
828
980
 
829
981
  # adapter_name
@@ -838,6 +990,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
838
990
  adapter_name=adapter_name,
839
991
  adapter_state_dict=text_encoder_lora_state_dict,
840
992
  peft_config=lora_config,
993
+ **peft_kwargs,
841
994
  )
842
995
 
843
996
  # scale LoRA layers with `lora_scale`
@@ -1065,7 +1218,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1065
1218
  "framework": "pytorch",
1066
1219
  }
1067
1220
 
1068
- state_dict = cls._fetch_state_dict(
1221
+ state_dict = _fetch_state_dict(
1069
1222
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
1070
1223
  weight_name=weight_name,
1071
1224
  use_safetensors=use_safetensors,
@@ -1080,6 +1233,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1080
1233
  allow_pickle=allow_pickle,
1081
1234
  )
1082
1235
 
1236
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
1237
+ if is_dora_scale_present:
1238
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1239
+ logger.warning(warn_msg)
1240
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
1241
+
1083
1242
  return state_dict
1084
1243
 
1085
1244
  def load_lora_weights(
@@ -1100,15 +1259,24 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1100
1259
  Parameters:
1101
1260
  pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1102
1261
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1103
- kwargs (`dict`, *optional*):
1104
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1105
1262
  adapter_name (`str`, *optional*):
1106
1263
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1107
1264
  `default_{i}` where i is the total number of adapters being loaded.
1265
+ low_cpu_mem_usage (`bool`, *optional*):
1266
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1267
+ weights.
1268
+ kwargs (`dict`, *optional*):
1269
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1108
1270
  """
1109
1271
  if not USE_PEFT_BACKEND:
1110
1272
  raise ValueError("PEFT backend is required for this method.")
1111
1273
 
1274
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
1275
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
1276
+ raise ValueError(
1277
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1278
+ )
1279
+
1112
1280
  # if a dict is passed, copy it instead of modifying it inplace
1113
1281
  if isinstance(pretrained_model_name_or_path_or_dict, dict):
1114
1282
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
@@ -1116,16 +1284,21 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1116
1284
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1117
1285
  state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
1118
1286
 
1119
- is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
1287
+ is_correct_format = all("lora" in key for key in state_dict.keys())
1120
1288
  if not is_correct_format:
1121
1289
  raise ValueError("Invalid LoRA checkpoint.")
1122
1290
 
1123
- self.load_lora_into_transformer(
1124
- state_dict,
1125
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
1126
- adapter_name=adapter_name,
1127
- _pipeline=self,
1128
- )
1291
+ transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
1292
+ if len(transformer_state_dict) > 0:
1293
+ self.load_lora_into_transformer(
1294
+ state_dict,
1295
+ transformer=getattr(self, self.transformer_name)
1296
+ if not hasattr(self, "transformer")
1297
+ else self.transformer,
1298
+ adapter_name=adapter_name,
1299
+ _pipeline=self,
1300
+ low_cpu_mem_usage=low_cpu_mem_usage,
1301
+ )
1129
1302
 
1130
1303
  text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1131
1304
  if len(text_encoder_state_dict) > 0:
@@ -1137,6 +1310,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1137
1310
  lora_scale=self.lora_scale,
1138
1311
  adapter_name=adapter_name,
1139
1312
  _pipeline=self,
1313
+ low_cpu_mem_usage=low_cpu_mem_usage,
1140
1314
  )
1141
1315
 
1142
1316
  text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
@@ -1149,10 +1323,13 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1149
1323
  lora_scale=self.lora_scale,
1150
1324
  adapter_name=adapter_name,
1151
1325
  _pipeline=self,
1326
+ low_cpu_mem_usage=low_cpu_mem_usage,
1152
1327
  )
1153
1328
 
1154
1329
  @classmethod
1155
- def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
1330
+ def load_lora_into_transformer(
1331
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
1332
+ ):
1156
1333
  """
1157
1334
  This will load the LoRA layers specified in `state_dict` into `transformer`.
1158
1335
 
@@ -1166,68 +1343,24 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1166
1343
  adapter_name (`str`, *optional*):
1167
1344
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1168
1345
  `default_{i}` where i is the total number of adapters being loaded.
1346
+ low_cpu_mem_usage (`bool`, *optional*):
1347
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1348
+ weights.
1169
1349
  """
1170
- from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
1171
-
1172
- keys = list(state_dict.keys())
1173
-
1174
- transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
1175
- state_dict = {
1176
- k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
1177
- }
1178
-
1179
- if len(state_dict.keys()) > 0:
1180
- # check with first key if is not in peft format
1181
- first_key = next(iter(state_dict.keys()))
1182
- if "lora_A" not in first_key:
1183
- state_dict = convert_unet_state_dict_to_peft(state_dict)
1184
-
1185
- if adapter_name in getattr(transformer, "peft_config", {}):
1186
- raise ValueError(
1187
- f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
1188
- )
1189
-
1190
- rank = {}
1191
- for key, val in state_dict.items():
1192
- if "lora_B" in key:
1193
- rank[key] = val.shape[1]
1194
-
1195
- lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
1196
- if "use_dora" in lora_config_kwargs:
1197
- if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
1198
- raise ValueError(
1199
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1200
- )
1201
- else:
1202
- lora_config_kwargs.pop("use_dora")
1203
- lora_config = LoraConfig(**lora_config_kwargs)
1204
-
1205
- # adapter_name
1206
- if adapter_name is None:
1207
- adapter_name = get_adapter_name(transformer)
1208
-
1209
- # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
1210
- # otherwise loading LoRA weights will lead to an error
1211
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
1212
-
1213
- inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
1214
- incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
1215
-
1216
- if incompatible_keys is not None:
1217
- # check only for unexpected keys
1218
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1219
- if unexpected_keys:
1220
- logger.warning(
1221
- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1222
- f" {unexpected_keys}. "
1223
- )
1350
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
1351
+ raise ValueError(
1352
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1353
+ )
1224
1354
 
1225
- # Offload back.
1226
- if is_model_cpu_offload:
1227
- _pipeline.enable_model_cpu_offload()
1228
- elif is_sequential_cpu_offload:
1229
- _pipeline.enable_sequential_cpu_offload()
1230
- # Unsafe code />
1355
+ # Load the layers corresponding to transformer.
1356
+ logger.info(f"Loading {cls.transformer_name}.")
1357
+ transformer.load_lora_adapter(
1358
+ state_dict,
1359
+ network_alphas=None,
1360
+ adapter_name=adapter_name,
1361
+ _pipeline=_pipeline,
1362
+ low_cpu_mem_usage=low_cpu_mem_usage,
1363
+ )
1231
1364
 
1232
1365
  @classmethod
1233
1366
  # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -1240,6 +1373,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1240
1373
  lora_scale=1.0,
1241
1374
  adapter_name=None,
1242
1375
  _pipeline=None,
1376
+ low_cpu_mem_usage=False,
1243
1377
  ):
1244
1378
  """
1245
1379
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1249,7 +1383,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1249
1383
  A standard state dict containing the lora layer parameters. The key should be prefixed with an
1250
1384
  additional `text_encoder` to distinguish between unet lora layers.
1251
1385
  network_alphas (`Dict[str, float]`):
1252
- See `LoRALinearLayer` for more details.
1386
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
1387
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
1388
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
1253
1389
  text_encoder (`CLIPTextModel`):
1254
1390
  The text encoder model to load the LoRA layers into.
1255
1391
  prefix (`str`):
@@ -1260,10 +1396,27 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1260
1396
  adapter_name (`str`, *optional*):
1261
1397
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1262
1398
  `default_{i}` where i is the total number of adapters being loaded.
1399
+ low_cpu_mem_usage (`bool`, *optional*):
1400
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1401
+ weights.
1263
1402
  """
1264
1403
  if not USE_PEFT_BACKEND:
1265
1404
  raise ValueError("PEFT backend is required for this method.")
1266
1405
 
1406
+ peft_kwargs = {}
1407
+ if low_cpu_mem_usage:
1408
+ if not is_peft_version(">=", "0.13.1"):
1409
+ raise ValueError(
1410
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1411
+ )
1412
+ if not is_transformers_version(">", "4.45.2"):
1413
+ # Note from sayakpaul: It's not in `transformers` stable yet.
1414
+ # https://github.com/huggingface/transformers/pull/33725/
1415
+ raise ValueError(
1416
+ "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
1417
+ )
1418
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
1419
+
1267
1420
  from peft import LoraConfig
1268
1421
 
1269
1422
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -1311,6 +1464,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1311
1464
  }
1312
1465
 
1313
1466
  lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
1467
+
1314
1468
  if "use_dora" in lora_config_kwargs:
1315
1469
  if lora_config_kwargs["use_dora"]:
1316
1470
  if is_peft_version("<", "0.9.0"):
@@ -1320,6 +1474,17 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1320
1474
  else:
1321
1475
  if is_peft_version("<", "0.9.0"):
1322
1476
  lora_config_kwargs.pop("use_dora")
1477
+
1478
+ if "lora_bias" in lora_config_kwargs:
1479
+ if lora_config_kwargs["lora_bias"]:
1480
+ if is_peft_version("<=", "0.13.2"):
1481
+ raise ValueError(
1482
+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
1483
+ )
1484
+ else:
1485
+ if is_peft_version("<=", "0.13.2"):
1486
+ lora_config_kwargs.pop("lora_bias")
1487
+
1323
1488
  lora_config = LoraConfig(**lora_config_kwargs)
1324
1489
 
1325
1490
  # adapter_name
@@ -1334,6 +1499,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1334
1499
  adapter_name=adapter_name,
1335
1500
  adapter_state_dict=text_encoder_lora_state_dict,
1336
1501
  peft_config=lora_config,
1502
+ **peft_kwargs,
1337
1503
  )
1338
1504
 
1339
1505
  # scale LoRA layers with `lora_scale`
@@ -1486,6 +1652,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1486
1652
  _lora_loadable_modules = ["transformer", "text_encoder"]
1487
1653
  transformer_name = TRANSFORMER_NAME
1488
1654
  text_encoder_name = TEXT_ENCODER_NAME
1655
+ _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
1489
1656
 
1490
1657
  @classmethod
1491
1658
  @validate_hf_hub_args
@@ -1562,7 +1729,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1562
1729
  "framework": "pytorch",
1563
1730
  }
1564
1731
 
1565
- state_dict = cls._fetch_state_dict(
1732
+ state_dict = _fetch_state_dict(
1566
1733
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
1567
1734
  weight_name=weight_name,
1568
1735
  use_safetensors=use_safetensors,
@@ -1576,6 +1743,29 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1576
1743
  user_agent=user_agent,
1577
1744
  allow_pickle=allow_pickle,
1578
1745
  )
1746
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
1747
+ if is_dora_scale_present:
1748
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1749
+ logger.warning(warn_msg)
1750
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
1751
+
1752
+ # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
1753
+ is_kohya = any(".lora_down.weight" in k for k in state_dict)
1754
+ if is_kohya:
1755
+ state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
1756
+ # Kohya already takes care of scaling the LoRA parameters with alpha.
1757
+ return (state_dict, None) if return_alphas else state_dict
1758
+
1759
+ is_xlabs = any("processor" in k for k in state_dict)
1760
+ if is_xlabs:
1761
+ state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
1762
+ # xlabs doesn't use `alpha`.
1763
+ return (state_dict, None) if return_alphas else state_dict
1764
+
1765
+ is_bfl_control = any("query_norm.scale" in k for k in state_dict)
1766
+ if is_bfl_control:
1767
+ state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
1768
+ return (state_dict, None) if return_alphas else state_dict
1579
1769
 
1580
1770
  # For state dicts like
1581
1771
  # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
@@ -1621,10 +1811,19 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1621
1811
  adapter_name (`str`, *optional*):
1622
1812
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1623
1813
  `default_{i}` where i is the total number of adapters being loaded.
1814
+ low_cpu_mem_usage (`bool`, *optional*):
1815
+ `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1816
+ weights.
1624
1817
  """
1625
1818
  if not USE_PEFT_BACKEND:
1626
1819
  raise ValueError("PEFT backend is required for this method.")
1627
1820
 
1821
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
1822
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
1823
+ raise ValueError(
1824
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1825
+ )
1826
+
1628
1827
  # if a dict is passed, copy it instead of modifying it inplace
1629
1828
  if isinstance(pretrained_model_name_or_path_or_dict, dict):
1630
1829
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
@@ -1634,18 +1833,57 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1634
1833
  pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
1635
1834
  )
1636
1835
 
1637
- is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
1638
- if not is_correct_format:
1836
+ has_lora_keys = any("lora" in key for key in state_dict.keys())
1837
+
1838
+ # Flux Control LoRAs also have norm keys
1839
+ has_norm_keys = any(
1840
+ norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys
1841
+ )
1842
+
1843
+ if not (has_lora_keys or has_norm_keys):
1639
1844
  raise ValueError("Invalid LoRA checkpoint.")
1640
1845
 
1641
- self.load_lora_into_transformer(
1642
- state_dict,
1643
- network_alphas=network_alphas,
1644
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
1645
- adapter_name=adapter_name,
1646
- _pipeline=self,
1846
+ transformer_lora_state_dict = {
1847
+ k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k
1848
+ }
1849
+ transformer_norm_state_dict = {
1850
+ k: state_dict.pop(k)
1851
+ for k in list(state_dict.keys())
1852
+ if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
1853
+ }
1854
+
1855
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
1856
+ has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
1857
+ transformer, transformer_lora_state_dict, transformer_norm_state_dict
1858
+ )
1859
+
1860
+ if has_param_with_expanded_shape:
1861
+ logger.info(
1862
+ "The LoRA weights contain parameters that have different shapes that expected by the transformer. "
1863
+ "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
1864
+ "To get a comprehensive list of parameter names that were modified, enable debug logging."
1865
+ )
1866
+ transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
1867
+ transformer=transformer, lora_state_dict=transformer_lora_state_dict
1647
1868
  )
1648
1869
 
1870
+ if len(transformer_lora_state_dict) > 0:
1871
+ self.load_lora_into_transformer(
1872
+ transformer_lora_state_dict,
1873
+ network_alphas=network_alphas,
1874
+ transformer=transformer,
1875
+ adapter_name=adapter_name,
1876
+ _pipeline=self,
1877
+ low_cpu_mem_usage=low_cpu_mem_usage,
1878
+ )
1879
+
1880
+ if len(transformer_norm_state_dict) > 0:
1881
+ transformer._transformer_norm_layers = self._load_norm_into_transformer(
1882
+ transformer_norm_state_dict,
1883
+ transformer=transformer,
1884
+ discard_original_layers=False,
1885
+ )
1886
+
1649
1887
  text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1650
1888
  if len(text_encoder_state_dict) > 0:
1651
1889
  self.load_lora_into_text_encoder(
@@ -1656,10 +1894,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1656
1894
  lora_scale=self.lora_scale,
1657
1895
  adapter_name=adapter_name,
1658
1896
  _pipeline=self,
1897
+ low_cpu_mem_usage=low_cpu_mem_usage,
1659
1898
  )
1660
1899
 
1661
1900
  @classmethod
1662
- def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
1901
+ def load_lora_into_transformer(
1902
+ cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
1903
+ ):
1663
1904
  """
1664
1905
  This will load the LoRA layers specified in `state_dict` into `transformer`.
1665
1906
 
@@ -1672,78 +1913,86 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1672
1913
  The value of the network alpha used for stable learning and preventing underflow. This value has the
1673
1914
  same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
1674
1915
  link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
1675
- transformer (`SD3Transformer2DModel`):
1916
+ transformer (`FluxTransformer2DModel`):
1676
1917
  The Transformer model to load the LoRA layers into.
1677
1918
  adapter_name (`str`, *optional*):
1678
1919
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1679
1920
  `default_{i}` where i is the total number of adapters being loaded.
1921
+ low_cpu_mem_usage (`bool`, *optional*):
1922
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1923
+ weights.
1680
1924
  """
1681
- from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
1925
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
1926
+ raise ValueError(
1927
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1928
+ )
1682
1929
 
1930
+ # Load the layers corresponding to transformer.
1683
1931
  keys = list(state_dict.keys())
1932
+ transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
1933
+ if transformer_present:
1934
+ logger.info(f"Loading {cls.transformer_name}.")
1935
+ transformer.load_lora_adapter(
1936
+ state_dict,
1937
+ network_alphas=network_alphas,
1938
+ adapter_name=adapter_name,
1939
+ _pipeline=_pipeline,
1940
+ low_cpu_mem_usage=low_cpu_mem_usage,
1941
+ )
1684
1942
 
1685
- transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
1686
- state_dict = {
1687
- k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
1688
- }
1943
+ @classmethod
1944
+ def _load_norm_into_transformer(
1945
+ cls,
1946
+ state_dict,
1947
+ transformer,
1948
+ prefix=None,
1949
+ discard_original_layers=False,
1950
+ ) -> Dict[str, torch.Tensor]:
1951
+ # Remove prefix if present
1952
+ prefix = prefix or cls.transformer_name
1953
+ for key in list(state_dict.keys()):
1954
+ if key.split(".")[0] == prefix:
1955
+ state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
1956
+
1957
+ # Find invalid keys
1958
+ transformer_state_dict = transformer.state_dict()
1959
+ transformer_keys = set(transformer_state_dict.keys())
1960
+ state_dict_keys = set(state_dict.keys())
1961
+ extra_keys = list(state_dict_keys - transformer_keys)
1962
+
1963
+ if extra_keys:
1964
+ logger.warning(
1965
+ f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}."
1966
+ )
1689
1967
 
1690
- if len(state_dict.keys()) > 0:
1691
- # check with first key if is not in peft format
1692
- first_key = next(iter(state_dict.keys()))
1693
- if "lora_A" not in first_key:
1694
- state_dict = convert_unet_state_dict_to_peft(state_dict)
1968
+ for key in extra_keys:
1969
+ state_dict.pop(key)
1695
1970
 
1696
- if adapter_name in getattr(transformer, "peft_config", {}):
1697
- raise ValueError(
1698
- f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
1699
- )
1971
+ # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected
1972
+ overwritten_layers_state_dict = {}
1973
+ if not discard_original_layers:
1974
+ for key in state_dict.keys():
1975
+ overwritten_layers_state_dict[key] = transformer_state_dict[key].clone()
1700
1976
 
1701
- rank = {}
1702
- for key, val in state_dict.items():
1703
- if "lora_B" in key:
1704
- rank[key] = val.shape[1]
1977
+ logger.info(
1978
+ "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer "
1979
+ 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly '
1980
+ "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. "
1981
+ "If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues."
1982
+ )
1705
1983
 
1706
- if network_alphas is not None and len(network_alphas) >= 1:
1707
- prefix = cls.transformer_name
1708
- alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
1709
- network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
1984
+ # We can't load with strict=True because the current state_dict does not contain all the transformer keys
1985
+ incompatible_keys = transformer.load_state_dict(state_dict, strict=False)
1986
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1710
1987
 
1711
- lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
1712
- if "use_dora" in lora_config_kwargs:
1713
- if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
1714
- raise ValueError(
1715
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1716
- )
1717
- else:
1718
- lora_config_kwargs.pop("use_dora")
1719
- lora_config = LoraConfig(**lora_config_kwargs)
1720
-
1721
- # adapter_name
1722
- if adapter_name is None:
1723
- adapter_name = get_adapter_name(transformer)
1724
-
1725
- # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
1726
- # otherwise loading LoRA weights will lead to an error
1727
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
1728
-
1729
- inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
1730
- incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
1731
-
1732
- if incompatible_keys is not None:
1733
- # check only for unexpected keys
1734
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1735
- if unexpected_keys:
1736
- logger.warning(
1737
- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1738
- f" {unexpected_keys}. "
1739
- )
1988
+ # We shouldn't expect to see the supported norm keys here being present in the unexpected keys.
1989
+ if unexpected_keys:
1990
+ if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys):
1991
+ raise ValueError(
1992
+ f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer."
1993
+ )
1740
1994
 
1741
- # Offload back.
1742
- if is_model_cpu_offload:
1743
- _pipeline.enable_model_cpu_offload()
1744
- elif is_sequential_cpu_offload:
1745
- _pipeline.enable_sequential_cpu_offload()
1746
- # Unsafe code />
1995
+ return overwritten_layers_state_dict
1747
1996
 
1748
1997
  @classmethod
1749
1998
  # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -1756,6 +2005,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1756
2005
  lora_scale=1.0,
1757
2006
  adapter_name=None,
1758
2007
  _pipeline=None,
2008
+ low_cpu_mem_usage=False,
1759
2009
  ):
1760
2010
  """
1761
2011
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1765,7 +2015,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1765
2015
  A standard state dict containing the lora layer parameters. The key should be prefixed with an
1766
2016
  additional `text_encoder` to distinguish between unet lora layers.
1767
2017
  network_alphas (`Dict[str, float]`):
1768
- See `LoRALinearLayer` for more details.
2018
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
2019
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
2020
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
1769
2021
  text_encoder (`CLIPTextModel`):
1770
2022
  The text encoder model to load the LoRA layers into.
1771
2023
  prefix (`str`):
@@ -1776,10 +2028,27 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1776
2028
  adapter_name (`str`, *optional*):
1777
2029
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1778
2030
  `default_{i}` where i is the total number of adapters being loaded.
2031
+ low_cpu_mem_usage (`bool`, *optional*):
2032
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2033
+ weights.
1779
2034
  """
1780
2035
  if not USE_PEFT_BACKEND:
1781
2036
  raise ValueError("PEFT backend is required for this method.")
1782
2037
 
2038
+ peft_kwargs = {}
2039
+ if low_cpu_mem_usage:
2040
+ if not is_peft_version(">=", "0.13.1"):
2041
+ raise ValueError(
2042
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2043
+ )
2044
+ if not is_transformers_version(">", "4.45.2"):
2045
+ # Note from sayakpaul: It's not in `transformers` stable yet.
2046
+ # https://github.com/huggingface/transformers/pull/33725/
2047
+ raise ValueError(
2048
+ "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
2049
+ )
2050
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
2051
+
1783
2052
  from peft import LoraConfig
1784
2053
 
1785
2054
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -1827,6 +2096,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1827
2096
  }
1828
2097
 
1829
2098
  lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
2099
+
1830
2100
  if "use_dora" in lora_config_kwargs:
1831
2101
  if lora_config_kwargs["use_dora"]:
1832
2102
  if is_peft_version("<", "0.9.0"):
@@ -1836,6 +2106,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1836
2106
  else:
1837
2107
  if is_peft_version("<", "0.9.0"):
1838
2108
  lora_config_kwargs.pop("use_dora")
2109
+
2110
+ if "lora_bias" in lora_config_kwargs:
2111
+ if lora_config_kwargs["lora_bias"]:
2112
+ if is_peft_version("<=", "0.13.2"):
2113
+ raise ValueError(
2114
+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
2115
+ )
2116
+ else:
2117
+ if is_peft_version("<=", "0.13.2"):
2118
+ lora_config_kwargs.pop("lora_bias")
2119
+
1839
2120
  lora_config = LoraConfig(**lora_config_kwargs)
1840
2121
 
1841
2122
  # adapter_name
@@ -1850,6 +2131,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1850
2131
  adapter_name=adapter_name,
1851
2132
  adapter_state_dict=text_encoder_lora_state_dict,
1852
2133
  peft_config=lora_config,
2134
+ **peft_kwargs,
1853
2135
  )
1854
2136
 
1855
2137
  # scale LoRA layers with `lora_scale`
@@ -1919,7 +2201,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1919
2201
  safe_serialization=safe_serialization,
1920
2202
  )
1921
2203
 
1922
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
1923
2204
  def fuse_lora(
1924
2205
  self,
1925
2206
  components: List[str] = ["transformer", "text_encoder"],
@@ -1959,6 +2240,19 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1959
2240
  pipeline.fuse_lora(lora_scale=0.7)
1960
2241
  ```
1961
2242
  """
2243
+
2244
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
2245
+ if (
2246
+ hasattr(transformer, "_transformer_norm_layers")
2247
+ and isinstance(transformer._transformer_norm_layers, dict)
2248
+ and len(transformer._transformer_norm_layers.keys()) > 0
2249
+ ):
2250
+ logger.info(
2251
+ "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer "
2252
+ "as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly "
2253
+ "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed."
2254
+ )
2255
+
1962
2256
  super().fuse_lora(
1963
2257
  components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
1964
2258
  )
@@ -1977,8 +2271,168 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1977
2271
  Args:
1978
2272
  components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
1979
2273
  """
2274
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
2275
+ if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
2276
+ transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
2277
+
1980
2278
  super().unfuse_lora(components=components)
1981
2279
 
2280
+ # We override this here account for `_transformer_norm_layers`.
2281
+ def unload_lora_weights(self):
2282
+ super().unload_lora_weights()
2283
+
2284
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
2285
+ if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
2286
+ transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
2287
+ transformer._transformer_norm_layers = None
2288
+
2289
+ @classmethod
2290
+ def _maybe_expand_transformer_param_shape_or_error_(
2291
+ cls,
2292
+ transformer: torch.nn.Module,
2293
+ lora_state_dict=None,
2294
+ norm_state_dict=None,
2295
+ prefix=None,
2296
+ ) -> bool:
2297
+ """
2298
+ Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and
2299
+ generalizes things a bit so that any parameter that needs expansion receives appropriate treatement.
2300
+ """
2301
+ state_dict = {}
2302
+ if lora_state_dict is not None:
2303
+ state_dict.update(lora_state_dict)
2304
+ if norm_state_dict is not None:
2305
+ state_dict.update(norm_state_dict)
2306
+
2307
+ # Remove prefix if present
2308
+ prefix = prefix or cls.transformer_name
2309
+ for key in list(state_dict.keys()):
2310
+ if key.split(".")[0] == prefix:
2311
+ state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
2312
+
2313
+ # Expand transformer parameter shapes if they don't match lora
2314
+ has_param_with_shape_update = False
2315
+ is_peft_loaded = getattr(transformer, "peft_config", None) is not None
2316
+ for name, module in transformer.named_modules():
2317
+ if isinstance(module, torch.nn.Linear):
2318
+ module_weight = module.weight.data
2319
+ module_bias = module.bias.data if module.bias is not None else None
2320
+ bias = module_bias is not None
2321
+
2322
+ lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
2323
+ lora_A_weight_name = f"{lora_base_name}.lora_A.weight"
2324
+ lora_B_weight_name = f"{lora_base_name}.lora_B.weight"
2325
+ if lora_A_weight_name not in state_dict:
2326
+ continue
2327
+
2328
+ in_features = state_dict[lora_A_weight_name].shape[1]
2329
+ out_features = state_dict[lora_B_weight_name].shape[0]
2330
+
2331
+ # This means there's no need for an expansion in the params, so we simply skip.
2332
+ if tuple(module_weight.shape) == (out_features, in_features):
2333
+ continue
2334
+
2335
+ module_out_features, module_in_features = module_weight.shape
2336
+ debug_message = ""
2337
+ if in_features > module_in_features:
2338
+ debug_message += (
2339
+ f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
2340
+ f"checkpoint contains higher number of features than expected. The number of input_features will be "
2341
+ f"expanded from {module_in_features} to {in_features}"
2342
+ )
2343
+ if out_features > module_out_features:
2344
+ debug_message += (
2345
+ ", and the number of output features will be "
2346
+ f"expanded from {module_out_features} to {out_features}."
2347
+ )
2348
+ else:
2349
+ debug_message += "."
2350
+ if debug_message:
2351
+ logger.debug(debug_message)
2352
+
2353
+ if out_features > module_out_features or in_features > module_in_features:
2354
+ has_param_with_shape_update = True
2355
+ parent_module_name, _, current_module_name = name.rpartition(".")
2356
+ parent_module = transformer.get_submodule(parent_module_name)
2357
+
2358
+ with torch.device("meta"):
2359
+ expanded_module = torch.nn.Linear(
2360
+ in_features, out_features, bias=bias, dtype=module_weight.dtype
2361
+ )
2362
+ # Only weights are expanded and biases are not. This is because only the input dimensions
2363
+ # are changed while the output dimensions remain the same. The shape of the weight tensor
2364
+ # is (out_features, in_features), while the shape of bias tensor is (out_features,), which
2365
+ # explains the reason why only weights are expanded.
2366
+ new_weight = torch.zeros_like(
2367
+ expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
2368
+ )
2369
+ slices = tuple(slice(0, dim) for dim in module_weight.shape)
2370
+ new_weight[slices] = module_weight
2371
+ tmp_state_dict = {"weight": new_weight}
2372
+ if module_bias is not None:
2373
+ tmp_state_dict["bias"] = module_bias
2374
+ expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)
2375
+
2376
+ setattr(parent_module, current_module_name, expanded_module)
2377
+
2378
+ del tmp_state_dict
2379
+
2380
+ if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
2381
+ attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
2382
+ new_value = int(expanded_module.weight.data.shape[1])
2383
+ old_value = getattr(transformer.config, attribute_name)
2384
+ setattr(transformer.config, attribute_name, new_value)
2385
+ logger.info(
2386
+ f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
2387
+ )
2388
+
2389
+ return has_param_with_shape_update
2390
+
2391
+ @classmethod
2392
+ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
2393
+ expanded_module_names = set()
2394
+ transformer_state_dict = transformer.state_dict()
2395
+ prefix = f"{cls.transformer_name}."
2396
+
2397
+ lora_module_names = [
2398
+ key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
2399
+ ]
2400
+ lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
2401
+ lora_module_names = sorted(set(lora_module_names))
2402
+ transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
2403
+ unexpected_modules = set(lora_module_names) - set(transformer_module_names)
2404
+ if unexpected_modules:
2405
+ logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
2406
+
2407
+ is_peft_loaded = getattr(transformer, "peft_config", None) is not None
2408
+ for k in lora_module_names:
2409
+ if k in unexpected_modules:
2410
+ continue
2411
+
2412
+ base_param_name = (
2413
+ f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
2414
+ )
2415
+ base_weight_param = transformer_state_dict[base_param_name]
2416
+ lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
2417
+
2418
+ if base_weight_param.shape[1] > lora_A_param.shape[1]:
2419
+ shape = (lora_A_param.shape[0], base_weight_param.shape[1])
2420
+ expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
2421
+ expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
2422
+ lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
2423
+ expanded_module_names.add(k)
2424
+ elif base_weight_param.shape[1] < lora_A_param.shape[1]:
2425
+ raise NotImplementedError(
2426
+ f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
2427
+ )
2428
+
2429
+ if expanded_module_names:
2430
+ logger.info(
2431
+ f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
2432
+ )
2433
+
2434
+ return lora_state_dict
2435
+
1982
2436
 
1983
2437
  # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
1984
2438
  # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
@@ -1988,7 +2442,10 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
1988
2442
  text_encoder_name = TEXT_ENCODER_NAME
1989
2443
 
1990
2444
  @classmethod
1991
- def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
2445
+ # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
2446
+ def load_lora_into_transformer(
2447
+ cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
2448
+ ):
1992
2449
  """
1993
2450
  This will load the LoRA layers specified in `state_dict` into `transformer`.
1994
2451
 
@@ -1998,78 +2455,35 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
1998
2455
  into the unet or prefixed with an additional `unet` which can be used to distinguish between text
1999
2456
  encoder lora layers.
2000
2457
  network_alphas (`Dict[str, float]`):
2001
- See `LoRALinearLayer` for more details.
2002
- unet (`UNet2DConditionModel`):
2003
- The UNet model to load the LoRA layers into.
2458
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
2459
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
2460
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
2461
+ transformer (`UVit2DModel`):
2462
+ The Transformer model to load the LoRA layers into.
2004
2463
  adapter_name (`str`, *optional*):
2005
2464
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2006
2465
  `default_{i}` where i is the total number of adapters being loaded.
2466
+ low_cpu_mem_usage (`bool`, *optional*):
2467
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2468
+ weights.
2007
2469
  """
2008
- if not USE_PEFT_BACKEND:
2009
- raise ValueError("PEFT backend is required for this method.")
2010
-
2011
- from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
2470
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
2471
+ raise ValueError(
2472
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2473
+ )
2012
2474
 
2475
+ # Load the layers corresponding to transformer.
2013
2476
  keys = list(state_dict.keys())
2014
-
2015
- transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
2016
- state_dict = {
2017
- k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
2018
- }
2019
-
2020
- if network_alphas is not None:
2021
- alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)]
2022
- network_alphas = {
2023
- k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
2024
- }
2025
-
2026
- if len(state_dict.keys()) > 0:
2027
- if adapter_name in getattr(transformer, "peft_config", {}):
2028
- raise ValueError(
2029
- f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
2030
- )
2031
-
2032
- rank = {}
2033
- for key, val in state_dict.items():
2034
- if "lora_B" in key:
2035
- rank[key] = val.shape[1]
2036
-
2037
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
2038
- if "use_dora" in lora_config_kwargs:
2039
- if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
2040
- raise ValueError(
2041
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
2042
- )
2043
- else:
2044
- lora_config_kwargs.pop("use_dora")
2045
- lora_config = LoraConfig(**lora_config_kwargs)
2046
-
2047
- # adapter_name
2048
- if adapter_name is None:
2049
- adapter_name = get_adapter_name(transformer)
2050
-
2051
- # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
2052
- # otherwise loading LoRA weights will lead to an error
2053
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
2054
-
2055
- inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
2056
- incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
2057
-
2058
- if incompatible_keys is not None:
2059
- # check only for unexpected keys
2060
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
2061
- if unexpected_keys:
2062
- logger.warning(
2063
- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
2064
- f" {unexpected_keys}. "
2065
- )
2066
-
2067
- # Offload back.
2068
- if is_model_cpu_offload:
2069
- _pipeline.enable_model_cpu_offload()
2070
- elif is_sequential_cpu_offload:
2071
- _pipeline.enable_sequential_cpu_offload()
2072
- # Unsafe code />
2477
+ transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
2478
+ if transformer_present:
2479
+ logger.info(f"Loading {cls.transformer_name}.")
2480
+ transformer.load_lora_adapter(
2481
+ state_dict,
2482
+ network_alphas=network_alphas,
2483
+ adapter_name=adapter_name,
2484
+ _pipeline=_pipeline,
2485
+ low_cpu_mem_usage=low_cpu_mem_usage,
2486
+ )
2073
2487
 
2074
2488
  @classmethod
2075
2489
  # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -2082,6 +2496,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2082
2496
  lora_scale=1.0,
2083
2497
  adapter_name=None,
2084
2498
  _pipeline=None,
2499
+ low_cpu_mem_usage=False,
2085
2500
  ):
2086
2501
  """
2087
2502
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -2091,7 +2506,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2091
2506
  A standard state dict containing the lora layer parameters. The key should be prefixed with an
2092
2507
  additional `text_encoder` to distinguish between unet lora layers.
2093
2508
  network_alphas (`Dict[str, float]`):
2094
- See `LoRALinearLayer` for more details.
2509
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
2510
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
2511
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
2095
2512
  text_encoder (`CLIPTextModel`):
2096
2513
  The text encoder model to load the LoRA layers into.
2097
2514
  prefix (`str`):
@@ -2102,10 +2519,27 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2102
2519
  adapter_name (`str`, *optional*):
2103
2520
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2104
2521
  `default_{i}` where i is the total number of adapters being loaded.
2522
+ low_cpu_mem_usage (`bool`, *optional*):
2523
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2524
+ weights.
2105
2525
  """
2106
2526
  if not USE_PEFT_BACKEND:
2107
2527
  raise ValueError("PEFT backend is required for this method.")
2108
2528
 
2529
+ peft_kwargs = {}
2530
+ if low_cpu_mem_usage:
2531
+ if not is_peft_version(">=", "0.13.1"):
2532
+ raise ValueError(
2533
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2534
+ )
2535
+ if not is_transformers_version(">", "4.45.2"):
2536
+ # Note from sayakpaul: It's not in `transformers` stable yet.
2537
+ # https://github.com/huggingface/transformers/pull/33725/
2538
+ raise ValueError(
2539
+ "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
2540
+ )
2541
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
2542
+
2109
2543
  from peft import LoraConfig
2110
2544
 
2111
2545
  # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -2153,6 +2587,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2153
2587
  }
2154
2588
 
2155
2589
  lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
2590
+
2156
2591
  if "use_dora" in lora_config_kwargs:
2157
2592
  if lora_config_kwargs["use_dora"]:
2158
2593
  if is_peft_version("<", "0.9.0"):
@@ -2162,6 +2597,17 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2162
2597
  else:
2163
2598
  if is_peft_version("<", "0.9.0"):
2164
2599
  lora_config_kwargs.pop("use_dora")
2600
+
2601
+ if "lora_bias" in lora_config_kwargs:
2602
+ if lora_config_kwargs["lora_bias"]:
2603
+ if is_peft_version("<=", "0.13.2"):
2604
+ raise ValueError(
2605
+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
2606
+ )
2607
+ else:
2608
+ if is_peft_version("<=", "0.13.2"):
2609
+ lora_config_kwargs.pop("lora_bias")
2610
+
2165
2611
  lora_config = LoraConfig(**lora_config_kwargs)
2166
2612
 
2167
2613
  # adapter_name
@@ -2176,6 +2622,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2176
2622
  adapter_name=adapter_name,
2177
2623
  adapter_state_dict=text_encoder_lora_state_dict,
2178
2624
  peft_config=lora_config,
2625
+ **peft_kwargs,
2179
2626
  )
2180
2627
 
2181
2628
  # scale LoRA layers with `lora_scale`
@@ -2245,6 +2692,1545 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2245
2692
  )
2246
2693
 
2247
2694
 
2695
+ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2696
+ r"""
2697
+ Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline`].
2698
+ """
2699
+
2700
+ _lora_loadable_modules = ["transformer"]
2701
+ transformer_name = TRANSFORMER_NAME
2702
+
2703
+ @classmethod
2704
+ @validate_hf_hub_args
2705
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
2706
+ def lora_state_dict(
2707
+ cls,
2708
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
2709
+ **kwargs,
2710
+ ):
2711
+ r"""
2712
+ Return state dict for lora weights and the network alphas.
2713
+
2714
+ <Tip warning={true}>
2715
+
2716
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
2717
+
2718
+ This function is experimental and might change in the future.
2719
+
2720
+ </Tip>
2721
+
2722
+ Parameters:
2723
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
2724
+ Can be either:
2725
+
2726
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
2727
+ the Hub.
2728
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
2729
+ with [`ModelMixin.save_pretrained`].
2730
+ - A [torch state
2731
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
2732
+
2733
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
2734
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
2735
+ is not used.
2736
+ force_download (`bool`, *optional*, defaults to `False`):
2737
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
2738
+ cached versions if they exist.
2739
+
2740
+ proxies (`Dict[str, str]`, *optional*):
2741
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
2742
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
2743
+ local_files_only (`bool`, *optional*, defaults to `False`):
2744
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
2745
+ won't be downloaded from the Hub.
2746
+ token (`str` or *bool*, *optional*):
2747
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
2748
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
2749
+ revision (`str`, *optional*, defaults to `"main"`):
2750
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
2751
+ allowed by Git.
2752
+ subfolder (`str`, *optional*, defaults to `""`):
2753
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
2754
+
2755
+ """
2756
+ # Load the main state dict first which has the LoRA layers for either of
2757
+ # transformer and text encoder or both.
2758
+ cache_dir = kwargs.pop("cache_dir", None)
2759
+ force_download = kwargs.pop("force_download", False)
2760
+ proxies = kwargs.pop("proxies", None)
2761
+ local_files_only = kwargs.pop("local_files_only", None)
2762
+ token = kwargs.pop("token", None)
2763
+ revision = kwargs.pop("revision", None)
2764
+ subfolder = kwargs.pop("subfolder", None)
2765
+ weight_name = kwargs.pop("weight_name", None)
2766
+ use_safetensors = kwargs.pop("use_safetensors", None)
2767
+
2768
+ allow_pickle = False
2769
+ if use_safetensors is None:
2770
+ use_safetensors = True
2771
+ allow_pickle = True
2772
+
2773
+ user_agent = {
2774
+ "file_type": "attn_procs_weights",
2775
+ "framework": "pytorch",
2776
+ }
2777
+
2778
+ state_dict = _fetch_state_dict(
2779
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
2780
+ weight_name=weight_name,
2781
+ use_safetensors=use_safetensors,
2782
+ local_files_only=local_files_only,
2783
+ cache_dir=cache_dir,
2784
+ force_download=force_download,
2785
+ proxies=proxies,
2786
+ token=token,
2787
+ revision=revision,
2788
+ subfolder=subfolder,
2789
+ user_agent=user_agent,
2790
+ allow_pickle=allow_pickle,
2791
+ )
2792
+
2793
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
2794
+ if is_dora_scale_present:
2795
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
2796
+ logger.warning(warn_msg)
2797
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
2798
+
2799
+ return state_dict
2800
+
2801
+ def load_lora_weights(
2802
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
2803
+ ):
2804
+ """
2805
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
2806
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
2807
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
2808
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
2809
+ dict is loaded into `self.transformer`.
2810
+
2811
+ Parameters:
2812
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
2813
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
2814
+ adapter_name (`str`, *optional*):
2815
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2816
+ `default_{i}` where i is the total number of adapters being loaded.
2817
+ low_cpu_mem_usage (`bool`, *optional*):
2818
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2819
+ weights.
2820
+ kwargs (`dict`, *optional*):
2821
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
2822
+ """
2823
+ if not USE_PEFT_BACKEND:
2824
+ raise ValueError("PEFT backend is required for this method.")
2825
+
2826
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
2827
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
2828
+ raise ValueError(
2829
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2830
+ )
2831
+
2832
+ # if a dict is passed, copy it instead of modifying it inplace
2833
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
2834
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
2835
+
2836
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
2837
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
2838
+
2839
+ is_correct_format = all("lora" in key for key in state_dict.keys())
2840
+ if not is_correct_format:
2841
+ raise ValueError("Invalid LoRA checkpoint.")
2842
+
2843
+ self.load_lora_into_transformer(
2844
+ state_dict,
2845
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
2846
+ adapter_name=adapter_name,
2847
+ _pipeline=self,
2848
+ low_cpu_mem_usage=low_cpu_mem_usage,
2849
+ )
2850
+
2851
+ @classmethod
2852
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
2853
+ def load_lora_into_transformer(
2854
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
2855
+ ):
2856
+ """
2857
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
2858
+
2859
+ Parameters:
2860
+ state_dict (`dict`):
2861
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
2862
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
2863
+ encoder lora layers.
2864
+ transformer (`CogVideoXTransformer3DModel`):
2865
+ The Transformer model to load the LoRA layers into.
2866
+ adapter_name (`str`, *optional*):
2867
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2868
+ `default_{i}` where i is the total number of adapters being loaded.
2869
+ low_cpu_mem_usage (`bool`, *optional*):
2870
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2871
+ weights.
2872
+ """
2873
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
2874
+ raise ValueError(
2875
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2876
+ )
2877
+
2878
+ # Load the layers corresponding to transformer.
2879
+ logger.info(f"Loading {cls.transformer_name}.")
2880
+ transformer.load_lora_adapter(
2881
+ state_dict,
2882
+ network_alphas=None,
2883
+ adapter_name=adapter_name,
2884
+ _pipeline=_pipeline,
2885
+ low_cpu_mem_usage=low_cpu_mem_usage,
2886
+ )
2887
+
2888
+ @classmethod
2889
+ # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
2890
+ def save_lora_weights(
2891
+ cls,
2892
+ save_directory: Union[str, os.PathLike],
2893
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
2894
+ is_main_process: bool = True,
2895
+ weight_name: str = None,
2896
+ save_function: Callable = None,
2897
+ safe_serialization: bool = True,
2898
+ ):
2899
+ r"""
2900
+ Save the LoRA parameters corresponding to the UNet and text encoder.
2901
+
2902
+ Arguments:
2903
+ save_directory (`str` or `os.PathLike`):
2904
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
2905
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
2906
+ State dict of the LoRA layers corresponding to the `transformer`.
2907
+ is_main_process (`bool`, *optional*, defaults to `True`):
2908
+ Whether the process calling this is the main process or not. Useful during distributed training and you
2909
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
2910
+ process to avoid race conditions.
2911
+ save_function (`Callable`):
2912
+ The function to use to save the state dictionary. Useful during distributed training when you need to
2913
+ replace `torch.save` with another method. Can be configured with the environment variable
2914
+ `DIFFUSERS_SAVE_MODE`.
2915
+ safe_serialization (`bool`, *optional*, defaults to `True`):
2916
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
2917
+ """
2918
+ state_dict = {}
2919
+
2920
+ if not transformer_lora_layers:
2921
+ raise ValueError("You must pass `transformer_lora_layers`.")
2922
+
2923
+ if transformer_lora_layers:
2924
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
2925
+
2926
+ # Save the model
2927
+ cls.write_lora_layers(
2928
+ state_dict=state_dict,
2929
+ save_directory=save_directory,
2930
+ is_main_process=is_main_process,
2931
+ weight_name=weight_name,
2932
+ save_function=save_function,
2933
+ safe_serialization=safe_serialization,
2934
+ )
2935
+
2936
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
2937
+ def fuse_lora(
2938
+ self,
2939
+ components: List[str] = ["transformer", "text_encoder"],
2940
+ lora_scale: float = 1.0,
2941
+ safe_fusing: bool = False,
2942
+ adapter_names: Optional[List[str]] = None,
2943
+ **kwargs,
2944
+ ):
2945
+ r"""
2946
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
2947
+
2948
+ <Tip warning={true}>
2949
+
2950
+ This is an experimental API.
2951
+
2952
+ </Tip>
2953
+
2954
+ Args:
2955
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
2956
+ lora_scale (`float`, defaults to 1.0):
2957
+ Controls how much to influence the outputs with the LoRA parameters.
2958
+ safe_fusing (`bool`, defaults to `False`):
2959
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
2960
+ adapter_names (`List[str]`, *optional*):
2961
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
2962
+
2963
+ Example:
2964
+
2965
+ ```py
2966
+ from diffusers import DiffusionPipeline
2967
+ import torch
2968
+
2969
+ pipeline = DiffusionPipeline.from_pretrained(
2970
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
2971
+ ).to("cuda")
2972
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
2973
+ pipeline.fuse_lora(lora_scale=0.7)
2974
+ ```
2975
+ """
2976
+ super().fuse_lora(
2977
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
2978
+ )
2979
+
2980
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
2981
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
2982
+ r"""
2983
+ Reverses the effect of
2984
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
2985
+
2986
+ <Tip warning={true}>
2987
+
2988
+ This is an experimental API.
2989
+
2990
+ </Tip>
2991
+
2992
+ Args:
2993
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
2994
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
2995
+ unfuse_text_encoder (`bool`, defaults to `True`):
2996
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
2997
+ LoRA parameters then it won't have any effect.
2998
+ """
2999
+ super().unfuse_lora(components=components)
3000
+
3001
+
3002
+ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3003
+ r"""
3004
+ Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`].
3005
+ """
3006
+
3007
+ _lora_loadable_modules = ["transformer"]
3008
+ transformer_name = TRANSFORMER_NAME
3009
+
3010
+ @classmethod
3011
+ @validate_hf_hub_args
3012
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
3013
+ def lora_state_dict(
3014
+ cls,
3015
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3016
+ **kwargs,
3017
+ ):
3018
+ r"""
3019
+ Return state dict for lora weights and the network alphas.
3020
+
3021
+ <Tip warning={true}>
3022
+
3023
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
3024
+
3025
+ This function is experimental and might change in the future.
3026
+
3027
+ </Tip>
3028
+
3029
+ Parameters:
3030
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3031
+ Can be either:
3032
+
3033
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3034
+ the Hub.
3035
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3036
+ with [`ModelMixin.save_pretrained`].
3037
+ - A [torch state
3038
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3039
+
3040
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3041
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3042
+ is not used.
3043
+ force_download (`bool`, *optional*, defaults to `False`):
3044
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3045
+ cached versions if they exist.
3046
+
3047
+ proxies (`Dict[str, str]`, *optional*):
3048
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3049
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3050
+ local_files_only (`bool`, *optional*, defaults to `False`):
3051
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
3052
+ won't be downloaded from the Hub.
3053
+ token (`str` or *bool*, *optional*):
3054
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3055
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
3056
+ revision (`str`, *optional*, defaults to `"main"`):
3057
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3058
+ allowed by Git.
3059
+ subfolder (`str`, *optional*, defaults to `""`):
3060
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
3061
+
3062
+ """
3063
+ # Load the main state dict first which has the LoRA layers for either of
3064
+ # transformer and text encoder or both.
3065
+ cache_dir = kwargs.pop("cache_dir", None)
3066
+ force_download = kwargs.pop("force_download", False)
3067
+ proxies = kwargs.pop("proxies", None)
3068
+ local_files_only = kwargs.pop("local_files_only", None)
3069
+ token = kwargs.pop("token", None)
3070
+ revision = kwargs.pop("revision", None)
3071
+ subfolder = kwargs.pop("subfolder", None)
3072
+ weight_name = kwargs.pop("weight_name", None)
3073
+ use_safetensors = kwargs.pop("use_safetensors", None)
3074
+
3075
+ allow_pickle = False
3076
+ if use_safetensors is None:
3077
+ use_safetensors = True
3078
+ allow_pickle = True
3079
+
3080
+ user_agent = {
3081
+ "file_type": "attn_procs_weights",
3082
+ "framework": "pytorch",
3083
+ }
3084
+
3085
+ state_dict = _fetch_state_dict(
3086
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
3087
+ weight_name=weight_name,
3088
+ use_safetensors=use_safetensors,
3089
+ local_files_only=local_files_only,
3090
+ cache_dir=cache_dir,
3091
+ force_download=force_download,
3092
+ proxies=proxies,
3093
+ token=token,
3094
+ revision=revision,
3095
+ subfolder=subfolder,
3096
+ user_agent=user_agent,
3097
+ allow_pickle=allow_pickle,
3098
+ )
3099
+
3100
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
3101
+ if is_dora_scale_present:
3102
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3103
+ logger.warning(warn_msg)
3104
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
3105
+
3106
+ return state_dict
3107
+
3108
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3109
+ def load_lora_weights(
3110
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
3111
+ ):
3112
+ """
3113
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3114
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3115
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3116
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3117
+ dict is loaded into `self.transformer`.
3118
+
3119
+ Parameters:
3120
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3121
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3122
+ adapter_name (`str`, *optional*):
3123
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3124
+ `default_{i}` where i is the total number of adapters being loaded.
3125
+ low_cpu_mem_usage (`bool`, *optional*):
3126
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3127
+ weights.
3128
+ kwargs (`dict`, *optional*):
3129
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3130
+ """
3131
+ if not USE_PEFT_BACKEND:
3132
+ raise ValueError("PEFT backend is required for this method.")
3133
+
3134
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
3135
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3136
+ raise ValueError(
3137
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3138
+ )
3139
+
3140
+ # if a dict is passed, copy it instead of modifying it inplace
3141
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
3142
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
3143
+
3144
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3145
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3146
+
3147
+ is_correct_format = all("lora" in key for key in state_dict.keys())
3148
+ if not is_correct_format:
3149
+ raise ValueError("Invalid LoRA checkpoint.")
3150
+
3151
+ self.load_lora_into_transformer(
3152
+ state_dict,
3153
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
3154
+ adapter_name=adapter_name,
3155
+ _pipeline=self,
3156
+ low_cpu_mem_usage=low_cpu_mem_usage,
3157
+ )
3158
+
3159
+ @classmethod
3160
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
3161
+ def load_lora_into_transformer(
3162
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
3163
+ ):
3164
+ """
3165
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
3166
+
3167
+ Parameters:
3168
+ state_dict (`dict`):
3169
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3170
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3171
+ encoder lora layers.
3172
+ transformer (`MochiTransformer3DModel`):
3173
+ The Transformer model to load the LoRA layers into.
3174
+ adapter_name (`str`, *optional*):
3175
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3176
+ `default_{i}` where i is the total number of adapters being loaded.
3177
+ low_cpu_mem_usage (`bool`, *optional*):
3178
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3179
+ weights.
3180
+ """
3181
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3182
+ raise ValueError(
3183
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3184
+ )
3185
+
3186
+ # Load the layers corresponding to transformer.
3187
+ logger.info(f"Loading {cls.transformer_name}.")
3188
+ transformer.load_lora_adapter(
3189
+ state_dict,
3190
+ network_alphas=None,
3191
+ adapter_name=adapter_name,
3192
+ _pipeline=_pipeline,
3193
+ low_cpu_mem_usage=low_cpu_mem_usage,
3194
+ )
3195
+
3196
+ @classmethod
3197
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
3198
+ def save_lora_weights(
3199
+ cls,
3200
+ save_directory: Union[str, os.PathLike],
3201
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
3202
+ is_main_process: bool = True,
3203
+ weight_name: str = None,
3204
+ save_function: Callable = None,
3205
+ safe_serialization: bool = True,
3206
+ ):
3207
+ r"""
3208
+ Save the LoRA parameters corresponding to the UNet and text encoder.
3209
+
3210
+ Arguments:
3211
+ save_directory (`str` or `os.PathLike`):
3212
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
3213
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
3214
+ State dict of the LoRA layers corresponding to the `transformer`.
3215
+ is_main_process (`bool`, *optional*, defaults to `True`):
3216
+ Whether the process calling this is the main process or not. Useful during distributed training and you
3217
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
3218
+ process to avoid race conditions.
3219
+ save_function (`Callable`):
3220
+ The function to use to save the state dictionary. Useful during distributed training when you need to
3221
+ replace `torch.save` with another method. Can be configured with the environment variable
3222
+ `DIFFUSERS_SAVE_MODE`.
3223
+ safe_serialization (`bool`, *optional*, defaults to `True`):
3224
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3225
+ """
3226
+ state_dict = {}
3227
+
3228
+ if not transformer_lora_layers:
3229
+ raise ValueError("You must pass `transformer_lora_layers`.")
3230
+
3231
+ if transformer_lora_layers:
3232
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3233
+
3234
+ # Save the model
3235
+ cls.write_lora_layers(
3236
+ state_dict=state_dict,
3237
+ save_directory=save_directory,
3238
+ is_main_process=is_main_process,
3239
+ weight_name=weight_name,
3240
+ save_function=save_function,
3241
+ safe_serialization=safe_serialization,
3242
+ )
3243
+
3244
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
3245
+ def fuse_lora(
3246
+ self,
3247
+ components: List[str] = ["transformer", "text_encoder"],
3248
+ lora_scale: float = 1.0,
3249
+ safe_fusing: bool = False,
3250
+ adapter_names: Optional[List[str]] = None,
3251
+ **kwargs,
3252
+ ):
3253
+ r"""
3254
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
3255
+
3256
+ <Tip warning={true}>
3257
+
3258
+ This is an experimental API.
3259
+
3260
+ </Tip>
3261
+
3262
+ Args:
3263
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
3264
+ lora_scale (`float`, defaults to 1.0):
3265
+ Controls how much to influence the outputs with the LoRA parameters.
3266
+ safe_fusing (`bool`, defaults to `False`):
3267
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
3268
+ adapter_names (`List[str]`, *optional*):
3269
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
3270
+
3271
+ Example:
3272
+
3273
+ ```py
3274
+ from diffusers import DiffusionPipeline
3275
+ import torch
3276
+
3277
+ pipeline = DiffusionPipeline.from_pretrained(
3278
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
3279
+ ).to("cuda")
3280
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
3281
+ pipeline.fuse_lora(lora_scale=0.7)
3282
+ ```
3283
+ """
3284
+ super().fuse_lora(
3285
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3286
+ )
3287
+
3288
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
3289
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
3290
+ r"""
3291
+ Reverses the effect of
3292
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
3293
+
3294
+ <Tip warning={true}>
3295
+
3296
+ This is an experimental API.
3297
+
3298
+ </Tip>
3299
+
3300
+ Args:
3301
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3302
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3303
+ unfuse_text_encoder (`bool`, defaults to `True`):
3304
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
3305
+ LoRA parameters then it won't have any effect.
3306
+ """
3307
+ super().unfuse_lora(components=components)
3308
+
3309
+
3310
+ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3311
+ r"""
3312
+ Load LoRA layers into [`LTXVideoTransformer3DModel`]. Specific to [`LTXPipeline`].
3313
+ """
3314
+
3315
+ _lora_loadable_modules = ["transformer"]
3316
+ transformer_name = TRANSFORMER_NAME
3317
+
3318
+ @classmethod
3319
+ @validate_hf_hub_args
3320
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
3321
+ def lora_state_dict(
3322
+ cls,
3323
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3324
+ **kwargs,
3325
+ ):
3326
+ r"""
3327
+ Return state dict for lora weights and the network alphas.
3328
+
3329
+ <Tip warning={true}>
3330
+
3331
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
3332
+
3333
+ This function is experimental and might change in the future.
3334
+
3335
+ </Tip>
3336
+
3337
+ Parameters:
3338
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3339
+ Can be either:
3340
+
3341
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3342
+ the Hub.
3343
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3344
+ with [`ModelMixin.save_pretrained`].
3345
+ - A [torch state
3346
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3347
+
3348
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3349
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3350
+ is not used.
3351
+ force_download (`bool`, *optional*, defaults to `False`):
3352
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3353
+ cached versions if they exist.
3354
+
3355
+ proxies (`Dict[str, str]`, *optional*):
3356
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3357
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3358
+ local_files_only (`bool`, *optional*, defaults to `False`):
3359
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
3360
+ won't be downloaded from the Hub.
3361
+ token (`str` or *bool*, *optional*):
3362
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3363
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
3364
+ revision (`str`, *optional*, defaults to `"main"`):
3365
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3366
+ allowed by Git.
3367
+ subfolder (`str`, *optional*, defaults to `""`):
3368
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
3369
+
3370
+ """
3371
+ # Load the main state dict first which has the LoRA layers for either of
3372
+ # transformer and text encoder or both.
3373
+ cache_dir = kwargs.pop("cache_dir", None)
3374
+ force_download = kwargs.pop("force_download", False)
3375
+ proxies = kwargs.pop("proxies", None)
3376
+ local_files_only = kwargs.pop("local_files_only", None)
3377
+ token = kwargs.pop("token", None)
3378
+ revision = kwargs.pop("revision", None)
3379
+ subfolder = kwargs.pop("subfolder", None)
3380
+ weight_name = kwargs.pop("weight_name", None)
3381
+ use_safetensors = kwargs.pop("use_safetensors", None)
3382
+
3383
+ allow_pickle = False
3384
+ if use_safetensors is None:
3385
+ use_safetensors = True
3386
+ allow_pickle = True
3387
+
3388
+ user_agent = {
3389
+ "file_type": "attn_procs_weights",
3390
+ "framework": "pytorch",
3391
+ }
3392
+
3393
+ state_dict = _fetch_state_dict(
3394
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
3395
+ weight_name=weight_name,
3396
+ use_safetensors=use_safetensors,
3397
+ local_files_only=local_files_only,
3398
+ cache_dir=cache_dir,
3399
+ force_download=force_download,
3400
+ proxies=proxies,
3401
+ token=token,
3402
+ revision=revision,
3403
+ subfolder=subfolder,
3404
+ user_agent=user_agent,
3405
+ allow_pickle=allow_pickle,
3406
+ )
3407
+
3408
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
3409
+ if is_dora_scale_present:
3410
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3411
+ logger.warning(warn_msg)
3412
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
3413
+
3414
+ return state_dict
3415
+
3416
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3417
+ def load_lora_weights(
3418
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
3419
+ ):
3420
+ """
3421
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3422
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3423
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3424
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3425
+ dict is loaded into `self.transformer`.
3426
+
3427
+ Parameters:
3428
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3429
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3430
+ adapter_name (`str`, *optional*):
3431
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3432
+ `default_{i}` where i is the total number of adapters being loaded.
3433
+ low_cpu_mem_usage (`bool`, *optional*):
3434
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3435
+ weights.
3436
+ kwargs (`dict`, *optional*):
3437
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3438
+ """
3439
+ if not USE_PEFT_BACKEND:
3440
+ raise ValueError("PEFT backend is required for this method.")
3441
+
3442
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
3443
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3444
+ raise ValueError(
3445
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3446
+ )
3447
+
3448
+ # if a dict is passed, copy it instead of modifying it inplace
3449
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
3450
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
3451
+
3452
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3453
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3454
+
3455
+ is_correct_format = all("lora" in key for key in state_dict.keys())
3456
+ if not is_correct_format:
3457
+ raise ValueError("Invalid LoRA checkpoint.")
3458
+
3459
+ self.load_lora_into_transformer(
3460
+ state_dict,
3461
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
3462
+ adapter_name=adapter_name,
3463
+ _pipeline=self,
3464
+ low_cpu_mem_usage=low_cpu_mem_usage,
3465
+ )
3466
+
3467
+ @classmethod
3468
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
3469
+ def load_lora_into_transformer(
3470
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
3471
+ ):
3472
+ """
3473
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
3474
+
3475
+ Parameters:
3476
+ state_dict (`dict`):
3477
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3478
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3479
+ encoder lora layers.
3480
+ transformer (`LTXVideoTransformer3DModel`):
3481
+ The Transformer model to load the LoRA layers into.
3482
+ adapter_name (`str`, *optional*):
3483
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3484
+ `default_{i}` where i is the total number of adapters being loaded.
3485
+ low_cpu_mem_usage (`bool`, *optional*):
3486
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3487
+ weights.
3488
+ """
3489
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3490
+ raise ValueError(
3491
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3492
+ )
3493
+
3494
+ # Load the layers corresponding to transformer.
3495
+ logger.info(f"Loading {cls.transformer_name}.")
3496
+ transformer.load_lora_adapter(
3497
+ state_dict,
3498
+ network_alphas=None,
3499
+ adapter_name=adapter_name,
3500
+ _pipeline=_pipeline,
3501
+ low_cpu_mem_usage=low_cpu_mem_usage,
3502
+ )
3503
+
3504
+ @classmethod
3505
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
3506
+ def save_lora_weights(
3507
+ cls,
3508
+ save_directory: Union[str, os.PathLike],
3509
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
3510
+ is_main_process: bool = True,
3511
+ weight_name: str = None,
3512
+ save_function: Callable = None,
3513
+ safe_serialization: bool = True,
3514
+ ):
3515
+ r"""
3516
+ Save the LoRA parameters corresponding to the UNet and text encoder.
3517
+
3518
+ Arguments:
3519
+ save_directory (`str` or `os.PathLike`):
3520
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
3521
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
3522
+ State dict of the LoRA layers corresponding to the `transformer`.
3523
+ is_main_process (`bool`, *optional*, defaults to `True`):
3524
+ Whether the process calling this is the main process or not. Useful during distributed training and you
3525
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
3526
+ process to avoid race conditions.
3527
+ save_function (`Callable`):
3528
+ The function to use to save the state dictionary. Useful during distributed training when you need to
3529
+ replace `torch.save` with another method. Can be configured with the environment variable
3530
+ `DIFFUSERS_SAVE_MODE`.
3531
+ safe_serialization (`bool`, *optional*, defaults to `True`):
3532
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3533
+ """
3534
+ state_dict = {}
3535
+
3536
+ if not transformer_lora_layers:
3537
+ raise ValueError("You must pass `transformer_lora_layers`.")
3538
+
3539
+ if transformer_lora_layers:
3540
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3541
+
3542
+ # Save the model
3543
+ cls.write_lora_layers(
3544
+ state_dict=state_dict,
3545
+ save_directory=save_directory,
3546
+ is_main_process=is_main_process,
3547
+ weight_name=weight_name,
3548
+ save_function=save_function,
3549
+ safe_serialization=safe_serialization,
3550
+ )
3551
+
3552
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
3553
+ def fuse_lora(
3554
+ self,
3555
+ components: List[str] = ["transformer", "text_encoder"],
3556
+ lora_scale: float = 1.0,
3557
+ safe_fusing: bool = False,
3558
+ adapter_names: Optional[List[str]] = None,
3559
+ **kwargs,
3560
+ ):
3561
+ r"""
3562
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
3563
+
3564
+ <Tip warning={true}>
3565
+
3566
+ This is an experimental API.
3567
+
3568
+ </Tip>
3569
+
3570
+ Args:
3571
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
3572
+ lora_scale (`float`, defaults to 1.0):
3573
+ Controls how much to influence the outputs with the LoRA parameters.
3574
+ safe_fusing (`bool`, defaults to `False`):
3575
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
3576
+ adapter_names (`List[str]`, *optional*):
3577
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
3578
+
3579
+ Example:
3580
+
3581
+ ```py
3582
+ from diffusers import DiffusionPipeline
3583
+ import torch
3584
+
3585
+ pipeline = DiffusionPipeline.from_pretrained(
3586
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
3587
+ ).to("cuda")
3588
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
3589
+ pipeline.fuse_lora(lora_scale=0.7)
3590
+ ```
3591
+ """
3592
+ super().fuse_lora(
3593
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3594
+ )
3595
+
3596
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
3597
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
3598
+ r"""
3599
+ Reverses the effect of
3600
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
3601
+
3602
+ <Tip warning={true}>
3603
+
3604
+ This is an experimental API.
3605
+
3606
+ </Tip>
3607
+
3608
+ Args:
3609
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3610
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3611
+ unfuse_text_encoder (`bool`, defaults to `True`):
3612
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
3613
+ LoRA parameters then it won't have any effect.
3614
+ """
3615
+ super().unfuse_lora(components=components)
3616
+
3617
+
3618
+ class SanaLoraLoaderMixin(LoraBaseMixin):
3619
+ r"""
3620
+ Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
3621
+ """
3622
+
3623
+ _lora_loadable_modules = ["transformer"]
3624
+ transformer_name = TRANSFORMER_NAME
3625
+
3626
+ @classmethod
3627
+ @validate_hf_hub_args
3628
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
3629
+ def lora_state_dict(
3630
+ cls,
3631
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3632
+ **kwargs,
3633
+ ):
3634
+ r"""
3635
+ Return state dict for lora weights and the network alphas.
3636
+
3637
+ <Tip warning={true}>
3638
+
3639
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
3640
+
3641
+ This function is experimental and might change in the future.
3642
+
3643
+ </Tip>
3644
+
3645
+ Parameters:
3646
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3647
+ Can be either:
3648
+
3649
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3650
+ the Hub.
3651
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3652
+ with [`ModelMixin.save_pretrained`].
3653
+ - A [torch state
3654
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3655
+
3656
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3657
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3658
+ is not used.
3659
+ force_download (`bool`, *optional*, defaults to `False`):
3660
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3661
+ cached versions if they exist.
3662
+
3663
+ proxies (`Dict[str, str]`, *optional*):
3664
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3665
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3666
+ local_files_only (`bool`, *optional*, defaults to `False`):
3667
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
3668
+ won't be downloaded from the Hub.
3669
+ token (`str` or *bool*, *optional*):
3670
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3671
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
3672
+ revision (`str`, *optional*, defaults to `"main"`):
3673
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3674
+ allowed by Git.
3675
+ subfolder (`str`, *optional*, defaults to `""`):
3676
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
3677
+
3678
+ """
3679
+ # Load the main state dict first which has the LoRA layers for either of
3680
+ # transformer and text encoder or both.
3681
+ cache_dir = kwargs.pop("cache_dir", None)
3682
+ force_download = kwargs.pop("force_download", False)
3683
+ proxies = kwargs.pop("proxies", None)
3684
+ local_files_only = kwargs.pop("local_files_only", None)
3685
+ token = kwargs.pop("token", None)
3686
+ revision = kwargs.pop("revision", None)
3687
+ subfolder = kwargs.pop("subfolder", None)
3688
+ weight_name = kwargs.pop("weight_name", None)
3689
+ use_safetensors = kwargs.pop("use_safetensors", None)
3690
+
3691
+ allow_pickle = False
3692
+ if use_safetensors is None:
3693
+ use_safetensors = True
3694
+ allow_pickle = True
3695
+
3696
+ user_agent = {
3697
+ "file_type": "attn_procs_weights",
3698
+ "framework": "pytorch",
3699
+ }
3700
+
3701
+ state_dict = _fetch_state_dict(
3702
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
3703
+ weight_name=weight_name,
3704
+ use_safetensors=use_safetensors,
3705
+ local_files_only=local_files_only,
3706
+ cache_dir=cache_dir,
3707
+ force_download=force_download,
3708
+ proxies=proxies,
3709
+ token=token,
3710
+ revision=revision,
3711
+ subfolder=subfolder,
3712
+ user_agent=user_agent,
3713
+ allow_pickle=allow_pickle,
3714
+ )
3715
+
3716
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
3717
+ if is_dora_scale_present:
3718
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3719
+ logger.warning(warn_msg)
3720
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
3721
+
3722
+ return state_dict
3723
+
3724
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3725
+ def load_lora_weights(
3726
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
3727
+ ):
3728
+ """
3729
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3730
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3731
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3732
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3733
+ dict is loaded into `self.transformer`.
3734
+
3735
+ Parameters:
3736
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3737
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3738
+ adapter_name (`str`, *optional*):
3739
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3740
+ `default_{i}` where i is the total number of adapters being loaded.
3741
+ low_cpu_mem_usage (`bool`, *optional*):
3742
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3743
+ weights.
3744
+ kwargs (`dict`, *optional*):
3745
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3746
+ """
3747
+ if not USE_PEFT_BACKEND:
3748
+ raise ValueError("PEFT backend is required for this method.")
3749
+
3750
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
3751
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3752
+ raise ValueError(
3753
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3754
+ )
3755
+
3756
+ # if a dict is passed, copy it instead of modifying it inplace
3757
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
3758
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
3759
+
3760
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3761
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3762
+
3763
+ is_correct_format = all("lora" in key for key in state_dict.keys())
3764
+ if not is_correct_format:
3765
+ raise ValueError("Invalid LoRA checkpoint.")
3766
+
3767
+ self.load_lora_into_transformer(
3768
+ state_dict,
3769
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
3770
+ adapter_name=adapter_name,
3771
+ _pipeline=self,
3772
+ low_cpu_mem_usage=low_cpu_mem_usage,
3773
+ )
3774
+
3775
+ @classmethod
3776
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
3777
+ def load_lora_into_transformer(
3778
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
3779
+ ):
3780
+ """
3781
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
3782
+
3783
+ Parameters:
3784
+ state_dict (`dict`):
3785
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3786
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3787
+ encoder lora layers.
3788
+ transformer (`SanaTransformer2DModel`):
3789
+ The Transformer model to load the LoRA layers into.
3790
+ adapter_name (`str`, *optional*):
3791
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3792
+ `default_{i}` where i is the total number of adapters being loaded.
3793
+ low_cpu_mem_usage (`bool`, *optional*):
3794
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3795
+ weights.
3796
+ """
3797
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3798
+ raise ValueError(
3799
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3800
+ )
3801
+
3802
+ # Load the layers corresponding to transformer.
3803
+ logger.info(f"Loading {cls.transformer_name}.")
3804
+ transformer.load_lora_adapter(
3805
+ state_dict,
3806
+ network_alphas=None,
3807
+ adapter_name=adapter_name,
3808
+ _pipeline=_pipeline,
3809
+ low_cpu_mem_usage=low_cpu_mem_usage,
3810
+ )
3811
+
3812
+ @classmethod
3813
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
3814
+ def save_lora_weights(
3815
+ cls,
3816
+ save_directory: Union[str, os.PathLike],
3817
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
3818
+ is_main_process: bool = True,
3819
+ weight_name: str = None,
3820
+ save_function: Callable = None,
3821
+ safe_serialization: bool = True,
3822
+ ):
3823
+ r"""
3824
+ Save the LoRA parameters corresponding to the UNet and text encoder.
3825
+
3826
+ Arguments:
3827
+ save_directory (`str` or `os.PathLike`):
3828
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
3829
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
3830
+ State dict of the LoRA layers corresponding to the `transformer`.
3831
+ is_main_process (`bool`, *optional*, defaults to `True`):
3832
+ Whether the process calling this is the main process or not. Useful during distributed training and you
3833
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
3834
+ process to avoid race conditions.
3835
+ save_function (`Callable`):
3836
+ The function to use to save the state dictionary. Useful during distributed training when you need to
3837
+ replace `torch.save` with another method. Can be configured with the environment variable
3838
+ `DIFFUSERS_SAVE_MODE`.
3839
+ safe_serialization (`bool`, *optional*, defaults to `True`):
3840
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3841
+ """
3842
+ state_dict = {}
3843
+
3844
+ if not transformer_lora_layers:
3845
+ raise ValueError("You must pass `transformer_lora_layers`.")
3846
+
3847
+ if transformer_lora_layers:
3848
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3849
+
3850
+ # Save the model
3851
+ cls.write_lora_layers(
3852
+ state_dict=state_dict,
3853
+ save_directory=save_directory,
3854
+ is_main_process=is_main_process,
3855
+ weight_name=weight_name,
3856
+ save_function=save_function,
3857
+ safe_serialization=safe_serialization,
3858
+ )
3859
+
3860
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
3861
+ def fuse_lora(
3862
+ self,
3863
+ components: List[str] = ["transformer", "text_encoder"],
3864
+ lora_scale: float = 1.0,
3865
+ safe_fusing: bool = False,
3866
+ adapter_names: Optional[List[str]] = None,
3867
+ **kwargs,
3868
+ ):
3869
+ r"""
3870
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
3871
+
3872
+ <Tip warning={true}>
3873
+
3874
+ This is an experimental API.
3875
+
3876
+ </Tip>
3877
+
3878
+ Args:
3879
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
3880
+ lora_scale (`float`, defaults to 1.0):
3881
+ Controls how much to influence the outputs with the LoRA parameters.
3882
+ safe_fusing (`bool`, defaults to `False`):
3883
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
3884
+ adapter_names (`List[str]`, *optional*):
3885
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
3886
+
3887
+ Example:
3888
+
3889
+ ```py
3890
+ from diffusers import DiffusionPipeline
3891
+ import torch
3892
+
3893
+ pipeline = DiffusionPipeline.from_pretrained(
3894
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
3895
+ ).to("cuda")
3896
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
3897
+ pipeline.fuse_lora(lora_scale=0.7)
3898
+ ```
3899
+ """
3900
+ super().fuse_lora(
3901
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3902
+ )
3903
+
3904
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
3905
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
3906
+ r"""
3907
+ Reverses the effect of
3908
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
3909
+
3910
+ <Tip warning={true}>
3911
+
3912
+ This is an experimental API.
3913
+
3914
+ </Tip>
3915
+
3916
+ Args:
3917
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3918
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3919
+ unfuse_text_encoder (`bool`, defaults to `True`):
3920
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
3921
+ LoRA parameters then it won't have any effect.
3922
+ """
3923
+ super().unfuse_lora(components=components)
3924
+
3925
+
3926
+ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
3927
+ r"""
3928
+ Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
3929
+ """
3930
+
3931
+ _lora_loadable_modules = ["transformer"]
3932
+ transformer_name = TRANSFORMER_NAME
3933
+
3934
+ @classmethod
3935
+ @validate_hf_hub_args
3936
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
3937
+ def lora_state_dict(
3938
+ cls,
3939
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3940
+ **kwargs,
3941
+ ):
3942
+ r"""
3943
+ Return state dict for lora weights and the network alphas.
3944
+
3945
+ <Tip warning={true}>
3946
+
3947
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
3948
+
3949
+ This function is experimental and might change in the future.
3950
+
3951
+ </Tip>
3952
+
3953
+ Parameters:
3954
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3955
+ Can be either:
3956
+
3957
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3958
+ the Hub.
3959
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3960
+ with [`ModelMixin.save_pretrained`].
3961
+ - A [torch state
3962
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3963
+
3964
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3965
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3966
+ is not used.
3967
+ force_download (`bool`, *optional*, defaults to `False`):
3968
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3969
+ cached versions if they exist.
3970
+
3971
+ proxies (`Dict[str, str]`, *optional*):
3972
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3973
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3974
+ local_files_only (`bool`, *optional*, defaults to `False`):
3975
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
3976
+ won't be downloaded from the Hub.
3977
+ token (`str` or *bool*, *optional*):
3978
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3979
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
3980
+ revision (`str`, *optional*, defaults to `"main"`):
3981
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3982
+ allowed by Git.
3983
+ subfolder (`str`, *optional*, defaults to `""`):
3984
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
3985
+
3986
+ """
3987
+ # Load the main state dict first which has the LoRA layers for either of
3988
+ # transformer and text encoder or both.
3989
+ cache_dir = kwargs.pop("cache_dir", None)
3990
+ force_download = kwargs.pop("force_download", False)
3991
+ proxies = kwargs.pop("proxies", None)
3992
+ local_files_only = kwargs.pop("local_files_only", None)
3993
+ token = kwargs.pop("token", None)
3994
+ revision = kwargs.pop("revision", None)
3995
+ subfolder = kwargs.pop("subfolder", None)
3996
+ weight_name = kwargs.pop("weight_name", None)
3997
+ use_safetensors = kwargs.pop("use_safetensors", None)
3998
+
3999
+ allow_pickle = False
4000
+ if use_safetensors is None:
4001
+ use_safetensors = True
4002
+ allow_pickle = True
4003
+
4004
+ user_agent = {
4005
+ "file_type": "attn_procs_weights",
4006
+ "framework": "pytorch",
4007
+ }
4008
+
4009
+ state_dict = _fetch_state_dict(
4010
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
4011
+ weight_name=weight_name,
4012
+ use_safetensors=use_safetensors,
4013
+ local_files_only=local_files_only,
4014
+ cache_dir=cache_dir,
4015
+ force_download=force_download,
4016
+ proxies=proxies,
4017
+ token=token,
4018
+ revision=revision,
4019
+ subfolder=subfolder,
4020
+ user_agent=user_agent,
4021
+ allow_pickle=allow_pickle,
4022
+ )
4023
+
4024
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
4025
+ if is_dora_scale_present:
4026
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
4027
+ logger.warning(warn_msg)
4028
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
4029
+
4030
+ return state_dict
4031
+
4032
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
4033
+ def load_lora_weights(
4034
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
4035
+ ):
4036
+ """
4037
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
4038
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
4039
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
4040
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
4041
+ dict is loaded into `self.transformer`.
4042
+
4043
+ Parameters:
4044
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4045
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4046
+ adapter_name (`str`, *optional*):
4047
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4048
+ `default_{i}` where i is the total number of adapters being loaded.
4049
+ low_cpu_mem_usage (`bool`, *optional*):
4050
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4051
+ weights.
4052
+ kwargs (`dict`, *optional*):
4053
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4054
+ """
4055
+ if not USE_PEFT_BACKEND:
4056
+ raise ValueError("PEFT backend is required for this method.")
4057
+
4058
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
4059
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4060
+ raise ValueError(
4061
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4062
+ )
4063
+
4064
+ # if a dict is passed, copy it instead of modifying it inplace
4065
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
4066
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
4067
+
4068
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4069
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4070
+
4071
+ is_correct_format = all("lora" in key for key in state_dict.keys())
4072
+ if not is_correct_format:
4073
+ raise ValueError("Invalid LoRA checkpoint.")
4074
+
4075
+ self.load_lora_into_transformer(
4076
+ state_dict,
4077
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4078
+ adapter_name=adapter_name,
4079
+ _pipeline=self,
4080
+ low_cpu_mem_usage=low_cpu_mem_usage,
4081
+ )
4082
+
4083
+ @classmethod
4084
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
4085
+ def load_lora_into_transformer(
4086
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
4087
+ ):
4088
+ """
4089
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
4090
+
4091
+ Parameters:
4092
+ state_dict (`dict`):
4093
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
4094
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
4095
+ encoder lora layers.
4096
+ transformer (`HunyuanVideoTransformer3DModel`):
4097
+ The Transformer model to load the LoRA layers into.
4098
+ adapter_name (`str`, *optional*):
4099
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4100
+ `default_{i}` where i is the total number of adapters being loaded.
4101
+ low_cpu_mem_usage (`bool`, *optional*):
4102
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4103
+ weights.
4104
+ """
4105
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4106
+ raise ValueError(
4107
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4108
+ )
4109
+
4110
+ # Load the layers corresponding to transformer.
4111
+ logger.info(f"Loading {cls.transformer_name}.")
4112
+ transformer.load_lora_adapter(
4113
+ state_dict,
4114
+ network_alphas=None,
4115
+ adapter_name=adapter_name,
4116
+ _pipeline=_pipeline,
4117
+ low_cpu_mem_usage=low_cpu_mem_usage,
4118
+ )
4119
+
4120
+ @classmethod
4121
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
4122
+ def save_lora_weights(
4123
+ cls,
4124
+ save_directory: Union[str, os.PathLike],
4125
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
4126
+ is_main_process: bool = True,
4127
+ weight_name: str = None,
4128
+ save_function: Callable = None,
4129
+ safe_serialization: bool = True,
4130
+ ):
4131
+ r"""
4132
+ Save the LoRA parameters corresponding to the UNet and text encoder.
4133
+
4134
+ Arguments:
4135
+ save_directory (`str` or `os.PathLike`):
4136
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
4137
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
4138
+ State dict of the LoRA layers corresponding to the `transformer`.
4139
+ is_main_process (`bool`, *optional*, defaults to `True`):
4140
+ Whether the process calling this is the main process or not. Useful during distributed training and you
4141
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
4142
+ process to avoid race conditions.
4143
+ save_function (`Callable`):
4144
+ The function to use to save the state dictionary. Useful during distributed training when you need to
4145
+ replace `torch.save` with another method. Can be configured with the environment variable
4146
+ `DIFFUSERS_SAVE_MODE`.
4147
+ safe_serialization (`bool`, *optional*, defaults to `True`):
4148
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4149
+ """
4150
+ state_dict = {}
4151
+
4152
+ if not transformer_lora_layers:
4153
+ raise ValueError("You must pass `transformer_lora_layers`.")
4154
+
4155
+ if transformer_lora_layers:
4156
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
4157
+
4158
+ # Save the model
4159
+ cls.write_lora_layers(
4160
+ state_dict=state_dict,
4161
+ save_directory=save_directory,
4162
+ is_main_process=is_main_process,
4163
+ weight_name=weight_name,
4164
+ save_function=save_function,
4165
+ safe_serialization=safe_serialization,
4166
+ )
4167
+
4168
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
4169
+ def fuse_lora(
4170
+ self,
4171
+ components: List[str] = ["transformer", "text_encoder"],
4172
+ lora_scale: float = 1.0,
4173
+ safe_fusing: bool = False,
4174
+ adapter_names: Optional[List[str]] = None,
4175
+ **kwargs,
4176
+ ):
4177
+ r"""
4178
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
4179
+
4180
+ <Tip warning={true}>
4181
+
4182
+ This is an experimental API.
4183
+
4184
+ </Tip>
4185
+
4186
+ Args:
4187
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
4188
+ lora_scale (`float`, defaults to 1.0):
4189
+ Controls how much to influence the outputs with the LoRA parameters.
4190
+ safe_fusing (`bool`, defaults to `False`):
4191
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
4192
+ adapter_names (`List[str]`, *optional*):
4193
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
4194
+
4195
+ Example:
4196
+
4197
+ ```py
4198
+ from diffusers import DiffusionPipeline
4199
+ import torch
4200
+
4201
+ pipeline = DiffusionPipeline.from_pretrained(
4202
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
4203
+ ).to("cuda")
4204
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
4205
+ pipeline.fuse_lora(lora_scale=0.7)
4206
+ ```
4207
+ """
4208
+ super().fuse_lora(
4209
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
4210
+ )
4211
+
4212
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
4213
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
4214
+ r"""
4215
+ Reverses the effect of
4216
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
4217
+
4218
+ <Tip warning={true}>
4219
+
4220
+ This is an experimental API.
4221
+
4222
+ </Tip>
4223
+
4224
+ Args:
4225
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
4226
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
4227
+ unfuse_text_encoder (`bool`, defaults to `True`):
4228
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
4229
+ LoRA parameters then it won't have any effect.
4230
+ """
4231
+ super().unfuse_lora(components=components)
4232
+
4233
+
2248
4234
  class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2249
4235
  def __init__(self, *args, **kwargs):
2250
4236
  deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."