diffusers 0.27.1__py3-none-any.whl → 0.32.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (445) hide show
  1. diffusers/__init__.py +233 -6
  2. diffusers/callbacks.py +209 -0
  3. diffusers/commands/env.py +102 -6
  4. diffusers/configuration_utils.py +45 -16
  5. diffusers/dependency_versions_table.py +4 -3
  6. diffusers/image_processor.py +434 -110
  7. diffusers/loaders/__init__.py +42 -9
  8. diffusers/loaders/ip_adapter.py +626 -36
  9. diffusers/loaders/lora_base.py +900 -0
  10. diffusers/loaders/lora_conversion_utils.py +991 -125
  11. diffusers/loaders/lora_pipeline.py +3812 -0
  12. diffusers/loaders/peft.py +571 -7
  13. diffusers/loaders/single_file.py +405 -173
  14. diffusers/loaders/single_file_model.py +385 -0
  15. diffusers/loaders/single_file_utils.py +1783 -713
  16. diffusers/loaders/textual_inversion.py +41 -23
  17. diffusers/loaders/transformer_flux.py +181 -0
  18. diffusers/loaders/transformer_sd3.py +89 -0
  19. diffusers/loaders/unet.py +464 -540
  20. diffusers/loaders/unet_loader_utils.py +163 -0
  21. diffusers/models/__init__.py +76 -7
  22. diffusers/models/activations.py +65 -10
  23. diffusers/models/adapter.py +53 -53
  24. diffusers/models/attention.py +605 -18
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +4304 -687
  27. diffusers/models/autoencoders/__init__.py +8 -0
  28. diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
  29. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  30. diffusers/models/autoencoders/autoencoder_kl.py +110 -28
  31. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  32. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  33. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  34. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  35. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  36. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
  37. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  38. diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
  39. diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
  40. diffusers/models/autoencoders/vae.py +41 -29
  41. diffusers/models/autoencoders/vq_model.py +182 -0
  42. diffusers/models/controlnet.py +47 -800
  43. diffusers/models/controlnet_flux.py +70 -0
  44. diffusers/models/controlnet_sd3.py +68 -0
  45. diffusers/models/controlnet_sparsectrl.py +116 -0
  46. diffusers/models/controlnets/__init__.py +23 -0
  47. diffusers/models/controlnets/controlnet.py +872 -0
  48. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
  49. diffusers/models/controlnets/controlnet_flux.py +536 -0
  50. diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
  51. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  52. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  53. diffusers/models/controlnets/controlnet_union.py +832 -0
  54. diffusers/models/controlnets/controlnet_xs.py +1946 -0
  55. diffusers/models/controlnets/multicontrolnet.py +183 -0
  56. diffusers/models/downsampling.py +85 -18
  57. diffusers/models/embeddings.py +1856 -158
  58. diffusers/models/embeddings_flax.py +23 -9
  59. diffusers/models/model_loading_utils.py +480 -0
  60. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  61. diffusers/models/modeling_flax_utils.py +2 -7
  62. diffusers/models/modeling_outputs.py +14 -0
  63. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  64. diffusers/models/modeling_utils.py +611 -146
  65. diffusers/models/normalization.py +361 -20
  66. diffusers/models/resnet.py +18 -23
  67. diffusers/models/transformers/__init__.py +16 -0
  68. diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
  69. diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
  70. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  71. diffusers/models/transformers/dual_transformer_2d.py +9 -8
  72. diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
  73. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  74. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  75. diffusers/models/transformers/pixart_transformer_2d.py +445 -0
  76. diffusers/models/transformers/prior_transformer.py +13 -13
  77. diffusers/models/transformers/sana_transformer.py +488 -0
  78. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  79. diffusers/models/transformers/t5_film_transformer.py +17 -19
  80. diffusers/models/transformers/transformer_2d.py +297 -187
  81. diffusers/models/transformers/transformer_allegro.py +422 -0
  82. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  83. diffusers/models/transformers/transformer_flux.py +593 -0
  84. diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
  85. diffusers/models/transformers/transformer_ltx.py +469 -0
  86. diffusers/models/transformers/transformer_mochi.py +499 -0
  87. diffusers/models/transformers/transformer_sd3.py +461 -0
  88. diffusers/models/transformers/transformer_temporal.py +21 -19
  89. diffusers/models/unets/unet_1d.py +8 -8
  90. diffusers/models/unets/unet_1d_blocks.py +31 -31
  91. diffusers/models/unets/unet_2d.py +17 -10
  92. diffusers/models/unets/unet_2d_blocks.py +225 -149
  93. diffusers/models/unets/unet_2d_condition.py +41 -40
  94. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  95. diffusers/models/unets/unet_3d_blocks.py +192 -1057
  96. diffusers/models/unets/unet_3d_condition.py +22 -27
  97. diffusers/models/unets/unet_i2vgen_xl.py +22 -18
  98. diffusers/models/unets/unet_kandinsky3.py +2 -2
  99. diffusers/models/unets/unet_motion_model.py +1413 -89
  100. diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
  101. diffusers/models/unets/unet_stable_cascade.py +19 -18
  102. diffusers/models/unets/uvit_2d.py +2 -2
  103. diffusers/models/upsampling.py +95 -26
  104. diffusers/models/vq_model.py +12 -164
  105. diffusers/optimization.py +1 -1
  106. diffusers/pipelines/__init__.py +202 -3
  107. diffusers/pipelines/allegro/__init__.py +48 -0
  108. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  109. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  110. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  111. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  112. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  113. diffusers/pipelines/animatediff/__init__.py +8 -0
  114. diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
  115. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
  116. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
  117. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
  118. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
  119. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  120. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  121. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  122. diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
  123. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
  124. diffusers/pipelines/aura_flow/__init__.py +48 -0
  125. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
  126. diffusers/pipelines/auto_pipeline.py +196 -28
  127. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  128. diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
  129. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  130. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  131. diffusers/pipelines/cogvideo/__init__.py +54 -0
  132. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
  133. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  134. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
  135. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
  136. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  137. diffusers/pipelines/cogview3/__init__.py +47 -0
  138. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  139. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  140. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
  141. diffusers/pipelines/controlnet/__init__.py +86 -80
  142. diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
  143. diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
  144. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  145. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
  146. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
  147. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
  148. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
  149. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
  150. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  151. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  152. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  153. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  154. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  155. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
  156. diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
  157. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
  158. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  159. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  160. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
  161. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
  162. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  163. diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
  164. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
  165. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
  166. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
  167. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
  168. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
  169. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  170. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  171. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  172. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
  173. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
  174. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  175. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  176. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
  177. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
  178. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
  179. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
  180. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
  181. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
  182. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  183. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  184. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  185. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  186. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  187. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  188. diffusers/pipelines/flux/__init__.py +69 -0
  189. diffusers/pipelines/flux/modeling_flux.py +47 -0
  190. diffusers/pipelines/flux/pipeline_flux.py +957 -0
  191. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  192. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  193. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  194. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  195. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  196. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  197. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  198. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  199. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  200. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  201. diffusers/pipelines/flux/pipeline_output.py +37 -0
  202. diffusers/pipelines/free_init_utils.py +41 -38
  203. diffusers/pipelines/free_noise_utils.py +596 -0
  204. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  205. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  206. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  207. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  208. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
  209. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  210. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  211. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
  212. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  213. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  214. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  215. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  216. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
  217. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  218. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  219. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  220. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  221. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  222. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  223. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  224. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
  225. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
  226. diffusers/pipelines/kolors/__init__.py +54 -0
  227. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  228. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
  229. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  230. diffusers/pipelines/kolors/text_encoder.py +889 -0
  231. diffusers/pipelines/kolors/tokenizer.py +338 -0
  232. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
  233. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
  234. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
  235. diffusers/pipelines/latte/__init__.py +48 -0
  236. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  237. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
  238. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
  239. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  240. diffusers/pipelines/ltx/__init__.py +50 -0
  241. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  242. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  243. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  244. diffusers/pipelines/lumina/__init__.py +48 -0
  245. diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
  246. diffusers/pipelines/marigold/__init__.py +50 -0
  247. diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
  248. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  249. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  250. diffusers/pipelines/mochi/__init__.py +48 -0
  251. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  252. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  253. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  254. diffusers/pipelines/pag/__init__.py +80 -0
  255. diffusers/pipelines/pag/pag_utils.py +243 -0
  256. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
  257. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  258. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
  259. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  260. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
  261. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  262. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
  263. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  264. diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
  265. diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
  266. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  267. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
  268. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  269. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  270. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
  271. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
  272. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
  273. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  274. diffusers/pipelines/pia/pipeline_pia.py +74 -164
  275. diffusers/pipelines/pipeline_flax_utils.py +5 -10
  276. diffusers/pipelines/pipeline_loading_utils.py +515 -53
  277. diffusers/pipelines/pipeline_utils.py +411 -222
  278. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  279. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
  280. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
  281. diffusers/pipelines/sana/__init__.py +47 -0
  282. diffusers/pipelines/sana/pipeline_output.py +21 -0
  283. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  284. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
  285. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  286. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  287. diffusers/pipelines/shap_e/renderer.py +1 -1
  288. diffusers/pipelines/stable_audio/__init__.py +50 -0
  289. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  290. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
  291. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
  292. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  293. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
  294. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  295. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
  296. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  297. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  298. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  299. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
  300. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
  301. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  302. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
  303. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
  304. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
  305. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
  306. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
  307. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
  308. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
  309. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  310. diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  311. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  312. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  313. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
  314. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
  315. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
  316. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
  317. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
  318. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
  319. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
  320. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
  321. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
  322. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
  323. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  324. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  325. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
  326. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
  327. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
  328. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
  329. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
  330. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  331. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  332. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
  333. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
  334. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  335. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
  336. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
  337. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
  338. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
  339. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  340. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  341. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  342. diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
  343. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
  344. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  345. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  346. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
  347. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  348. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  349. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
  350. diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
  351. diffusers/quantizers/auto.py +139 -0
  352. diffusers/quantizers/base.py +233 -0
  353. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  354. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  355. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  356. diffusers/quantizers/gguf/__init__.py +1 -0
  357. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  358. diffusers/quantizers/gguf/utils.py +456 -0
  359. diffusers/quantizers/quantization_config.py +669 -0
  360. diffusers/quantizers/torchao/__init__.py +15 -0
  361. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  362. diffusers/schedulers/__init__.py +12 -2
  363. diffusers/schedulers/deprecated/__init__.py +1 -1
  364. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  365. diffusers/schedulers/scheduling_amused.py +5 -5
  366. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  367. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  368. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  369. diffusers/schedulers/scheduling_ddim.py +27 -26
  370. diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
  371. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  372. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  373. diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
  374. diffusers/schedulers/scheduling_ddpm.py +27 -30
  375. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  376. diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
  377. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  378. diffusers/schedulers/scheduling_deis_multistep.py +150 -50
  379. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  380. diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
  381. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  382. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
  383. diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
  384. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
  385. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
  386. diffusers/schedulers/scheduling_edm_euler.py +62 -39
  387. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
  388. diffusers/schedulers/scheduling_euler_discrete.py +255 -74
  389. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
  390. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
  391. diffusers/schedulers/scheduling_heun_discrete.py +174 -46
  392. diffusers/schedulers/scheduling_ipndm.py +9 -9
  393. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
  394. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
  395. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  396. diffusers/schedulers/scheduling_lcm.py +23 -29
  397. diffusers/schedulers/scheduling_lms_discrete.py +105 -28
  398. diffusers/schedulers/scheduling_pndm.py +20 -20
  399. diffusers/schedulers/scheduling_repaint.py +21 -21
  400. diffusers/schedulers/scheduling_sasolver.py +157 -60
  401. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  402. diffusers/schedulers/scheduling_tcd.py +41 -36
  403. diffusers/schedulers/scheduling_unclip.py +19 -16
  404. diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
  405. diffusers/schedulers/scheduling_utils.py +12 -5
  406. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  407. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  408. diffusers/training_utils.py +214 -30
  409. diffusers/utils/__init__.py +17 -1
  410. diffusers/utils/constants.py +3 -0
  411. diffusers/utils/doc_utils.py +1 -0
  412. diffusers/utils/dummy_pt_objects.py +592 -7
  413. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  414. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  415. diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
  416. diffusers/utils/dynamic_modules_utils.py +34 -29
  417. diffusers/utils/export_utils.py +50 -6
  418. diffusers/utils/hub_utils.py +131 -17
  419. diffusers/utils/import_utils.py +210 -8
  420. diffusers/utils/loading_utils.py +118 -5
  421. diffusers/utils/logging.py +4 -2
  422. diffusers/utils/peft_utils.py +37 -7
  423. diffusers/utils/state_dict_utils.py +13 -2
  424. diffusers/utils/testing_utils.py +193 -11
  425. diffusers/utils/torch_utils.py +4 -0
  426. diffusers/video_processor.py +113 -0
  427. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
  428. diffusers-0.32.2.dist-info/RECORD +550 -0
  429. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
  430. diffusers/loaders/autoencoder.py +0 -146
  431. diffusers/loaders/controlnet.py +0 -136
  432. diffusers/loaders/lora.py +0 -1349
  433. diffusers/models/prior_transformer.py +0 -12
  434. diffusers/models/t5_film_transformer.py +0 -70
  435. diffusers/models/transformer_2d.py +0 -25
  436. diffusers/models/transformer_temporal.py +0 -34
  437. diffusers/models/unet_1d.py +0 -26
  438. diffusers/models/unet_1d_blocks.py +0 -203
  439. diffusers/models/unet_2d.py +0 -27
  440. diffusers/models/unet_2d_blocks.py +0 -375
  441. diffusers/models/unet_2d_condition.py +0 -25
  442. diffusers-0.27.1.dist-info/RECORD +0 -399
  443. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
  444. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
  445. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,3812 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from typing import Callable, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ from huggingface_hub.utils import validate_hf_hub_args
20
+
21
+ from ..utils import (
22
+ USE_PEFT_BACKEND,
23
+ deprecate,
24
+ get_submodule_by_name,
25
+ is_peft_available,
26
+ is_peft_version,
27
+ is_torch_version,
28
+ is_transformers_available,
29
+ is_transformers_version,
30
+ logging,
31
+ )
32
+ from .lora_base import ( # noqa
33
+ LORA_WEIGHT_NAME,
34
+ LORA_WEIGHT_NAME_SAFE,
35
+ LoraBaseMixin,
36
+ _fetch_state_dict,
37
+ _load_lora_into_text_encoder,
38
+ )
39
+ from .lora_conversion_utils import (
40
+ _convert_bfl_flux_control_lora_to_diffusers,
41
+ _convert_hunyuan_video_lora_to_diffusers,
42
+ _convert_kohya_flux_lora_to_diffusers,
43
+ _convert_non_diffusers_lora_to_diffusers,
44
+ _convert_xlabs_flux_lora_to_diffusers,
45
+ _maybe_map_sgm_blocks_to_diffusers,
46
+ )
47
+
48
+
49
+ _LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
50
+ if is_torch_version(">=", "1.9.0"):
51
+ if (
52
+ is_peft_available()
53
+ and is_peft_version(">=", "0.13.1")
54
+ and is_transformers_available()
55
+ and is_transformers_version(">", "4.45.2")
56
+ ):
57
+ _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
58
+
59
+
60
+ logger = logging.get_logger(__name__)
61
+
62
+ TEXT_ENCODER_NAME = "text_encoder"
63
+ UNET_NAME = "unet"
64
+ TRANSFORMER_NAME = "transformer"
65
+
66
+ _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
67
+
68
+
69
+ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
70
+ r"""
71
+ Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
72
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
73
+ """
74
+
75
+ _lora_loadable_modules = ["unet", "text_encoder"]
76
+ unet_name = UNET_NAME
77
+ text_encoder_name = TEXT_ENCODER_NAME
78
+
79
+ def load_lora_weights(
80
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
81
+ ):
82
+ """
83
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
84
+ `self.text_encoder`.
85
+
86
+ All kwargs are forwarded to `self.lora_state_dict`.
87
+
88
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
89
+ loaded.
90
+
91
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
92
+ loaded into `self.unet`.
93
+
94
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
95
+ dict is loaded into `self.text_encoder`.
96
+
97
+ Parameters:
98
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
99
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
100
+ adapter_name (`str`, *optional*):
101
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
102
+ `default_{i}` where i is the total number of adapters being loaded.
103
+ low_cpu_mem_usage (`bool`, *optional*):
104
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
105
+ weights.
106
+ kwargs (`dict`, *optional*):
107
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
108
+ """
109
+ if not USE_PEFT_BACKEND:
110
+ raise ValueError("PEFT backend is required for this method.")
111
+
112
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
113
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
114
+ raise ValueError(
115
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
116
+ )
117
+
118
+ # if a dict is passed, copy it instead of modifying it inplace
119
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
120
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
121
+
122
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
123
+ state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
124
+
125
+ is_correct_format = all("lora" in key for key in state_dict.keys())
126
+ if not is_correct_format:
127
+ raise ValueError("Invalid LoRA checkpoint.")
128
+
129
+ self.load_lora_into_unet(
130
+ state_dict,
131
+ network_alphas=network_alphas,
132
+ unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
133
+ adapter_name=adapter_name,
134
+ _pipeline=self,
135
+ low_cpu_mem_usage=low_cpu_mem_usage,
136
+ )
137
+ self.load_lora_into_text_encoder(
138
+ state_dict,
139
+ network_alphas=network_alphas,
140
+ text_encoder=getattr(self, self.text_encoder_name)
141
+ if not hasattr(self, "text_encoder")
142
+ else self.text_encoder,
143
+ lora_scale=self.lora_scale,
144
+ adapter_name=adapter_name,
145
+ _pipeline=self,
146
+ low_cpu_mem_usage=low_cpu_mem_usage,
147
+ )
148
+
149
+ @classmethod
150
+ @validate_hf_hub_args
151
+ def lora_state_dict(
152
+ cls,
153
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
154
+ **kwargs,
155
+ ):
156
+ r"""
157
+ Return state dict for lora weights and the network alphas.
158
+
159
+ <Tip warning={true}>
160
+
161
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
162
+
163
+ This function is experimental and might change in the future.
164
+
165
+ </Tip>
166
+
167
+ Parameters:
168
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
169
+ Can be either:
170
+
171
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
172
+ the Hub.
173
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
174
+ with [`ModelMixin.save_pretrained`].
175
+ - A [torch state
176
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
177
+
178
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
179
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
180
+ is not used.
181
+ force_download (`bool`, *optional*, defaults to `False`):
182
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
183
+ cached versions if they exist.
184
+
185
+ proxies (`Dict[str, str]`, *optional*):
186
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
187
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
188
+ local_files_only (`bool`, *optional*, defaults to `False`):
189
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
190
+ won't be downloaded from the Hub.
191
+ token (`str` or *bool*, *optional*):
192
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
193
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
194
+ revision (`str`, *optional*, defaults to `"main"`):
195
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
196
+ allowed by Git.
197
+ subfolder (`str`, *optional*, defaults to `""`):
198
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
199
+ weight_name (`str`, *optional*, defaults to None):
200
+ Name of the serialized state dict file.
201
+ """
202
+ # Load the main state dict first which has the LoRA layers for either of
203
+ # UNet and text encoder or both.
204
+ cache_dir = kwargs.pop("cache_dir", None)
205
+ force_download = kwargs.pop("force_download", False)
206
+ proxies = kwargs.pop("proxies", None)
207
+ local_files_only = kwargs.pop("local_files_only", None)
208
+ token = kwargs.pop("token", None)
209
+ revision = kwargs.pop("revision", None)
210
+ subfolder = kwargs.pop("subfolder", None)
211
+ weight_name = kwargs.pop("weight_name", None)
212
+ unet_config = kwargs.pop("unet_config", None)
213
+ use_safetensors = kwargs.pop("use_safetensors", None)
214
+
215
+ allow_pickle = False
216
+ if use_safetensors is None:
217
+ use_safetensors = True
218
+ allow_pickle = True
219
+
220
+ user_agent = {
221
+ "file_type": "attn_procs_weights",
222
+ "framework": "pytorch",
223
+ }
224
+
225
+ state_dict = _fetch_state_dict(
226
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
227
+ weight_name=weight_name,
228
+ use_safetensors=use_safetensors,
229
+ local_files_only=local_files_only,
230
+ cache_dir=cache_dir,
231
+ force_download=force_download,
232
+ proxies=proxies,
233
+ token=token,
234
+ revision=revision,
235
+ subfolder=subfolder,
236
+ user_agent=user_agent,
237
+ allow_pickle=allow_pickle,
238
+ )
239
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
240
+ if is_dora_scale_present:
241
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
242
+ logger.warning(warn_msg)
243
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
244
+
245
+ network_alphas = None
246
+ # TODO: replace it with a method from `state_dict_utils`
247
+ if all(
248
+ (
249
+ k.startswith("lora_te_")
250
+ or k.startswith("lora_unet_")
251
+ or k.startswith("lora_te1_")
252
+ or k.startswith("lora_te2_")
253
+ )
254
+ for k in state_dict.keys()
255
+ ):
256
+ # Map SDXL blocks correctly.
257
+ if unet_config is not None:
258
+ # use unet config to remap block numbers
259
+ state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
260
+ state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
261
+
262
+ return state_dict, network_alphas
263
+
264
+ @classmethod
265
+ def load_lora_into_unet(
266
+ cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
267
+ ):
268
+ """
269
+ This will load the LoRA layers specified in `state_dict` into `unet`.
270
+
271
+ Parameters:
272
+ state_dict (`dict`):
273
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
274
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
275
+ encoder lora layers.
276
+ network_alphas (`Dict[str, float]`):
277
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
278
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
279
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
280
+ unet (`UNet2DConditionModel`):
281
+ The UNet model to load the LoRA layers into.
282
+ adapter_name (`str`, *optional*):
283
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
284
+ `default_{i}` where i is the total number of adapters being loaded.
285
+ low_cpu_mem_usage (`bool`, *optional*):
286
+ Speed up model loading only loading the pretrained LoRA weights and not initializing the random
287
+ weights.
288
+ """
289
+ if not USE_PEFT_BACKEND:
290
+ raise ValueError("PEFT backend is required for this method.")
291
+
292
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
293
+ raise ValueError(
294
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
295
+ )
296
+
297
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
298
+ # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
299
+ # their prefixes.
300
+ keys = list(state_dict.keys())
301
+ only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
302
+ if not only_text_encoder:
303
+ # Load the layers corresponding to UNet.
304
+ logger.info(f"Loading {cls.unet_name}.")
305
+ unet.load_lora_adapter(
306
+ state_dict,
307
+ prefix=cls.unet_name,
308
+ network_alphas=network_alphas,
309
+ adapter_name=adapter_name,
310
+ _pipeline=_pipeline,
311
+ low_cpu_mem_usage=low_cpu_mem_usage,
312
+ )
313
+
314
+ @classmethod
315
+ def load_lora_into_text_encoder(
316
+ cls,
317
+ state_dict,
318
+ network_alphas,
319
+ text_encoder,
320
+ prefix=None,
321
+ lora_scale=1.0,
322
+ adapter_name=None,
323
+ _pipeline=None,
324
+ low_cpu_mem_usage=False,
325
+ ):
326
+ """
327
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
328
+
329
+ Parameters:
330
+ state_dict (`dict`):
331
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
332
+ additional `text_encoder` to distinguish between unet lora layers.
333
+ network_alphas (`Dict[str, float]`):
334
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
335
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
336
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
337
+ text_encoder (`CLIPTextModel`):
338
+ The text encoder model to load the LoRA layers into.
339
+ prefix (`str`):
340
+ Expected prefix of the `text_encoder` in the `state_dict`.
341
+ lora_scale (`float`):
342
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
343
+ lora layer.
344
+ adapter_name (`str`, *optional*):
345
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
346
+ `default_{i}` where i is the total number of adapters being loaded.
347
+ low_cpu_mem_usage (`bool`, *optional*):
348
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
349
+ weights.
350
+ """
351
+ _load_lora_into_text_encoder(
352
+ state_dict=state_dict,
353
+ network_alphas=network_alphas,
354
+ lora_scale=lora_scale,
355
+ text_encoder=text_encoder,
356
+ prefix=prefix,
357
+ text_encoder_name=cls.text_encoder_name,
358
+ adapter_name=adapter_name,
359
+ _pipeline=_pipeline,
360
+ low_cpu_mem_usage=low_cpu_mem_usage,
361
+ )
362
+
363
+ @classmethod
364
+ def save_lora_weights(
365
+ cls,
366
+ save_directory: Union[str, os.PathLike],
367
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
368
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
369
+ is_main_process: bool = True,
370
+ weight_name: str = None,
371
+ save_function: Callable = None,
372
+ safe_serialization: bool = True,
373
+ ):
374
+ r"""
375
+ Save the LoRA parameters corresponding to the UNet and text encoder.
376
+
377
+ Arguments:
378
+ save_directory (`str` or `os.PathLike`):
379
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
380
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
381
+ State dict of the LoRA layers corresponding to the `unet`.
382
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
383
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
384
+ encoder LoRA state dict because it comes from 🤗 Transformers.
385
+ is_main_process (`bool`, *optional*, defaults to `True`):
386
+ Whether the process calling this is the main process or not. Useful during distributed training and you
387
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
388
+ process to avoid race conditions.
389
+ save_function (`Callable`):
390
+ The function to use to save the state dictionary. Useful during distributed training when you need to
391
+ replace `torch.save` with another method. Can be configured with the environment variable
392
+ `DIFFUSERS_SAVE_MODE`.
393
+ safe_serialization (`bool`, *optional*, defaults to `True`):
394
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
395
+ """
396
+ state_dict = {}
397
+
398
+ if not (unet_lora_layers or text_encoder_lora_layers):
399
+ raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
400
+
401
+ if unet_lora_layers:
402
+ state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
403
+
404
+ if text_encoder_lora_layers:
405
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
406
+
407
+ # Save the model
408
+ cls.write_lora_layers(
409
+ state_dict=state_dict,
410
+ save_directory=save_directory,
411
+ is_main_process=is_main_process,
412
+ weight_name=weight_name,
413
+ save_function=save_function,
414
+ safe_serialization=safe_serialization,
415
+ )
416
+
417
+ def fuse_lora(
418
+ self,
419
+ components: List[str] = ["unet", "text_encoder"],
420
+ lora_scale: float = 1.0,
421
+ safe_fusing: bool = False,
422
+ adapter_names: Optional[List[str]] = None,
423
+ **kwargs,
424
+ ):
425
+ r"""
426
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
427
+
428
+ <Tip warning={true}>
429
+
430
+ This is an experimental API.
431
+
432
+ </Tip>
433
+
434
+ Args:
435
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
436
+ lora_scale (`float`, defaults to 1.0):
437
+ Controls how much to influence the outputs with the LoRA parameters.
438
+ safe_fusing (`bool`, defaults to `False`):
439
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
440
+ adapter_names (`List[str]`, *optional*):
441
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
442
+
443
+ Example:
444
+
445
+ ```py
446
+ from diffusers import DiffusionPipeline
447
+ import torch
448
+
449
+ pipeline = DiffusionPipeline.from_pretrained(
450
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
451
+ ).to("cuda")
452
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
453
+ pipeline.fuse_lora(lora_scale=0.7)
454
+ ```
455
+ """
456
+ super().fuse_lora(
457
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
458
+ )
459
+
460
+ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
461
+ r"""
462
+ Reverses the effect of
463
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
464
+
465
+ <Tip warning={true}>
466
+
467
+ This is an experimental API.
468
+
469
+ </Tip>
470
+
471
+ Args:
472
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
473
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
474
+ unfuse_text_encoder (`bool`, defaults to `True`):
475
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
476
+ LoRA parameters then it won't have any effect.
477
+ """
478
+ super().unfuse_lora(components=components)
479
+
480
+
481
+ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
482
+ r"""
483
+ Load LoRA layers into Stable Diffusion XL [`UNet2DConditionModel`],
484
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and
485
+ [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection).
486
+ """
487
+
488
+ _lora_loadable_modules = ["unet", "text_encoder", "text_encoder_2"]
489
+ unet_name = UNET_NAME
490
+ text_encoder_name = TEXT_ENCODER_NAME
491
+
492
+ def load_lora_weights(
493
+ self,
494
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
495
+ adapter_name: Optional[str] = None,
496
+ **kwargs,
497
+ ):
498
+ """
499
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
500
+ `self.text_encoder`.
501
+
502
+ All kwargs are forwarded to `self.lora_state_dict`.
503
+
504
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
505
+ loaded.
506
+
507
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
508
+ loaded into `self.unet`.
509
+
510
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
511
+ dict is loaded into `self.text_encoder`.
512
+
513
+ Parameters:
514
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
515
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
516
+ adapter_name (`str`, *optional*):
517
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
518
+ `default_{i}` where i is the total number of adapters being loaded.
519
+ low_cpu_mem_usage (`bool`, *optional*):
520
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
521
+ weights.
522
+ kwargs (`dict`, *optional*):
523
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
524
+ """
525
+ if not USE_PEFT_BACKEND:
526
+ raise ValueError("PEFT backend is required for this method.")
527
+
528
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
529
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
530
+ raise ValueError(
531
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
532
+ )
533
+
534
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
535
+ # it here explicitly to be able to tell that it's coming from an SDXL
536
+ # pipeline.
537
+
538
+ # if a dict is passed, copy it instead of modifying it inplace
539
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
540
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
541
+
542
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
543
+ state_dict, network_alphas = self.lora_state_dict(
544
+ pretrained_model_name_or_path_or_dict,
545
+ unet_config=self.unet.config,
546
+ **kwargs,
547
+ )
548
+
549
+ is_correct_format = all("lora" in key for key in state_dict.keys())
550
+ if not is_correct_format:
551
+ raise ValueError("Invalid LoRA checkpoint.")
552
+
553
+ self.load_lora_into_unet(
554
+ state_dict,
555
+ network_alphas=network_alphas,
556
+ unet=self.unet,
557
+ adapter_name=adapter_name,
558
+ _pipeline=self,
559
+ low_cpu_mem_usage=low_cpu_mem_usage,
560
+ )
561
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
562
+ if len(text_encoder_state_dict) > 0:
563
+ self.load_lora_into_text_encoder(
564
+ text_encoder_state_dict,
565
+ network_alphas=network_alphas,
566
+ text_encoder=self.text_encoder,
567
+ prefix="text_encoder",
568
+ lora_scale=self.lora_scale,
569
+ adapter_name=adapter_name,
570
+ _pipeline=self,
571
+ low_cpu_mem_usage=low_cpu_mem_usage,
572
+ )
573
+
574
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
575
+ if len(text_encoder_2_state_dict) > 0:
576
+ self.load_lora_into_text_encoder(
577
+ text_encoder_2_state_dict,
578
+ network_alphas=network_alphas,
579
+ text_encoder=self.text_encoder_2,
580
+ prefix="text_encoder_2",
581
+ lora_scale=self.lora_scale,
582
+ adapter_name=adapter_name,
583
+ _pipeline=self,
584
+ low_cpu_mem_usage=low_cpu_mem_usage,
585
+ )
586
+
587
+ @classmethod
588
+ @validate_hf_hub_args
589
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict
590
+ def lora_state_dict(
591
+ cls,
592
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
593
+ **kwargs,
594
+ ):
595
+ r"""
596
+ Return state dict for lora weights and the network alphas.
597
+
598
+ <Tip warning={true}>
599
+
600
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
601
+
602
+ This function is experimental and might change in the future.
603
+
604
+ </Tip>
605
+
606
+ Parameters:
607
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
608
+ Can be either:
609
+
610
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
611
+ the Hub.
612
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
613
+ with [`ModelMixin.save_pretrained`].
614
+ - A [torch state
615
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
616
+
617
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
618
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
619
+ is not used.
620
+ force_download (`bool`, *optional*, defaults to `False`):
621
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
622
+ cached versions if they exist.
623
+
624
+ proxies (`Dict[str, str]`, *optional*):
625
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
626
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
627
+ local_files_only (`bool`, *optional*, defaults to `False`):
628
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
629
+ won't be downloaded from the Hub.
630
+ token (`str` or *bool*, *optional*):
631
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
632
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
633
+ revision (`str`, *optional*, defaults to `"main"`):
634
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
635
+ allowed by Git.
636
+ subfolder (`str`, *optional*, defaults to `""`):
637
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
638
+ weight_name (`str`, *optional*, defaults to None):
639
+ Name of the serialized state dict file.
640
+ """
641
+ # Load the main state dict first which has the LoRA layers for either of
642
+ # UNet and text encoder or both.
643
+ cache_dir = kwargs.pop("cache_dir", None)
644
+ force_download = kwargs.pop("force_download", False)
645
+ proxies = kwargs.pop("proxies", None)
646
+ local_files_only = kwargs.pop("local_files_only", None)
647
+ token = kwargs.pop("token", None)
648
+ revision = kwargs.pop("revision", None)
649
+ subfolder = kwargs.pop("subfolder", None)
650
+ weight_name = kwargs.pop("weight_name", None)
651
+ unet_config = kwargs.pop("unet_config", None)
652
+ use_safetensors = kwargs.pop("use_safetensors", None)
653
+
654
+ allow_pickle = False
655
+ if use_safetensors is None:
656
+ use_safetensors = True
657
+ allow_pickle = True
658
+
659
+ user_agent = {
660
+ "file_type": "attn_procs_weights",
661
+ "framework": "pytorch",
662
+ }
663
+
664
+ state_dict = _fetch_state_dict(
665
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
666
+ weight_name=weight_name,
667
+ use_safetensors=use_safetensors,
668
+ local_files_only=local_files_only,
669
+ cache_dir=cache_dir,
670
+ force_download=force_download,
671
+ proxies=proxies,
672
+ token=token,
673
+ revision=revision,
674
+ subfolder=subfolder,
675
+ user_agent=user_agent,
676
+ allow_pickle=allow_pickle,
677
+ )
678
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
679
+ if is_dora_scale_present:
680
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
681
+ logger.warning(warn_msg)
682
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
683
+
684
+ network_alphas = None
685
+ # TODO: replace it with a method from `state_dict_utils`
686
+ if all(
687
+ (
688
+ k.startswith("lora_te_")
689
+ or k.startswith("lora_unet_")
690
+ or k.startswith("lora_te1_")
691
+ or k.startswith("lora_te2_")
692
+ )
693
+ for k in state_dict.keys()
694
+ ):
695
+ # Map SDXL blocks correctly.
696
+ if unet_config is not None:
697
+ # use unet config to remap block numbers
698
+ state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
699
+ state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
700
+
701
+ return state_dict, network_alphas
702
+
703
+ @classmethod
704
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
705
+ def load_lora_into_unet(
706
+ cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
707
+ ):
708
+ """
709
+ This will load the LoRA layers specified in `state_dict` into `unet`.
710
+
711
+ Parameters:
712
+ state_dict (`dict`):
713
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
714
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
715
+ encoder lora layers.
716
+ network_alphas (`Dict[str, float]`):
717
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
718
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
719
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
720
+ unet (`UNet2DConditionModel`):
721
+ The UNet model to load the LoRA layers into.
722
+ adapter_name (`str`, *optional*):
723
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
724
+ `default_{i}` where i is the total number of adapters being loaded.
725
+ low_cpu_mem_usage (`bool`, *optional*):
726
+ Speed up model loading only loading the pretrained LoRA weights and not initializing the random
727
+ weights.
728
+ """
729
+ if not USE_PEFT_BACKEND:
730
+ raise ValueError("PEFT backend is required for this method.")
731
+
732
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
733
+ raise ValueError(
734
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
735
+ )
736
+
737
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
738
+ # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
739
+ # their prefixes.
740
+ keys = list(state_dict.keys())
741
+ only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
742
+ if not only_text_encoder:
743
+ # Load the layers corresponding to UNet.
744
+ logger.info(f"Loading {cls.unet_name}.")
745
+ unet.load_lora_adapter(
746
+ state_dict,
747
+ prefix=cls.unet_name,
748
+ network_alphas=network_alphas,
749
+ adapter_name=adapter_name,
750
+ _pipeline=_pipeline,
751
+ low_cpu_mem_usage=low_cpu_mem_usage,
752
+ )
753
+
754
+ @classmethod
755
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
756
+ def load_lora_into_text_encoder(
757
+ cls,
758
+ state_dict,
759
+ network_alphas,
760
+ text_encoder,
761
+ prefix=None,
762
+ lora_scale=1.0,
763
+ adapter_name=None,
764
+ _pipeline=None,
765
+ low_cpu_mem_usage=False,
766
+ ):
767
+ """
768
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
769
+
770
+ Parameters:
771
+ state_dict (`dict`):
772
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
773
+ additional `text_encoder` to distinguish between unet lora layers.
774
+ network_alphas (`Dict[str, float]`):
775
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
776
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
777
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
778
+ text_encoder (`CLIPTextModel`):
779
+ The text encoder model to load the LoRA layers into.
780
+ prefix (`str`):
781
+ Expected prefix of the `text_encoder` in the `state_dict`.
782
+ lora_scale (`float`):
783
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
784
+ lora layer.
785
+ adapter_name (`str`, *optional*):
786
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
787
+ `default_{i}` where i is the total number of adapters being loaded.
788
+ low_cpu_mem_usage (`bool`, *optional*):
789
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
790
+ weights.
791
+ """
792
+ _load_lora_into_text_encoder(
793
+ state_dict=state_dict,
794
+ network_alphas=network_alphas,
795
+ lora_scale=lora_scale,
796
+ text_encoder=text_encoder,
797
+ prefix=prefix,
798
+ text_encoder_name=cls.text_encoder_name,
799
+ adapter_name=adapter_name,
800
+ _pipeline=_pipeline,
801
+ low_cpu_mem_usage=low_cpu_mem_usage,
802
+ )
803
+
804
+ @classmethod
805
+ def save_lora_weights(
806
+ cls,
807
+ save_directory: Union[str, os.PathLike],
808
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
809
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
810
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
811
+ is_main_process: bool = True,
812
+ weight_name: str = None,
813
+ save_function: Callable = None,
814
+ safe_serialization: bool = True,
815
+ ):
816
+ r"""
817
+ Save the LoRA parameters corresponding to the UNet and text encoder.
818
+
819
+ Arguments:
820
+ save_directory (`str` or `os.PathLike`):
821
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
822
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
823
+ State dict of the LoRA layers corresponding to the `unet`.
824
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
825
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
826
+ encoder LoRA state dict because it comes from 🤗 Transformers.
827
+ text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
828
+ State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
829
+ encoder LoRA state dict because it comes from 🤗 Transformers.
830
+ is_main_process (`bool`, *optional*, defaults to `True`):
831
+ Whether the process calling this is the main process or not. Useful during distributed training and you
832
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
833
+ process to avoid race conditions.
834
+ save_function (`Callable`):
835
+ The function to use to save the state dictionary. Useful during distributed training when you need to
836
+ replace `torch.save` with another method. Can be configured with the environment variable
837
+ `DIFFUSERS_SAVE_MODE`.
838
+ safe_serialization (`bool`, *optional*, defaults to `True`):
839
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
840
+ """
841
+ state_dict = {}
842
+
843
+ if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
844
+ raise ValueError(
845
+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
846
+ )
847
+
848
+ if unet_lora_layers:
849
+ state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
850
+
851
+ if text_encoder_lora_layers:
852
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
853
+
854
+ if text_encoder_2_lora_layers:
855
+ state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
856
+
857
+ cls.write_lora_layers(
858
+ state_dict=state_dict,
859
+ save_directory=save_directory,
860
+ is_main_process=is_main_process,
861
+ weight_name=weight_name,
862
+ save_function=save_function,
863
+ safe_serialization=safe_serialization,
864
+ )
865
+
866
+ def fuse_lora(
867
+ self,
868
+ components: List[str] = ["unet", "text_encoder", "text_encoder_2"],
869
+ lora_scale: float = 1.0,
870
+ safe_fusing: bool = False,
871
+ adapter_names: Optional[List[str]] = None,
872
+ **kwargs,
873
+ ):
874
+ r"""
875
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
876
+
877
+ <Tip warning={true}>
878
+
879
+ This is an experimental API.
880
+
881
+ </Tip>
882
+
883
+ Args:
884
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
885
+ lora_scale (`float`, defaults to 1.0):
886
+ Controls how much to influence the outputs with the LoRA parameters.
887
+ safe_fusing (`bool`, defaults to `False`):
888
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
889
+ adapter_names (`List[str]`, *optional*):
890
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
891
+
892
+ Example:
893
+
894
+ ```py
895
+ from diffusers import DiffusionPipeline
896
+ import torch
897
+
898
+ pipeline = DiffusionPipeline.from_pretrained(
899
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
900
+ ).to("cuda")
901
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
902
+ pipeline.fuse_lora(lora_scale=0.7)
903
+ ```
904
+ """
905
+ super().fuse_lora(
906
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
907
+ )
908
+
909
+ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
910
+ r"""
911
+ Reverses the effect of
912
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
913
+
914
+ <Tip warning={true}>
915
+
916
+ This is an experimental API.
917
+
918
+ </Tip>
919
+
920
+ Args:
921
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
922
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
923
+ unfuse_text_encoder (`bool`, defaults to `True`):
924
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
925
+ LoRA parameters then it won't have any effect.
926
+ """
927
+ super().unfuse_lora(components=components)
928
+
929
+
930
+ class SD3LoraLoaderMixin(LoraBaseMixin):
931
+ r"""
932
+ Load LoRA layers into [`SD3Transformer2DModel`],
933
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and
934
+ [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection).
935
+
936
+ Specific to [`StableDiffusion3Pipeline`].
937
+ """
938
+
939
+ _lora_loadable_modules = ["transformer", "text_encoder", "text_encoder_2"]
940
+ transformer_name = TRANSFORMER_NAME
941
+ text_encoder_name = TEXT_ENCODER_NAME
942
+
943
+ @classmethod
944
+ @validate_hf_hub_args
945
+ def lora_state_dict(
946
+ cls,
947
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
948
+ **kwargs,
949
+ ):
950
+ r"""
951
+ Return state dict for lora weights and the network alphas.
952
+
953
+ <Tip warning={true}>
954
+
955
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
956
+
957
+ This function is experimental and might change in the future.
958
+
959
+ </Tip>
960
+
961
+ Parameters:
962
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
963
+ Can be either:
964
+
965
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
966
+ the Hub.
967
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
968
+ with [`ModelMixin.save_pretrained`].
969
+ - A [torch state
970
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
971
+
972
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
973
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
974
+ is not used.
975
+ force_download (`bool`, *optional*, defaults to `False`):
976
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
977
+ cached versions if they exist.
978
+
979
+ proxies (`Dict[str, str]`, *optional*):
980
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
981
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
982
+ local_files_only (`bool`, *optional*, defaults to `False`):
983
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
984
+ won't be downloaded from the Hub.
985
+ token (`str` or *bool*, *optional*):
986
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
987
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
988
+ revision (`str`, *optional*, defaults to `"main"`):
989
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
990
+ allowed by Git.
991
+ subfolder (`str`, *optional*, defaults to `""`):
992
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
993
+
994
+ """
995
+ # Load the main state dict first which has the LoRA layers for either of
996
+ # transformer and text encoder or both.
997
+ cache_dir = kwargs.pop("cache_dir", None)
998
+ force_download = kwargs.pop("force_download", False)
999
+ proxies = kwargs.pop("proxies", None)
1000
+ local_files_only = kwargs.pop("local_files_only", None)
1001
+ token = kwargs.pop("token", None)
1002
+ revision = kwargs.pop("revision", None)
1003
+ subfolder = kwargs.pop("subfolder", None)
1004
+ weight_name = kwargs.pop("weight_name", None)
1005
+ use_safetensors = kwargs.pop("use_safetensors", None)
1006
+
1007
+ allow_pickle = False
1008
+ if use_safetensors is None:
1009
+ use_safetensors = True
1010
+ allow_pickle = True
1011
+
1012
+ user_agent = {
1013
+ "file_type": "attn_procs_weights",
1014
+ "framework": "pytorch",
1015
+ }
1016
+
1017
+ state_dict = _fetch_state_dict(
1018
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
1019
+ weight_name=weight_name,
1020
+ use_safetensors=use_safetensors,
1021
+ local_files_only=local_files_only,
1022
+ cache_dir=cache_dir,
1023
+ force_download=force_download,
1024
+ proxies=proxies,
1025
+ token=token,
1026
+ revision=revision,
1027
+ subfolder=subfolder,
1028
+ user_agent=user_agent,
1029
+ allow_pickle=allow_pickle,
1030
+ )
1031
+
1032
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
1033
+ if is_dora_scale_present:
1034
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1035
+ logger.warning(warn_msg)
1036
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
1037
+
1038
+ return state_dict
1039
+
1040
+ def load_lora_weights(
1041
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
1042
+ ):
1043
+ """
1044
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
1045
+ `self.text_encoder`.
1046
+
1047
+ All kwargs are forwarded to `self.lora_state_dict`.
1048
+
1049
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
1050
+ loaded.
1051
+
1052
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
1053
+ dict is loaded into `self.transformer`.
1054
+
1055
+ Parameters:
1056
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1057
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1058
+ adapter_name (`str`, *optional*):
1059
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1060
+ `default_{i}` where i is the total number of adapters being loaded.
1061
+ low_cpu_mem_usage (`bool`, *optional*):
1062
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1063
+ weights.
1064
+ kwargs (`dict`, *optional*):
1065
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1066
+ """
1067
+ if not USE_PEFT_BACKEND:
1068
+ raise ValueError("PEFT backend is required for this method.")
1069
+
1070
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
1071
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
1072
+ raise ValueError(
1073
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1074
+ )
1075
+
1076
+ # if a dict is passed, copy it instead of modifying it inplace
1077
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
1078
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
1079
+
1080
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1081
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
1082
+
1083
+ is_correct_format = all("lora" in key for key in state_dict.keys())
1084
+ if not is_correct_format:
1085
+ raise ValueError("Invalid LoRA checkpoint.")
1086
+
1087
+ transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
1088
+ if len(transformer_state_dict) > 0:
1089
+ self.load_lora_into_transformer(
1090
+ state_dict,
1091
+ transformer=getattr(self, self.transformer_name)
1092
+ if not hasattr(self, "transformer")
1093
+ else self.transformer,
1094
+ adapter_name=adapter_name,
1095
+ _pipeline=self,
1096
+ low_cpu_mem_usage=low_cpu_mem_usage,
1097
+ )
1098
+
1099
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1100
+ if len(text_encoder_state_dict) > 0:
1101
+ self.load_lora_into_text_encoder(
1102
+ text_encoder_state_dict,
1103
+ network_alphas=None,
1104
+ text_encoder=self.text_encoder,
1105
+ prefix="text_encoder",
1106
+ lora_scale=self.lora_scale,
1107
+ adapter_name=adapter_name,
1108
+ _pipeline=self,
1109
+ low_cpu_mem_usage=low_cpu_mem_usage,
1110
+ )
1111
+
1112
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
1113
+ if len(text_encoder_2_state_dict) > 0:
1114
+ self.load_lora_into_text_encoder(
1115
+ text_encoder_2_state_dict,
1116
+ network_alphas=None,
1117
+ text_encoder=self.text_encoder_2,
1118
+ prefix="text_encoder_2",
1119
+ lora_scale=self.lora_scale,
1120
+ adapter_name=adapter_name,
1121
+ _pipeline=self,
1122
+ low_cpu_mem_usage=low_cpu_mem_usage,
1123
+ )
1124
+
1125
+ @classmethod
1126
+ def load_lora_into_transformer(
1127
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
1128
+ ):
1129
+ """
1130
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
1131
+
1132
+ Parameters:
1133
+ state_dict (`dict`):
1134
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
1135
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
1136
+ encoder lora layers.
1137
+ transformer (`SD3Transformer2DModel`):
1138
+ The Transformer model to load the LoRA layers into.
1139
+ adapter_name (`str`, *optional*):
1140
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1141
+ `default_{i}` where i is the total number of adapters being loaded.
1142
+ low_cpu_mem_usage (`bool`, *optional*):
1143
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1144
+ weights.
1145
+ """
1146
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
1147
+ raise ValueError(
1148
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1149
+ )
1150
+
1151
+ # Load the layers corresponding to transformer.
1152
+ logger.info(f"Loading {cls.transformer_name}.")
1153
+ transformer.load_lora_adapter(
1154
+ state_dict,
1155
+ network_alphas=None,
1156
+ adapter_name=adapter_name,
1157
+ _pipeline=_pipeline,
1158
+ low_cpu_mem_usage=low_cpu_mem_usage,
1159
+ )
1160
+
1161
+ @classmethod
1162
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
1163
+ def load_lora_into_text_encoder(
1164
+ cls,
1165
+ state_dict,
1166
+ network_alphas,
1167
+ text_encoder,
1168
+ prefix=None,
1169
+ lora_scale=1.0,
1170
+ adapter_name=None,
1171
+ _pipeline=None,
1172
+ low_cpu_mem_usage=False,
1173
+ ):
1174
+ """
1175
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
1176
+
1177
+ Parameters:
1178
+ state_dict (`dict`):
1179
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
1180
+ additional `text_encoder` to distinguish between unet lora layers.
1181
+ network_alphas (`Dict[str, float]`):
1182
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
1183
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
1184
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
1185
+ text_encoder (`CLIPTextModel`):
1186
+ The text encoder model to load the LoRA layers into.
1187
+ prefix (`str`):
1188
+ Expected prefix of the `text_encoder` in the `state_dict`.
1189
+ lora_scale (`float`):
1190
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
1191
+ lora layer.
1192
+ adapter_name (`str`, *optional*):
1193
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1194
+ `default_{i}` where i is the total number of adapters being loaded.
1195
+ low_cpu_mem_usage (`bool`, *optional*):
1196
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1197
+ weights.
1198
+ """
1199
+ _load_lora_into_text_encoder(
1200
+ state_dict=state_dict,
1201
+ network_alphas=network_alphas,
1202
+ lora_scale=lora_scale,
1203
+ text_encoder=text_encoder,
1204
+ prefix=prefix,
1205
+ text_encoder_name=cls.text_encoder_name,
1206
+ adapter_name=adapter_name,
1207
+ _pipeline=_pipeline,
1208
+ low_cpu_mem_usage=low_cpu_mem_usage,
1209
+ )
1210
+
1211
+ @classmethod
1212
+ def save_lora_weights(
1213
+ cls,
1214
+ save_directory: Union[str, os.PathLike],
1215
+ transformer_lora_layers: Dict[str, torch.nn.Module] = None,
1216
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1217
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1218
+ is_main_process: bool = True,
1219
+ weight_name: str = None,
1220
+ save_function: Callable = None,
1221
+ safe_serialization: bool = True,
1222
+ ):
1223
+ r"""
1224
+ Save the LoRA parameters corresponding to the UNet and text encoder.
1225
+
1226
+ Arguments:
1227
+ save_directory (`str` or `os.PathLike`):
1228
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
1229
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1230
+ State dict of the LoRA layers corresponding to the `transformer`.
1231
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1232
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
1233
+ encoder LoRA state dict because it comes from 🤗 Transformers.
1234
+ text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1235
+ State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
1236
+ encoder LoRA state dict because it comes from 🤗 Transformers.
1237
+ is_main_process (`bool`, *optional*, defaults to `True`):
1238
+ Whether the process calling this is the main process or not. Useful during distributed training and you
1239
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
1240
+ process to avoid race conditions.
1241
+ save_function (`Callable`):
1242
+ The function to use to save the state dictionary. Useful during distributed training when you need to
1243
+ replace `torch.save` with another method. Can be configured with the environment variable
1244
+ `DIFFUSERS_SAVE_MODE`.
1245
+ safe_serialization (`bool`, *optional*, defaults to `True`):
1246
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
1247
+ """
1248
+ state_dict = {}
1249
+
1250
+ if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
1251
+ raise ValueError(
1252
+ "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
1253
+ )
1254
+
1255
+ if transformer_lora_layers:
1256
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
1257
+
1258
+ if text_encoder_lora_layers:
1259
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
1260
+
1261
+ if text_encoder_2_lora_layers:
1262
+ state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
1263
+
1264
+ # Save the model
1265
+ cls.write_lora_layers(
1266
+ state_dict=state_dict,
1267
+ save_directory=save_directory,
1268
+ is_main_process=is_main_process,
1269
+ weight_name=weight_name,
1270
+ save_function=save_function,
1271
+ safe_serialization=safe_serialization,
1272
+ )
1273
+
1274
+ def fuse_lora(
1275
+ self,
1276
+ components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
1277
+ lora_scale: float = 1.0,
1278
+ safe_fusing: bool = False,
1279
+ adapter_names: Optional[List[str]] = None,
1280
+ **kwargs,
1281
+ ):
1282
+ r"""
1283
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
1284
+
1285
+ <Tip warning={true}>
1286
+
1287
+ This is an experimental API.
1288
+
1289
+ </Tip>
1290
+
1291
+ Args:
1292
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
1293
+ lora_scale (`float`, defaults to 1.0):
1294
+ Controls how much to influence the outputs with the LoRA parameters.
1295
+ safe_fusing (`bool`, defaults to `False`):
1296
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
1297
+ adapter_names (`List[str]`, *optional*):
1298
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
1299
+
1300
+ Example:
1301
+
1302
+ ```py
1303
+ from diffusers import DiffusionPipeline
1304
+ import torch
1305
+
1306
+ pipeline = DiffusionPipeline.from_pretrained(
1307
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
1308
+ ).to("cuda")
1309
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
1310
+ pipeline.fuse_lora(lora_scale=0.7)
1311
+ ```
1312
+ """
1313
+ super().fuse_lora(
1314
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
1315
+ )
1316
+
1317
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
1318
+ r"""
1319
+ Reverses the effect of
1320
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
1321
+
1322
+ <Tip warning={true}>
1323
+
1324
+ This is an experimental API.
1325
+
1326
+ </Tip>
1327
+
1328
+ Args:
1329
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
1330
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
1331
+ unfuse_text_encoder (`bool`, defaults to `True`):
1332
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
1333
+ LoRA parameters then it won't have any effect.
1334
+ """
1335
+ super().unfuse_lora(components=components)
1336
+
1337
+
1338
+ class FluxLoraLoaderMixin(LoraBaseMixin):
1339
+ r"""
1340
+ Load LoRA layers into [`FluxTransformer2DModel`],
1341
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
1342
+
1343
+ Specific to [`StableDiffusion3Pipeline`].
1344
+ """
1345
+
1346
+ _lora_loadable_modules = ["transformer", "text_encoder"]
1347
+ transformer_name = TRANSFORMER_NAME
1348
+ text_encoder_name = TEXT_ENCODER_NAME
1349
+ _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
1350
+
1351
+ @classmethod
1352
+ @validate_hf_hub_args
1353
+ def lora_state_dict(
1354
+ cls,
1355
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
1356
+ return_alphas: bool = False,
1357
+ **kwargs,
1358
+ ):
1359
+ r"""
1360
+ Return state dict for lora weights and the network alphas.
1361
+
1362
+ <Tip warning={true}>
1363
+
1364
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
1365
+
1366
+ This function is experimental and might change in the future.
1367
+
1368
+ </Tip>
1369
+
1370
+ Parameters:
1371
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1372
+ Can be either:
1373
+
1374
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1375
+ the Hub.
1376
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1377
+ with [`ModelMixin.save_pretrained`].
1378
+ - A [torch state
1379
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
1380
+
1381
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1382
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1383
+ is not used.
1384
+ force_download (`bool`, *optional*, defaults to `False`):
1385
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1386
+ cached versions if they exist.
1387
+
1388
+ proxies (`Dict[str, str]`, *optional*):
1389
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1390
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1391
+ local_files_only (`bool`, *optional*, defaults to `False`):
1392
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1393
+ won't be downloaded from the Hub.
1394
+ token (`str` or *bool*, *optional*):
1395
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1396
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1397
+ revision (`str`, *optional*, defaults to `"main"`):
1398
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1399
+ allowed by Git.
1400
+ subfolder (`str`, *optional*, defaults to `""`):
1401
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1402
+
1403
+ """
1404
+ # Load the main state dict first which has the LoRA layers for either of
1405
+ # transformer and text encoder or both.
1406
+ cache_dir = kwargs.pop("cache_dir", None)
1407
+ force_download = kwargs.pop("force_download", False)
1408
+ proxies = kwargs.pop("proxies", None)
1409
+ local_files_only = kwargs.pop("local_files_only", None)
1410
+ token = kwargs.pop("token", None)
1411
+ revision = kwargs.pop("revision", None)
1412
+ subfolder = kwargs.pop("subfolder", None)
1413
+ weight_name = kwargs.pop("weight_name", None)
1414
+ use_safetensors = kwargs.pop("use_safetensors", None)
1415
+
1416
+ allow_pickle = False
1417
+ if use_safetensors is None:
1418
+ use_safetensors = True
1419
+ allow_pickle = True
1420
+
1421
+ user_agent = {
1422
+ "file_type": "attn_procs_weights",
1423
+ "framework": "pytorch",
1424
+ }
1425
+
1426
+ state_dict = _fetch_state_dict(
1427
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
1428
+ weight_name=weight_name,
1429
+ use_safetensors=use_safetensors,
1430
+ local_files_only=local_files_only,
1431
+ cache_dir=cache_dir,
1432
+ force_download=force_download,
1433
+ proxies=proxies,
1434
+ token=token,
1435
+ revision=revision,
1436
+ subfolder=subfolder,
1437
+ user_agent=user_agent,
1438
+ allow_pickle=allow_pickle,
1439
+ )
1440
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
1441
+ if is_dora_scale_present:
1442
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1443
+ logger.warning(warn_msg)
1444
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
1445
+
1446
+ # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
1447
+ is_kohya = any(".lora_down.weight" in k for k in state_dict)
1448
+ if is_kohya:
1449
+ state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
1450
+ # Kohya already takes care of scaling the LoRA parameters with alpha.
1451
+ return (state_dict, None) if return_alphas else state_dict
1452
+
1453
+ is_xlabs = any("processor" in k for k in state_dict)
1454
+ if is_xlabs:
1455
+ state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
1456
+ # xlabs doesn't use `alpha`.
1457
+ return (state_dict, None) if return_alphas else state_dict
1458
+
1459
+ is_bfl_control = any("query_norm.scale" in k for k in state_dict)
1460
+ if is_bfl_control:
1461
+ state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
1462
+ return (state_dict, None) if return_alphas else state_dict
1463
+
1464
+ # For state dicts like
1465
+ # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
1466
+ keys = list(state_dict.keys())
1467
+ network_alphas = {}
1468
+ for k in keys:
1469
+ if "alpha" in k:
1470
+ alpha_value = state_dict.get(k)
1471
+ if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
1472
+ alpha_value, float
1473
+ ):
1474
+ network_alphas[k] = state_dict.pop(k)
1475
+ else:
1476
+ raise ValueError(
1477
+ f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
1478
+ )
1479
+
1480
+ if return_alphas:
1481
+ return state_dict, network_alphas
1482
+ else:
1483
+ return state_dict
1484
+
1485
+ def load_lora_weights(
1486
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
1487
+ ):
1488
+ """
1489
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
1490
+ `self.text_encoder`.
1491
+
1492
+ All kwargs are forwarded to `self.lora_state_dict`.
1493
+
1494
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
1495
+ loaded.
1496
+
1497
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
1498
+ dict is loaded into `self.transformer`.
1499
+
1500
+ Parameters:
1501
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1502
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1503
+ kwargs (`dict`, *optional*):
1504
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1505
+ adapter_name (`str`, *optional*):
1506
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1507
+ `default_{i}` where i is the total number of adapters being loaded.
1508
+ low_cpu_mem_usage (`bool`, *optional*):
1509
+ `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1510
+ weights.
1511
+ """
1512
+ if not USE_PEFT_BACKEND:
1513
+ raise ValueError("PEFT backend is required for this method.")
1514
+
1515
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
1516
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
1517
+ raise ValueError(
1518
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1519
+ )
1520
+
1521
+ # if a dict is passed, copy it instead of modifying it inplace
1522
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
1523
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
1524
+
1525
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1526
+ state_dict, network_alphas = self.lora_state_dict(
1527
+ pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
1528
+ )
1529
+
1530
+ has_lora_keys = any("lora" in key for key in state_dict.keys())
1531
+
1532
+ # Flux Control LoRAs also have norm keys
1533
+ has_norm_keys = any(
1534
+ norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys
1535
+ )
1536
+
1537
+ if not (has_lora_keys or has_norm_keys):
1538
+ raise ValueError("Invalid LoRA checkpoint.")
1539
+
1540
+ transformer_lora_state_dict = {
1541
+ k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k
1542
+ }
1543
+ transformer_norm_state_dict = {
1544
+ k: state_dict.pop(k)
1545
+ for k in list(state_dict.keys())
1546
+ if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
1547
+ }
1548
+
1549
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
1550
+ has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
1551
+ transformer, transformer_lora_state_dict, transformer_norm_state_dict
1552
+ )
1553
+
1554
+ if has_param_with_expanded_shape:
1555
+ logger.info(
1556
+ "The LoRA weights contain parameters that have different shapes that expected by the transformer. "
1557
+ "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
1558
+ "To get a comprehensive list of parameter names that were modified, enable debug logging."
1559
+ )
1560
+ transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
1561
+ transformer=transformer, lora_state_dict=transformer_lora_state_dict
1562
+ )
1563
+
1564
+ if len(transformer_lora_state_dict) > 0:
1565
+ self.load_lora_into_transformer(
1566
+ transformer_lora_state_dict,
1567
+ network_alphas=network_alphas,
1568
+ transformer=transformer,
1569
+ adapter_name=adapter_name,
1570
+ _pipeline=self,
1571
+ low_cpu_mem_usage=low_cpu_mem_usage,
1572
+ )
1573
+
1574
+ if len(transformer_norm_state_dict) > 0:
1575
+ transformer._transformer_norm_layers = self._load_norm_into_transformer(
1576
+ transformer_norm_state_dict,
1577
+ transformer=transformer,
1578
+ discard_original_layers=False,
1579
+ )
1580
+
1581
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1582
+ if len(text_encoder_state_dict) > 0:
1583
+ self.load_lora_into_text_encoder(
1584
+ text_encoder_state_dict,
1585
+ network_alphas=network_alphas,
1586
+ text_encoder=self.text_encoder,
1587
+ prefix="text_encoder",
1588
+ lora_scale=self.lora_scale,
1589
+ adapter_name=adapter_name,
1590
+ _pipeline=self,
1591
+ low_cpu_mem_usage=low_cpu_mem_usage,
1592
+ )
1593
+
1594
+ @classmethod
1595
+ def load_lora_into_transformer(
1596
+ cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
1597
+ ):
1598
+ """
1599
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
1600
+
1601
+ Parameters:
1602
+ state_dict (`dict`):
1603
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
1604
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
1605
+ encoder lora layers.
1606
+ network_alphas (`Dict[str, float]`):
1607
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
1608
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
1609
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
1610
+ transformer (`FluxTransformer2DModel`):
1611
+ The Transformer model to load the LoRA layers into.
1612
+ adapter_name (`str`, *optional*):
1613
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1614
+ `default_{i}` where i is the total number of adapters being loaded.
1615
+ low_cpu_mem_usage (`bool`, *optional*):
1616
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1617
+ weights.
1618
+ """
1619
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
1620
+ raise ValueError(
1621
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1622
+ )
1623
+
1624
+ # Load the layers corresponding to transformer.
1625
+ keys = list(state_dict.keys())
1626
+ transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
1627
+ if transformer_present:
1628
+ logger.info(f"Loading {cls.transformer_name}.")
1629
+ transformer.load_lora_adapter(
1630
+ state_dict,
1631
+ network_alphas=network_alphas,
1632
+ adapter_name=adapter_name,
1633
+ _pipeline=_pipeline,
1634
+ low_cpu_mem_usage=low_cpu_mem_usage,
1635
+ )
1636
+
1637
+ @classmethod
1638
+ def _load_norm_into_transformer(
1639
+ cls,
1640
+ state_dict,
1641
+ transformer,
1642
+ prefix=None,
1643
+ discard_original_layers=False,
1644
+ ) -> Dict[str, torch.Tensor]:
1645
+ # Remove prefix if present
1646
+ prefix = prefix or cls.transformer_name
1647
+ for key in list(state_dict.keys()):
1648
+ if key.split(".")[0] == prefix:
1649
+ state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
1650
+
1651
+ # Find invalid keys
1652
+ transformer_state_dict = transformer.state_dict()
1653
+ transformer_keys = set(transformer_state_dict.keys())
1654
+ state_dict_keys = set(state_dict.keys())
1655
+ extra_keys = list(state_dict_keys - transformer_keys)
1656
+
1657
+ if extra_keys:
1658
+ logger.warning(
1659
+ f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}."
1660
+ )
1661
+
1662
+ for key in extra_keys:
1663
+ state_dict.pop(key)
1664
+
1665
+ # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected
1666
+ overwritten_layers_state_dict = {}
1667
+ if not discard_original_layers:
1668
+ for key in state_dict.keys():
1669
+ overwritten_layers_state_dict[key] = transformer_state_dict[key].clone()
1670
+
1671
+ logger.info(
1672
+ "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer "
1673
+ 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly '
1674
+ "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. "
1675
+ "If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues."
1676
+ )
1677
+
1678
+ # We can't load with strict=True because the current state_dict does not contain all the transformer keys
1679
+ incompatible_keys = transformer.load_state_dict(state_dict, strict=False)
1680
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1681
+
1682
+ # We shouldn't expect to see the supported norm keys here being present in the unexpected keys.
1683
+ if unexpected_keys:
1684
+ if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys):
1685
+ raise ValueError(
1686
+ f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer."
1687
+ )
1688
+
1689
+ return overwritten_layers_state_dict
1690
+
1691
+ @classmethod
1692
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
1693
+ def load_lora_into_text_encoder(
1694
+ cls,
1695
+ state_dict,
1696
+ network_alphas,
1697
+ text_encoder,
1698
+ prefix=None,
1699
+ lora_scale=1.0,
1700
+ adapter_name=None,
1701
+ _pipeline=None,
1702
+ low_cpu_mem_usage=False,
1703
+ ):
1704
+ """
1705
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
1706
+
1707
+ Parameters:
1708
+ state_dict (`dict`):
1709
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
1710
+ additional `text_encoder` to distinguish between unet lora layers.
1711
+ network_alphas (`Dict[str, float]`):
1712
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
1713
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
1714
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
1715
+ text_encoder (`CLIPTextModel`):
1716
+ The text encoder model to load the LoRA layers into.
1717
+ prefix (`str`):
1718
+ Expected prefix of the `text_encoder` in the `state_dict`.
1719
+ lora_scale (`float`):
1720
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
1721
+ lora layer.
1722
+ adapter_name (`str`, *optional*):
1723
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1724
+ `default_{i}` where i is the total number of adapters being loaded.
1725
+ low_cpu_mem_usage (`bool`, *optional*):
1726
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1727
+ weights.
1728
+ """
1729
+ _load_lora_into_text_encoder(
1730
+ state_dict=state_dict,
1731
+ network_alphas=network_alphas,
1732
+ lora_scale=lora_scale,
1733
+ text_encoder=text_encoder,
1734
+ prefix=prefix,
1735
+ text_encoder_name=cls.text_encoder_name,
1736
+ adapter_name=adapter_name,
1737
+ _pipeline=_pipeline,
1738
+ low_cpu_mem_usage=low_cpu_mem_usage,
1739
+ )
1740
+
1741
+ @classmethod
1742
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
1743
+ def save_lora_weights(
1744
+ cls,
1745
+ save_directory: Union[str, os.PathLike],
1746
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1747
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
1748
+ is_main_process: bool = True,
1749
+ weight_name: str = None,
1750
+ save_function: Callable = None,
1751
+ safe_serialization: bool = True,
1752
+ ):
1753
+ r"""
1754
+ Save the LoRA parameters corresponding to the UNet and text encoder.
1755
+
1756
+ Arguments:
1757
+ save_directory (`str` or `os.PathLike`):
1758
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
1759
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1760
+ State dict of the LoRA layers corresponding to the `transformer`.
1761
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1762
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
1763
+ encoder LoRA state dict because it comes from 🤗 Transformers.
1764
+ is_main_process (`bool`, *optional*, defaults to `True`):
1765
+ Whether the process calling this is the main process or not. Useful during distributed training and you
1766
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
1767
+ process to avoid race conditions.
1768
+ save_function (`Callable`):
1769
+ The function to use to save the state dictionary. Useful during distributed training when you need to
1770
+ replace `torch.save` with another method. Can be configured with the environment variable
1771
+ `DIFFUSERS_SAVE_MODE`.
1772
+ safe_serialization (`bool`, *optional*, defaults to `True`):
1773
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
1774
+ """
1775
+ state_dict = {}
1776
+
1777
+ if not (transformer_lora_layers or text_encoder_lora_layers):
1778
+ raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
1779
+
1780
+ if transformer_lora_layers:
1781
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
1782
+
1783
+ if text_encoder_lora_layers:
1784
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
1785
+
1786
+ # Save the model
1787
+ cls.write_lora_layers(
1788
+ state_dict=state_dict,
1789
+ save_directory=save_directory,
1790
+ is_main_process=is_main_process,
1791
+ weight_name=weight_name,
1792
+ save_function=save_function,
1793
+ safe_serialization=safe_serialization,
1794
+ )
1795
+
1796
+ def fuse_lora(
1797
+ self,
1798
+ components: List[str] = ["transformer"],
1799
+ lora_scale: float = 1.0,
1800
+ safe_fusing: bool = False,
1801
+ adapter_names: Optional[List[str]] = None,
1802
+ **kwargs,
1803
+ ):
1804
+ r"""
1805
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
1806
+
1807
+ <Tip warning={true}>
1808
+
1809
+ This is an experimental API.
1810
+
1811
+ </Tip>
1812
+
1813
+ Args:
1814
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
1815
+ lora_scale (`float`, defaults to 1.0):
1816
+ Controls how much to influence the outputs with the LoRA parameters.
1817
+ safe_fusing (`bool`, defaults to `False`):
1818
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
1819
+ adapter_names (`List[str]`, *optional*):
1820
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
1821
+
1822
+ Example:
1823
+
1824
+ ```py
1825
+ from diffusers import DiffusionPipeline
1826
+ import torch
1827
+
1828
+ pipeline = DiffusionPipeline.from_pretrained(
1829
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
1830
+ ).to("cuda")
1831
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
1832
+ pipeline.fuse_lora(lora_scale=0.7)
1833
+ ```
1834
+ """
1835
+
1836
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
1837
+ if (
1838
+ hasattr(transformer, "_transformer_norm_layers")
1839
+ and isinstance(transformer._transformer_norm_layers, dict)
1840
+ and len(transformer._transformer_norm_layers.keys()) > 0
1841
+ ):
1842
+ logger.info(
1843
+ "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer "
1844
+ "as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly "
1845
+ "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed."
1846
+ )
1847
+
1848
+ super().fuse_lora(
1849
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
1850
+ )
1851
+
1852
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
1853
+ r"""
1854
+ Reverses the effect of
1855
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
1856
+
1857
+ <Tip warning={true}>
1858
+
1859
+ This is an experimental API.
1860
+
1861
+ </Tip>
1862
+
1863
+ Args:
1864
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
1865
+ """
1866
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
1867
+ if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
1868
+ transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
1869
+
1870
+ super().unfuse_lora(components=components)
1871
+
1872
+ # We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
1873
+ def unload_lora_weights(self, reset_to_overwritten_params=False):
1874
+ """
1875
+ Unloads the LoRA parameters.
1876
+
1877
+ Args:
1878
+ reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules
1879
+ to their original params. Refer to the [Flux
1880
+ documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more.
1881
+
1882
+ Examples:
1883
+
1884
+ ```python
1885
+ >>> # Assuming `pipeline` is already loaded with the LoRA parameters.
1886
+ >>> pipeline.unload_lora_weights()
1887
+ >>> ...
1888
+ ```
1889
+ """
1890
+ super().unload_lora_weights()
1891
+
1892
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
1893
+ if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
1894
+ transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
1895
+ transformer._transformer_norm_layers = None
1896
+
1897
+ if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None:
1898
+ overwritten_params = transformer._overwritten_params
1899
+ module_names = set()
1900
+
1901
+ for param_name in overwritten_params:
1902
+ if param_name.endswith(".weight"):
1903
+ module_names.add(param_name.replace(".weight", ""))
1904
+
1905
+ for name, module in transformer.named_modules():
1906
+ if isinstance(module, torch.nn.Linear) and name in module_names:
1907
+ module_weight = module.weight.data
1908
+ module_bias = module.bias.data if module.bias is not None else None
1909
+ bias = module_bias is not None
1910
+
1911
+ parent_module_name, _, current_module_name = name.rpartition(".")
1912
+ parent_module = transformer.get_submodule(parent_module_name)
1913
+
1914
+ current_param_weight = overwritten_params[f"{name}.weight"]
1915
+ in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0]
1916
+ with torch.device("meta"):
1917
+ original_module = torch.nn.Linear(
1918
+ in_features,
1919
+ out_features,
1920
+ bias=bias,
1921
+ dtype=module_weight.dtype,
1922
+ )
1923
+
1924
+ tmp_state_dict = {"weight": current_param_weight}
1925
+ if module_bias is not None:
1926
+ tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]})
1927
+ original_module.load_state_dict(tmp_state_dict, assign=True, strict=True)
1928
+ setattr(parent_module, current_module_name, original_module)
1929
+
1930
+ del tmp_state_dict
1931
+
1932
+ if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
1933
+ attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
1934
+ new_value = int(current_param_weight.shape[1])
1935
+ old_value = getattr(transformer.config, attribute_name)
1936
+ setattr(transformer.config, attribute_name, new_value)
1937
+ logger.info(
1938
+ f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
1939
+ )
1940
+
1941
+ @classmethod
1942
+ def _maybe_expand_transformer_param_shape_or_error_(
1943
+ cls,
1944
+ transformer: torch.nn.Module,
1945
+ lora_state_dict=None,
1946
+ norm_state_dict=None,
1947
+ prefix=None,
1948
+ ) -> bool:
1949
+ """
1950
+ Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and
1951
+ generalizes things a bit so that any parameter that needs expansion receives appropriate treatement.
1952
+ """
1953
+ state_dict = {}
1954
+ if lora_state_dict is not None:
1955
+ state_dict.update(lora_state_dict)
1956
+ if norm_state_dict is not None:
1957
+ state_dict.update(norm_state_dict)
1958
+
1959
+ # Remove prefix if present
1960
+ prefix = prefix or cls.transformer_name
1961
+ for key in list(state_dict.keys()):
1962
+ if key.split(".")[0] == prefix:
1963
+ state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
1964
+
1965
+ # Expand transformer parameter shapes if they don't match lora
1966
+ has_param_with_shape_update = False
1967
+ overwritten_params = {}
1968
+
1969
+ is_peft_loaded = getattr(transformer, "peft_config", None) is not None
1970
+ for name, module in transformer.named_modules():
1971
+ if isinstance(module, torch.nn.Linear):
1972
+ module_weight = module.weight.data
1973
+ module_bias = module.bias.data if module.bias is not None else None
1974
+ bias = module_bias is not None
1975
+
1976
+ lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
1977
+ lora_A_weight_name = f"{lora_base_name}.lora_A.weight"
1978
+ lora_B_weight_name = f"{lora_base_name}.lora_B.weight"
1979
+ if lora_A_weight_name not in state_dict:
1980
+ continue
1981
+
1982
+ in_features = state_dict[lora_A_weight_name].shape[1]
1983
+ out_features = state_dict[lora_B_weight_name].shape[0]
1984
+
1985
+ # Model maybe loaded with different quantization schemes which may flatten the params.
1986
+ # `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
1987
+ # preserve weight shape.
1988
+ module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)
1989
+
1990
+ # This means there's no need for an expansion in the params, so we simply skip.
1991
+ if tuple(module_weight_shape) == (out_features, in_features):
1992
+ continue
1993
+
1994
+ # TODO (sayakpaul): We still need to consider if the module we're expanding is
1995
+ # quantized and handle it accordingly if that is the case.
1996
+ module_out_features, module_in_features = module_weight.shape
1997
+ debug_message = ""
1998
+ if in_features > module_in_features:
1999
+ debug_message += (
2000
+ f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
2001
+ f"checkpoint contains higher number of features than expected. The number of input_features will be "
2002
+ f"expanded from {module_in_features} to {in_features}"
2003
+ )
2004
+ if out_features > module_out_features:
2005
+ debug_message += (
2006
+ ", and the number of output features will be "
2007
+ f"expanded from {module_out_features} to {out_features}."
2008
+ )
2009
+ else:
2010
+ debug_message += "."
2011
+ if debug_message:
2012
+ logger.debug(debug_message)
2013
+
2014
+ if out_features > module_out_features or in_features > module_in_features:
2015
+ has_param_with_shape_update = True
2016
+ parent_module_name, _, current_module_name = name.rpartition(".")
2017
+ parent_module = transformer.get_submodule(parent_module_name)
2018
+
2019
+ with torch.device("meta"):
2020
+ expanded_module = torch.nn.Linear(
2021
+ in_features, out_features, bias=bias, dtype=module_weight.dtype
2022
+ )
2023
+ # Only weights are expanded and biases are not. This is because only the input dimensions
2024
+ # are changed while the output dimensions remain the same. The shape of the weight tensor
2025
+ # is (out_features, in_features), while the shape of bias tensor is (out_features,), which
2026
+ # explains the reason why only weights are expanded.
2027
+ new_weight = torch.zeros_like(
2028
+ expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
2029
+ )
2030
+ slices = tuple(slice(0, dim) for dim in module_weight.shape)
2031
+ new_weight[slices] = module_weight
2032
+ tmp_state_dict = {"weight": new_weight}
2033
+ if module_bias is not None:
2034
+ tmp_state_dict["bias"] = module_bias
2035
+ expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)
2036
+
2037
+ setattr(parent_module, current_module_name, expanded_module)
2038
+
2039
+ del tmp_state_dict
2040
+
2041
+ if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
2042
+ attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
2043
+ new_value = int(expanded_module.weight.data.shape[1])
2044
+ old_value = getattr(transformer.config, attribute_name)
2045
+ setattr(transformer.config, attribute_name, new_value)
2046
+ logger.info(
2047
+ f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
2048
+ )
2049
+
2050
+ # For `unload_lora_weights()`.
2051
+ # TODO: this could lead to more memory overhead if the number of overwritten params
2052
+ # are large. Should be revisited later and tackled through a `discard_original_layers` arg.
2053
+ overwritten_params[f"{current_module_name}.weight"] = module_weight
2054
+ if module_bias is not None:
2055
+ overwritten_params[f"{current_module_name}.bias"] = module_bias
2056
+
2057
+ if len(overwritten_params) > 0:
2058
+ transformer._overwritten_params = overwritten_params
2059
+
2060
+ return has_param_with_shape_update
2061
+
2062
+ @classmethod
2063
+ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
2064
+ expanded_module_names = set()
2065
+ transformer_state_dict = transformer.state_dict()
2066
+ prefix = f"{cls.transformer_name}."
2067
+
2068
+ lora_module_names = [
2069
+ key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
2070
+ ]
2071
+ lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
2072
+ lora_module_names = sorted(set(lora_module_names))
2073
+ transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
2074
+ unexpected_modules = set(lora_module_names) - set(transformer_module_names)
2075
+ if unexpected_modules:
2076
+ logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
2077
+
2078
+ is_peft_loaded = getattr(transformer, "peft_config", None) is not None
2079
+ for k in lora_module_names:
2080
+ if k in unexpected_modules:
2081
+ continue
2082
+
2083
+ base_param_name = (
2084
+ f"{k.replace(prefix, '')}.base_layer.weight"
2085
+ if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
2086
+ else f"{k.replace(prefix, '')}.weight"
2087
+ )
2088
+ base_weight_param = transformer_state_dict[base_param_name]
2089
+ lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
2090
+
2091
+ # TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
2092
+ base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)
2093
+
2094
+ if base_module_shape[1] > lora_A_param.shape[1]:
2095
+ shape = (lora_A_param.shape[0], base_weight_param.shape[1])
2096
+ expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
2097
+ expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
2098
+ lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
2099
+ expanded_module_names.add(k)
2100
+ elif base_module_shape[1] < lora_A_param.shape[1]:
2101
+ raise NotImplementedError(
2102
+ f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
2103
+ )
2104
+
2105
+ if expanded_module_names:
2106
+ logger.info(
2107
+ f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
2108
+ )
2109
+
2110
+ return lora_state_dict
2111
+
2112
+ @staticmethod
2113
+ def _calculate_module_shape(
2114
+ model: "torch.nn.Module",
2115
+ base_module: "torch.nn.Linear" = None,
2116
+ base_weight_param_name: str = None,
2117
+ ) -> "torch.Size":
2118
+ def _get_weight_shape(weight: torch.Tensor):
2119
+ return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
2120
+
2121
+ if base_module is not None:
2122
+ return _get_weight_shape(base_module.weight)
2123
+ elif base_weight_param_name is not None:
2124
+ if not base_weight_param_name.endswith(".weight"):
2125
+ raise ValueError(
2126
+ f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
2127
+ )
2128
+ module_path = base_weight_param_name.rsplit(".weight", 1)[0]
2129
+ submodule = get_submodule_by_name(model, module_path)
2130
+ return _get_weight_shape(submodule.weight)
2131
+
2132
+ raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
2133
+
2134
+
2135
+ # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
2136
+ # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
2137
+ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2138
+ _lora_loadable_modules = ["transformer", "text_encoder"]
2139
+ transformer_name = TRANSFORMER_NAME
2140
+ text_encoder_name = TEXT_ENCODER_NAME
2141
+
2142
+ @classmethod
2143
+ # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
2144
+ def load_lora_into_transformer(
2145
+ cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
2146
+ ):
2147
+ """
2148
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
2149
+
2150
+ Parameters:
2151
+ state_dict (`dict`):
2152
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
2153
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
2154
+ encoder lora layers.
2155
+ network_alphas (`Dict[str, float]`):
2156
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
2157
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
2158
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
2159
+ transformer (`UVit2DModel`):
2160
+ The Transformer model to load the LoRA layers into.
2161
+ adapter_name (`str`, *optional*):
2162
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2163
+ `default_{i}` where i is the total number of adapters being loaded.
2164
+ low_cpu_mem_usage (`bool`, *optional*):
2165
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2166
+ weights.
2167
+ """
2168
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
2169
+ raise ValueError(
2170
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2171
+ )
2172
+
2173
+ # Load the layers corresponding to transformer.
2174
+ keys = list(state_dict.keys())
2175
+ transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
2176
+ if transformer_present:
2177
+ logger.info(f"Loading {cls.transformer_name}.")
2178
+ transformer.load_lora_adapter(
2179
+ state_dict,
2180
+ network_alphas=network_alphas,
2181
+ adapter_name=adapter_name,
2182
+ _pipeline=_pipeline,
2183
+ low_cpu_mem_usage=low_cpu_mem_usage,
2184
+ )
2185
+
2186
+ @classmethod
2187
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
2188
+ def load_lora_into_text_encoder(
2189
+ cls,
2190
+ state_dict,
2191
+ network_alphas,
2192
+ text_encoder,
2193
+ prefix=None,
2194
+ lora_scale=1.0,
2195
+ adapter_name=None,
2196
+ _pipeline=None,
2197
+ low_cpu_mem_usage=False,
2198
+ ):
2199
+ """
2200
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
2201
+
2202
+ Parameters:
2203
+ state_dict (`dict`):
2204
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
2205
+ additional `text_encoder` to distinguish between unet lora layers.
2206
+ network_alphas (`Dict[str, float]`):
2207
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
2208
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
2209
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
2210
+ text_encoder (`CLIPTextModel`):
2211
+ The text encoder model to load the LoRA layers into.
2212
+ prefix (`str`):
2213
+ Expected prefix of the `text_encoder` in the `state_dict`.
2214
+ lora_scale (`float`):
2215
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
2216
+ lora layer.
2217
+ adapter_name (`str`, *optional*):
2218
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2219
+ `default_{i}` where i is the total number of adapters being loaded.
2220
+ low_cpu_mem_usage (`bool`, *optional*):
2221
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2222
+ weights.
2223
+ """
2224
+ _load_lora_into_text_encoder(
2225
+ state_dict=state_dict,
2226
+ network_alphas=network_alphas,
2227
+ lora_scale=lora_scale,
2228
+ text_encoder=text_encoder,
2229
+ prefix=prefix,
2230
+ text_encoder_name=cls.text_encoder_name,
2231
+ adapter_name=adapter_name,
2232
+ _pipeline=_pipeline,
2233
+ low_cpu_mem_usage=low_cpu_mem_usage,
2234
+ )
2235
+
2236
+ @classmethod
2237
+ def save_lora_weights(
2238
+ cls,
2239
+ save_directory: Union[str, os.PathLike],
2240
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
2241
+ transformer_lora_layers: Dict[str, torch.nn.Module] = None,
2242
+ is_main_process: bool = True,
2243
+ weight_name: str = None,
2244
+ save_function: Callable = None,
2245
+ safe_serialization: bool = True,
2246
+ ):
2247
+ r"""
2248
+ Save the LoRA parameters corresponding to the UNet and text encoder.
2249
+
2250
+ Arguments:
2251
+ save_directory (`str` or `os.PathLike`):
2252
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
2253
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
2254
+ State dict of the LoRA layers corresponding to the `unet`.
2255
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
2256
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
2257
+ encoder LoRA state dict because it comes from 🤗 Transformers.
2258
+ is_main_process (`bool`, *optional*, defaults to `True`):
2259
+ Whether the process calling this is the main process or not. Useful during distributed training and you
2260
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
2261
+ process to avoid race conditions.
2262
+ save_function (`Callable`):
2263
+ The function to use to save the state dictionary. Useful during distributed training when you need to
2264
+ replace `torch.save` with another method. Can be configured with the environment variable
2265
+ `DIFFUSERS_SAVE_MODE`.
2266
+ safe_serialization (`bool`, *optional*, defaults to `True`):
2267
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
2268
+ """
2269
+ state_dict = {}
2270
+
2271
+ if not (transformer_lora_layers or text_encoder_lora_layers):
2272
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
2273
+
2274
+ if transformer_lora_layers:
2275
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
2276
+
2277
+ if text_encoder_lora_layers:
2278
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
2279
+
2280
+ # Save the model
2281
+ cls.write_lora_layers(
2282
+ state_dict=state_dict,
2283
+ save_directory=save_directory,
2284
+ is_main_process=is_main_process,
2285
+ weight_name=weight_name,
2286
+ save_function=save_function,
2287
+ safe_serialization=safe_serialization,
2288
+ )
2289
+
2290
+
2291
+ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2292
+ r"""
2293
+ Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline`].
2294
+ """
2295
+
2296
+ _lora_loadable_modules = ["transformer"]
2297
+ transformer_name = TRANSFORMER_NAME
2298
+
2299
+ @classmethod
2300
+ @validate_hf_hub_args
2301
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
2302
+ def lora_state_dict(
2303
+ cls,
2304
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
2305
+ **kwargs,
2306
+ ):
2307
+ r"""
2308
+ Return state dict for lora weights and the network alphas.
2309
+
2310
+ <Tip warning={true}>
2311
+
2312
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
2313
+
2314
+ This function is experimental and might change in the future.
2315
+
2316
+ </Tip>
2317
+
2318
+ Parameters:
2319
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
2320
+ Can be either:
2321
+
2322
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
2323
+ the Hub.
2324
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
2325
+ with [`ModelMixin.save_pretrained`].
2326
+ - A [torch state
2327
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
2328
+
2329
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
2330
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
2331
+ is not used.
2332
+ force_download (`bool`, *optional*, defaults to `False`):
2333
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
2334
+ cached versions if they exist.
2335
+
2336
+ proxies (`Dict[str, str]`, *optional*):
2337
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
2338
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
2339
+ local_files_only (`bool`, *optional*, defaults to `False`):
2340
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
2341
+ won't be downloaded from the Hub.
2342
+ token (`str` or *bool*, *optional*):
2343
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
2344
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
2345
+ revision (`str`, *optional*, defaults to `"main"`):
2346
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
2347
+ allowed by Git.
2348
+ subfolder (`str`, *optional*, defaults to `""`):
2349
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
2350
+
2351
+ """
2352
+ # Load the main state dict first which has the LoRA layers for either of
2353
+ # transformer and text encoder or both.
2354
+ cache_dir = kwargs.pop("cache_dir", None)
2355
+ force_download = kwargs.pop("force_download", False)
2356
+ proxies = kwargs.pop("proxies", None)
2357
+ local_files_only = kwargs.pop("local_files_only", None)
2358
+ token = kwargs.pop("token", None)
2359
+ revision = kwargs.pop("revision", None)
2360
+ subfolder = kwargs.pop("subfolder", None)
2361
+ weight_name = kwargs.pop("weight_name", None)
2362
+ use_safetensors = kwargs.pop("use_safetensors", None)
2363
+
2364
+ allow_pickle = False
2365
+ if use_safetensors is None:
2366
+ use_safetensors = True
2367
+ allow_pickle = True
2368
+
2369
+ user_agent = {
2370
+ "file_type": "attn_procs_weights",
2371
+ "framework": "pytorch",
2372
+ }
2373
+
2374
+ state_dict = _fetch_state_dict(
2375
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
2376
+ weight_name=weight_name,
2377
+ use_safetensors=use_safetensors,
2378
+ local_files_only=local_files_only,
2379
+ cache_dir=cache_dir,
2380
+ force_download=force_download,
2381
+ proxies=proxies,
2382
+ token=token,
2383
+ revision=revision,
2384
+ subfolder=subfolder,
2385
+ user_agent=user_agent,
2386
+ allow_pickle=allow_pickle,
2387
+ )
2388
+
2389
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
2390
+ if is_dora_scale_present:
2391
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
2392
+ logger.warning(warn_msg)
2393
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
2394
+
2395
+ return state_dict
2396
+
2397
+ def load_lora_weights(
2398
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
2399
+ ):
2400
+ """
2401
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
2402
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
2403
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
2404
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
2405
+ dict is loaded into `self.transformer`.
2406
+
2407
+ Parameters:
2408
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
2409
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
2410
+ adapter_name (`str`, *optional*):
2411
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2412
+ `default_{i}` where i is the total number of adapters being loaded.
2413
+ low_cpu_mem_usage (`bool`, *optional*):
2414
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2415
+ weights.
2416
+ kwargs (`dict`, *optional*):
2417
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
2418
+ """
2419
+ if not USE_PEFT_BACKEND:
2420
+ raise ValueError("PEFT backend is required for this method.")
2421
+
2422
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
2423
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
2424
+ raise ValueError(
2425
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2426
+ )
2427
+
2428
+ # if a dict is passed, copy it instead of modifying it inplace
2429
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
2430
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
2431
+
2432
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
2433
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
2434
+
2435
+ is_correct_format = all("lora" in key for key in state_dict.keys())
2436
+ if not is_correct_format:
2437
+ raise ValueError("Invalid LoRA checkpoint.")
2438
+
2439
+ self.load_lora_into_transformer(
2440
+ state_dict,
2441
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
2442
+ adapter_name=adapter_name,
2443
+ _pipeline=self,
2444
+ low_cpu_mem_usage=low_cpu_mem_usage,
2445
+ )
2446
+
2447
+ @classmethod
2448
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
2449
+ def load_lora_into_transformer(
2450
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
2451
+ ):
2452
+ """
2453
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
2454
+
2455
+ Parameters:
2456
+ state_dict (`dict`):
2457
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
2458
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
2459
+ encoder lora layers.
2460
+ transformer (`CogVideoXTransformer3DModel`):
2461
+ The Transformer model to load the LoRA layers into.
2462
+ adapter_name (`str`, *optional*):
2463
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2464
+ `default_{i}` where i is the total number of adapters being loaded.
2465
+ low_cpu_mem_usage (`bool`, *optional*):
2466
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2467
+ weights.
2468
+ """
2469
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
2470
+ raise ValueError(
2471
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2472
+ )
2473
+
2474
+ # Load the layers corresponding to transformer.
2475
+ logger.info(f"Loading {cls.transformer_name}.")
2476
+ transformer.load_lora_adapter(
2477
+ state_dict,
2478
+ network_alphas=None,
2479
+ adapter_name=adapter_name,
2480
+ _pipeline=_pipeline,
2481
+ low_cpu_mem_usage=low_cpu_mem_usage,
2482
+ )
2483
+
2484
+ @classmethod
2485
+ # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
2486
+ def save_lora_weights(
2487
+ cls,
2488
+ save_directory: Union[str, os.PathLike],
2489
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
2490
+ is_main_process: bool = True,
2491
+ weight_name: str = None,
2492
+ save_function: Callable = None,
2493
+ safe_serialization: bool = True,
2494
+ ):
2495
+ r"""
2496
+ Save the LoRA parameters corresponding to the UNet and text encoder.
2497
+
2498
+ Arguments:
2499
+ save_directory (`str` or `os.PathLike`):
2500
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
2501
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
2502
+ State dict of the LoRA layers corresponding to the `transformer`.
2503
+ is_main_process (`bool`, *optional*, defaults to `True`):
2504
+ Whether the process calling this is the main process or not. Useful during distributed training and you
2505
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
2506
+ process to avoid race conditions.
2507
+ save_function (`Callable`):
2508
+ The function to use to save the state dictionary. Useful during distributed training when you need to
2509
+ replace `torch.save` with another method. Can be configured with the environment variable
2510
+ `DIFFUSERS_SAVE_MODE`.
2511
+ safe_serialization (`bool`, *optional*, defaults to `True`):
2512
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
2513
+ """
2514
+ state_dict = {}
2515
+
2516
+ if not transformer_lora_layers:
2517
+ raise ValueError("You must pass `transformer_lora_layers`.")
2518
+
2519
+ if transformer_lora_layers:
2520
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
2521
+
2522
+ # Save the model
2523
+ cls.write_lora_layers(
2524
+ state_dict=state_dict,
2525
+ save_directory=save_directory,
2526
+ is_main_process=is_main_process,
2527
+ weight_name=weight_name,
2528
+ save_function=save_function,
2529
+ safe_serialization=safe_serialization,
2530
+ )
2531
+
2532
+ def fuse_lora(
2533
+ self,
2534
+ components: List[str] = ["transformer"],
2535
+ lora_scale: float = 1.0,
2536
+ safe_fusing: bool = False,
2537
+ adapter_names: Optional[List[str]] = None,
2538
+ **kwargs,
2539
+ ):
2540
+ r"""
2541
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
2542
+
2543
+ <Tip warning={true}>
2544
+
2545
+ This is an experimental API.
2546
+
2547
+ </Tip>
2548
+
2549
+ Args:
2550
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
2551
+ lora_scale (`float`, defaults to 1.0):
2552
+ Controls how much to influence the outputs with the LoRA parameters.
2553
+ safe_fusing (`bool`, defaults to `False`):
2554
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
2555
+ adapter_names (`List[str]`, *optional*):
2556
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
2557
+
2558
+ Example:
2559
+
2560
+ ```py
2561
+ from diffusers import DiffusionPipeline
2562
+ import torch
2563
+
2564
+ pipeline = DiffusionPipeline.from_pretrained(
2565
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
2566
+ ).to("cuda")
2567
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
2568
+ pipeline.fuse_lora(lora_scale=0.7)
2569
+ ```
2570
+ """
2571
+ super().fuse_lora(
2572
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
2573
+ )
2574
+
2575
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
2576
+ r"""
2577
+ Reverses the effect of
2578
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
2579
+
2580
+ <Tip warning={true}>
2581
+
2582
+ This is an experimental API.
2583
+
2584
+ </Tip>
2585
+
2586
+ Args:
2587
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
2588
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
2589
+ """
2590
+ super().unfuse_lora(components=components)
2591
+
2592
+
2593
+ class Mochi1LoraLoaderMixin(LoraBaseMixin):
2594
+ r"""
2595
+ Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`].
2596
+ """
2597
+
2598
+ _lora_loadable_modules = ["transformer"]
2599
+ transformer_name = TRANSFORMER_NAME
2600
+
2601
+ @classmethod
2602
+ @validate_hf_hub_args
2603
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
2604
+ def lora_state_dict(
2605
+ cls,
2606
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
2607
+ **kwargs,
2608
+ ):
2609
+ r"""
2610
+ Return state dict for lora weights and the network alphas.
2611
+
2612
+ <Tip warning={true}>
2613
+
2614
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
2615
+
2616
+ This function is experimental and might change in the future.
2617
+
2618
+ </Tip>
2619
+
2620
+ Parameters:
2621
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
2622
+ Can be either:
2623
+
2624
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
2625
+ the Hub.
2626
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
2627
+ with [`ModelMixin.save_pretrained`].
2628
+ - A [torch state
2629
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
2630
+
2631
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
2632
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
2633
+ is not used.
2634
+ force_download (`bool`, *optional*, defaults to `False`):
2635
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
2636
+ cached versions if they exist.
2637
+
2638
+ proxies (`Dict[str, str]`, *optional*):
2639
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
2640
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
2641
+ local_files_only (`bool`, *optional*, defaults to `False`):
2642
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
2643
+ won't be downloaded from the Hub.
2644
+ token (`str` or *bool*, *optional*):
2645
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
2646
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
2647
+ revision (`str`, *optional*, defaults to `"main"`):
2648
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
2649
+ allowed by Git.
2650
+ subfolder (`str`, *optional*, defaults to `""`):
2651
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
2652
+
2653
+ """
2654
+ # Load the main state dict first which has the LoRA layers for either of
2655
+ # transformer and text encoder or both.
2656
+ cache_dir = kwargs.pop("cache_dir", None)
2657
+ force_download = kwargs.pop("force_download", False)
2658
+ proxies = kwargs.pop("proxies", None)
2659
+ local_files_only = kwargs.pop("local_files_only", None)
2660
+ token = kwargs.pop("token", None)
2661
+ revision = kwargs.pop("revision", None)
2662
+ subfolder = kwargs.pop("subfolder", None)
2663
+ weight_name = kwargs.pop("weight_name", None)
2664
+ use_safetensors = kwargs.pop("use_safetensors", None)
2665
+
2666
+ allow_pickle = False
2667
+ if use_safetensors is None:
2668
+ use_safetensors = True
2669
+ allow_pickle = True
2670
+
2671
+ user_agent = {
2672
+ "file_type": "attn_procs_weights",
2673
+ "framework": "pytorch",
2674
+ }
2675
+
2676
+ state_dict = _fetch_state_dict(
2677
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
2678
+ weight_name=weight_name,
2679
+ use_safetensors=use_safetensors,
2680
+ local_files_only=local_files_only,
2681
+ cache_dir=cache_dir,
2682
+ force_download=force_download,
2683
+ proxies=proxies,
2684
+ token=token,
2685
+ revision=revision,
2686
+ subfolder=subfolder,
2687
+ user_agent=user_agent,
2688
+ allow_pickle=allow_pickle,
2689
+ )
2690
+
2691
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
2692
+ if is_dora_scale_present:
2693
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
2694
+ logger.warning(warn_msg)
2695
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
2696
+
2697
+ return state_dict
2698
+
2699
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
2700
+ def load_lora_weights(
2701
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
2702
+ ):
2703
+ """
2704
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
2705
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
2706
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
2707
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
2708
+ dict is loaded into `self.transformer`.
2709
+
2710
+ Parameters:
2711
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
2712
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
2713
+ adapter_name (`str`, *optional*):
2714
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2715
+ `default_{i}` where i is the total number of adapters being loaded.
2716
+ low_cpu_mem_usage (`bool`, *optional*):
2717
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2718
+ weights.
2719
+ kwargs (`dict`, *optional*):
2720
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
2721
+ """
2722
+ if not USE_PEFT_BACKEND:
2723
+ raise ValueError("PEFT backend is required for this method.")
2724
+
2725
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
2726
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
2727
+ raise ValueError(
2728
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2729
+ )
2730
+
2731
+ # if a dict is passed, copy it instead of modifying it inplace
2732
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
2733
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
2734
+
2735
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
2736
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
2737
+
2738
+ is_correct_format = all("lora" in key for key in state_dict.keys())
2739
+ if not is_correct_format:
2740
+ raise ValueError("Invalid LoRA checkpoint.")
2741
+
2742
+ self.load_lora_into_transformer(
2743
+ state_dict,
2744
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
2745
+ adapter_name=adapter_name,
2746
+ _pipeline=self,
2747
+ low_cpu_mem_usage=low_cpu_mem_usage,
2748
+ )
2749
+
2750
+ @classmethod
2751
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
2752
+ def load_lora_into_transformer(
2753
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
2754
+ ):
2755
+ """
2756
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
2757
+
2758
+ Parameters:
2759
+ state_dict (`dict`):
2760
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
2761
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
2762
+ encoder lora layers.
2763
+ transformer (`MochiTransformer3DModel`):
2764
+ The Transformer model to load the LoRA layers into.
2765
+ adapter_name (`str`, *optional*):
2766
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2767
+ `default_{i}` where i is the total number of adapters being loaded.
2768
+ low_cpu_mem_usage (`bool`, *optional*):
2769
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2770
+ weights.
2771
+ """
2772
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
2773
+ raise ValueError(
2774
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2775
+ )
2776
+
2777
+ # Load the layers corresponding to transformer.
2778
+ logger.info(f"Loading {cls.transformer_name}.")
2779
+ transformer.load_lora_adapter(
2780
+ state_dict,
2781
+ network_alphas=None,
2782
+ adapter_name=adapter_name,
2783
+ _pipeline=_pipeline,
2784
+ low_cpu_mem_usage=low_cpu_mem_usage,
2785
+ )
2786
+
2787
+ @classmethod
2788
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
2789
+ def save_lora_weights(
2790
+ cls,
2791
+ save_directory: Union[str, os.PathLike],
2792
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
2793
+ is_main_process: bool = True,
2794
+ weight_name: str = None,
2795
+ save_function: Callable = None,
2796
+ safe_serialization: bool = True,
2797
+ ):
2798
+ r"""
2799
+ Save the LoRA parameters corresponding to the UNet and text encoder.
2800
+
2801
+ Arguments:
2802
+ save_directory (`str` or `os.PathLike`):
2803
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
2804
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
2805
+ State dict of the LoRA layers corresponding to the `transformer`.
2806
+ is_main_process (`bool`, *optional*, defaults to `True`):
2807
+ Whether the process calling this is the main process or not. Useful during distributed training and you
2808
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
2809
+ process to avoid race conditions.
2810
+ save_function (`Callable`):
2811
+ The function to use to save the state dictionary. Useful during distributed training when you need to
2812
+ replace `torch.save` with another method. Can be configured with the environment variable
2813
+ `DIFFUSERS_SAVE_MODE`.
2814
+ safe_serialization (`bool`, *optional*, defaults to `True`):
2815
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
2816
+ """
2817
+ state_dict = {}
2818
+
2819
+ if not transformer_lora_layers:
2820
+ raise ValueError("You must pass `transformer_lora_layers`.")
2821
+
2822
+ if transformer_lora_layers:
2823
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
2824
+
2825
+ # Save the model
2826
+ cls.write_lora_layers(
2827
+ state_dict=state_dict,
2828
+ save_directory=save_directory,
2829
+ is_main_process=is_main_process,
2830
+ weight_name=weight_name,
2831
+ save_function=save_function,
2832
+ safe_serialization=safe_serialization,
2833
+ )
2834
+
2835
+ def fuse_lora(
2836
+ self,
2837
+ components: List[str] = ["transformer"],
2838
+ lora_scale: float = 1.0,
2839
+ safe_fusing: bool = False,
2840
+ adapter_names: Optional[List[str]] = None,
2841
+ **kwargs,
2842
+ ):
2843
+ r"""
2844
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
2845
+
2846
+ <Tip warning={true}>
2847
+
2848
+ This is an experimental API.
2849
+
2850
+ </Tip>
2851
+
2852
+ Args:
2853
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
2854
+ lora_scale (`float`, defaults to 1.0):
2855
+ Controls how much to influence the outputs with the LoRA parameters.
2856
+ safe_fusing (`bool`, defaults to `False`):
2857
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
2858
+ adapter_names (`List[str]`, *optional*):
2859
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
2860
+
2861
+ Example:
2862
+
2863
+ ```py
2864
+ from diffusers import DiffusionPipeline
2865
+ import torch
2866
+
2867
+ pipeline = DiffusionPipeline.from_pretrained(
2868
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
2869
+ ).to("cuda")
2870
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
2871
+ pipeline.fuse_lora(lora_scale=0.7)
2872
+ ```
2873
+ """
2874
+ super().fuse_lora(
2875
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
2876
+ )
2877
+
2878
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
2879
+ r"""
2880
+ Reverses the effect of
2881
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
2882
+
2883
+ <Tip warning={true}>
2884
+
2885
+ This is an experimental API.
2886
+
2887
+ </Tip>
2888
+
2889
+ Args:
2890
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
2891
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
2892
+ """
2893
+ super().unfuse_lora(components=components)
2894
+
2895
+
2896
+ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
2897
+ r"""
2898
+ Load LoRA layers into [`LTXVideoTransformer3DModel`]. Specific to [`LTXPipeline`].
2899
+ """
2900
+
2901
+ _lora_loadable_modules = ["transformer"]
2902
+ transformer_name = TRANSFORMER_NAME
2903
+
2904
+ @classmethod
2905
+ @validate_hf_hub_args
2906
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
2907
+ def lora_state_dict(
2908
+ cls,
2909
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
2910
+ **kwargs,
2911
+ ):
2912
+ r"""
2913
+ Return state dict for lora weights and the network alphas.
2914
+
2915
+ <Tip warning={true}>
2916
+
2917
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
2918
+
2919
+ This function is experimental and might change in the future.
2920
+
2921
+ </Tip>
2922
+
2923
+ Parameters:
2924
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
2925
+ Can be either:
2926
+
2927
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
2928
+ the Hub.
2929
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
2930
+ with [`ModelMixin.save_pretrained`].
2931
+ - A [torch state
2932
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
2933
+
2934
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
2935
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
2936
+ is not used.
2937
+ force_download (`bool`, *optional*, defaults to `False`):
2938
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
2939
+ cached versions if they exist.
2940
+
2941
+ proxies (`Dict[str, str]`, *optional*):
2942
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
2943
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
2944
+ local_files_only (`bool`, *optional*, defaults to `False`):
2945
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
2946
+ won't be downloaded from the Hub.
2947
+ token (`str` or *bool*, *optional*):
2948
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
2949
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
2950
+ revision (`str`, *optional*, defaults to `"main"`):
2951
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
2952
+ allowed by Git.
2953
+ subfolder (`str`, *optional*, defaults to `""`):
2954
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
2955
+
2956
+ """
2957
+ # Load the main state dict first which has the LoRA layers for either of
2958
+ # transformer and text encoder or both.
2959
+ cache_dir = kwargs.pop("cache_dir", None)
2960
+ force_download = kwargs.pop("force_download", False)
2961
+ proxies = kwargs.pop("proxies", None)
2962
+ local_files_only = kwargs.pop("local_files_only", None)
2963
+ token = kwargs.pop("token", None)
2964
+ revision = kwargs.pop("revision", None)
2965
+ subfolder = kwargs.pop("subfolder", None)
2966
+ weight_name = kwargs.pop("weight_name", None)
2967
+ use_safetensors = kwargs.pop("use_safetensors", None)
2968
+
2969
+ allow_pickle = False
2970
+ if use_safetensors is None:
2971
+ use_safetensors = True
2972
+ allow_pickle = True
2973
+
2974
+ user_agent = {
2975
+ "file_type": "attn_procs_weights",
2976
+ "framework": "pytorch",
2977
+ }
2978
+
2979
+ state_dict = _fetch_state_dict(
2980
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
2981
+ weight_name=weight_name,
2982
+ use_safetensors=use_safetensors,
2983
+ local_files_only=local_files_only,
2984
+ cache_dir=cache_dir,
2985
+ force_download=force_download,
2986
+ proxies=proxies,
2987
+ token=token,
2988
+ revision=revision,
2989
+ subfolder=subfolder,
2990
+ user_agent=user_agent,
2991
+ allow_pickle=allow_pickle,
2992
+ )
2993
+
2994
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
2995
+ if is_dora_scale_present:
2996
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
2997
+ logger.warning(warn_msg)
2998
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
2999
+
3000
+ return state_dict
3001
+
3002
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3003
+ def load_lora_weights(
3004
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
3005
+ ):
3006
+ """
3007
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3008
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3009
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3010
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3011
+ dict is loaded into `self.transformer`.
3012
+
3013
+ Parameters:
3014
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3015
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3016
+ adapter_name (`str`, *optional*):
3017
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3018
+ `default_{i}` where i is the total number of adapters being loaded.
3019
+ low_cpu_mem_usage (`bool`, *optional*):
3020
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3021
+ weights.
3022
+ kwargs (`dict`, *optional*):
3023
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3024
+ """
3025
+ if not USE_PEFT_BACKEND:
3026
+ raise ValueError("PEFT backend is required for this method.")
3027
+
3028
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
3029
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3030
+ raise ValueError(
3031
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3032
+ )
3033
+
3034
+ # if a dict is passed, copy it instead of modifying it inplace
3035
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
3036
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
3037
+
3038
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3039
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3040
+
3041
+ is_correct_format = all("lora" in key for key in state_dict.keys())
3042
+ if not is_correct_format:
3043
+ raise ValueError("Invalid LoRA checkpoint.")
3044
+
3045
+ self.load_lora_into_transformer(
3046
+ state_dict,
3047
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
3048
+ adapter_name=adapter_name,
3049
+ _pipeline=self,
3050
+ low_cpu_mem_usage=low_cpu_mem_usage,
3051
+ )
3052
+
3053
+ @classmethod
3054
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
3055
+ def load_lora_into_transformer(
3056
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
3057
+ ):
3058
+ """
3059
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
3060
+
3061
+ Parameters:
3062
+ state_dict (`dict`):
3063
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3064
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3065
+ encoder lora layers.
3066
+ transformer (`LTXVideoTransformer3DModel`):
3067
+ The Transformer model to load the LoRA layers into.
3068
+ adapter_name (`str`, *optional*):
3069
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3070
+ `default_{i}` where i is the total number of adapters being loaded.
3071
+ low_cpu_mem_usage (`bool`, *optional*):
3072
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3073
+ weights.
3074
+ """
3075
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3076
+ raise ValueError(
3077
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3078
+ )
3079
+
3080
+ # Load the layers corresponding to transformer.
3081
+ logger.info(f"Loading {cls.transformer_name}.")
3082
+ transformer.load_lora_adapter(
3083
+ state_dict,
3084
+ network_alphas=None,
3085
+ adapter_name=adapter_name,
3086
+ _pipeline=_pipeline,
3087
+ low_cpu_mem_usage=low_cpu_mem_usage,
3088
+ )
3089
+
3090
+ @classmethod
3091
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
3092
+ def save_lora_weights(
3093
+ cls,
3094
+ save_directory: Union[str, os.PathLike],
3095
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
3096
+ is_main_process: bool = True,
3097
+ weight_name: str = None,
3098
+ save_function: Callable = None,
3099
+ safe_serialization: bool = True,
3100
+ ):
3101
+ r"""
3102
+ Save the LoRA parameters corresponding to the UNet and text encoder.
3103
+
3104
+ Arguments:
3105
+ save_directory (`str` or `os.PathLike`):
3106
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
3107
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
3108
+ State dict of the LoRA layers corresponding to the `transformer`.
3109
+ is_main_process (`bool`, *optional*, defaults to `True`):
3110
+ Whether the process calling this is the main process or not. Useful during distributed training and you
3111
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
3112
+ process to avoid race conditions.
3113
+ save_function (`Callable`):
3114
+ The function to use to save the state dictionary. Useful during distributed training when you need to
3115
+ replace `torch.save` with another method. Can be configured with the environment variable
3116
+ `DIFFUSERS_SAVE_MODE`.
3117
+ safe_serialization (`bool`, *optional*, defaults to `True`):
3118
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3119
+ """
3120
+ state_dict = {}
3121
+
3122
+ if not transformer_lora_layers:
3123
+ raise ValueError("You must pass `transformer_lora_layers`.")
3124
+
3125
+ if transformer_lora_layers:
3126
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3127
+
3128
+ # Save the model
3129
+ cls.write_lora_layers(
3130
+ state_dict=state_dict,
3131
+ save_directory=save_directory,
3132
+ is_main_process=is_main_process,
3133
+ weight_name=weight_name,
3134
+ save_function=save_function,
3135
+ safe_serialization=safe_serialization,
3136
+ )
3137
+
3138
+ def fuse_lora(
3139
+ self,
3140
+ components: List[str] = ["transformer"],
3141
+ lora_scale: float = 1.0,
3142
+ safe_fusing: bool = False,
3143
+ adapter_names: Optional[List[str]] = None,
3144
+ **kwargs,
3145
+ ):
3146
+ r"""
3147
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
3148
+
3149
+ <Tip warning={true}>
3150
+
3151
+ This is an experimental API.
3152
+
3153
+ </Tip>
3154
+
3155
+ Args:
3156
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
3157
+ lora_scale (`float`, defaults to 1.0):
3158
+ Controls how much to influence the outputs with the LoRA parameters.
3159
+ safe_fusing (`bool`, defaults to `False`):
3160
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
3161
+ adapter_names (`List[str]`, *optional*):
3162
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
3163
+
3164
+ Example:
3165
+
3166
+ ```py
3167
+ from diffusers import DiffusionPipeline
3168
+ import torch
3169
+
3170
+ pipeline = DiffusionPipeline.from_pretrained(
3171
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
3172
+ ).to("cuda")
3173
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
3174
+ pipeline.fuse_lora(lora_scale=0.7)
3175
+ ```
3176
+ """
3177
+ super().fuse_lora(
3178
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3179
+ )
3180
+
3181
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
3182
+ r"""
3183
+ Reverses the effect of
3184
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
3185
+
3186
+ <Tip warning={true}>
3187
+
3188
+ This is an experimental API.
3189
+
3190
+ </Tip>
3191
+
3192
+ Args:
3193
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3194
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3195
+ """
3196
+ super().unfuse_lora(components=components)
3197
+
3198
+
3199
+ class SanaLoraLoaderMixin(LoraBaseMixin):
3200
+ r"""
3201
+ Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
3202
+ """
3203
+
3204
+ _lora_loadable_modules = ["transformer"]
3205
+ transformer_name = TRANSFORMER_NAME
3206
+
3207
+ @classmethod
3208
+ @validate_hf_hub_args
3209
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
3210
+ def lora_state_dict(
3211
+ cls,
3212
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3213
+ **kwargs,
3214
+ ):
3215
+ r"""
3216
+ Return state dict for lora weights and the network alphas.
3217
+
3218
+ <Tip warning={true}>
3219
+
3220
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
3221
+
3222
+ This function is experimental and might change in the future.
3223
+
3224
+ </Tip>
3225
+
3226
+ Parameters:
3227
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3228
+ Can be either:
3229
+
3230
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3231
+ the Hub.
3232
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3233
+ with [`ModelMixin.save_pretrained`].
3234
+ - A [torch state
3235
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3236
+
3237
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3238
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3239
+ is not used.
3240
+ force_download (`bool`, *optional*, defaults to `False`):
3241
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3242
+ cached versions if they exist.
3243
+
3244
+ proxies (`Dict[str, str]`, *optional*):
3245
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3246
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3247
+ local_files_only (`bool`, *optional*, defaults to `False`):
3248
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
3249
+ won't be downloaded from the Hub.
3250
+ token (`str` or *bool*, *optional*):
3251
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3252
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
3253
+ revision (`str`, *optional*, defaults to `"main"`):
3254
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3255
+ allowed by Git.
3256
+ subfolder (`str`, *optional*, defaults to `""`):
3257
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
3258
+
3259
+ """
3260
+ # Load the main state dict first which has the LoRA layers for either of
3261
+ # transformer and text encoder or both.
3262
+ cache_dir = kwargs.pop("cache_dir", None)
3263
+ force_download = kwargs.pop("force_download", False)
3264
+ proxies = kwargs.pop("proxies", None)
3265
+ local_files_only = kwargs.pop("local_files_only", None)
3266
+ token = kwargs.pop("token", None)
3267
+ revision = kwargs.pop("revision", None)
3268
+ subfolder = kwargs.pop("subfolder", None)
3269
+ weight_name = kwargs.pop("weight_name", None)
3270
+ use_safetensors = kwargs.pop("use_safetensors", None)
3271
+
3272
+ allow_pickle = False
3273
+ if use_safetensors is None:
3274
+ use_safetensors = True
3275
+ allow_pickle = True
3276
+
3277
+ user_agent = {
3278
+ "file_type": "attn_procs_weights",
3279
+ "framework": "pytorch",
3280
+ }
3281
+
3282
+ state_dict = _fetch_state_dict(
3283
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
3284
+ weight_name=weight_name,
3285
+ use_safetensors=use_safetensors,
3286
+ local_files_only=local_files_only,
3287
+ cache_dir=cache_dir,
3288
+ force_download=force_download,
3289
+ proxies=proxies,
3290
+ token=token,
3291
+ revision=revision,
3292
+ subfolder=subfolder,
3293
+ user_agent=user_agent,
3294
+ allow_pickle=allow_pickle,
3295
+ )
3296
+
3297
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
3298
+ if is_dora_scale_present:
3299
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3300
+ logger.warning(warn_msg)
3301
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
3302
+
3303
+ return state_dict
3304
+
3305
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3306
+ def load_lora_weights(
3307
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
3308
+ ):
3309
+ """
3310
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3311
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3312
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3313
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3314
+ dict is loaded into `self.transformer`.
3315
+
3316
+ Parameters:
3317
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3318
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3319
+ adapter_name (`str`, *optional*):
3320
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3321
+ `default_{i}` where i is the total number of adapters being loaded.
3322
+ low_cpu_mem_usage (`bool`, *optional*):
3323
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3324
+ weights.
3325
+ kwargs (`dict`, *optional*):
3326
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3327
+ """
3328
+ if not USE_PEFT_BACKEND:
3329
+ raise ValueError("PEFT backend is required for this method.")
3330
+
3331
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
3332
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3333
+ raise ValueError(
3334
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3335
+ )
3336
+
3337
+ # if a dict is passed, copy it instead of modifying it inplace
3338
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
3339
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
3340
+
3341
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3342
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3343
+
3344
+ is_correct_format = all("lora" in key for key in state_dict.keys())
3345
+ if not is_correct_format:
3346
+ raise ValueError("Invalid LoRA checkpoint.")
3347
+
3348
+ self.load_lora_into_transformer(
3349
+ state_dict,
3350
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
3351
+ adapter_name=adapter_name,
3352
+ _pipeline=self,
3353
+ low_cpu_mem_usage=low_cpu_mem_usage,
3354
+ )
3355
+
3356
+ @classmethod
3357
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
3358
+ def load_lora_into_transformer(
3359
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
3360
+ ):
3361
+ """
3362
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
3363
+
3364
+ Parameters:
3365
+ state_dict (`dict`):
3366
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3367
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3368
+ encoder lora layers.
3369
+ transformer (`SanaTransformer2DModel`):
3370
+ The Transformer model to load the LoRA layers into.
3371
+ adapter_name (`str`, *optional*):
3372
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3373
+ `default_{i}` where i is the total number of adapters being loaded.
3374
+ low_cpu_mem_usage (`bool`, *optional*):
3375
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3376
+ weights.
3377
+ """
3378
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3379
+ raise ValueError(
3380
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3381
+ )
3382
+
3383
+ # Load the layers corresponding to transformer.
3384
+ logger.info(f"Loading {cls.transformer_name}.")
3385
+ transformer.load_lora_adapter(
3386
+ state_dict,
3387
+ network_alphas=None,
3388
+ adapter_name=adapter_name,
3389
+ _pipeline=_pipeline,
3390
+ low_cpu_mem_usage=low_cpu_mem_usage,
3391
+ )
3392
+
3393
+ @classmethod
3394
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
3395
+ def save_lora_weights(
3396
+ cls,
3397
+ save_directory: Union[str, os.PathLike],
3398
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
3399
+ is_main_process: bool = True,
3400
+ weight_name: str = None,
3401
+ save_function: Callable = None,
3402
+ safe_serialization: bool = True,
3403
+ ):
3404
+ r"""
3405
+ Save the LoRA parameters corresponding to the UNet and text encoder.
3406
+
3407
+ Arguments:
3408
+ save_directory (`str` or `os.PathLike`):
3409
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
3410
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
3411
+ State dict of the LoRA layers corresponding to the `transformer`.
3412
+ is_main_process (`bool`, *optional*, defaults to `True`):
3413
+ Whether the process calling this is the main process or not. Useful during distributed training and you
3414
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
3415
+ process to avoid race conditions.
3416
+ save_function (`Callable`):
3417
+ The function to use to save the state dictionary. Useful during distributed training when you need to
3418
+ replace `torch.save` with another method. Can be configured with the environment variable
3419
+ `DIFFUSERS_SAVE_MODE`.
3420
+ safe_serialization (`bool`, *optional*, defaults to `True`):
3421
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3422
+ """
3423
+ state_dict = {}
3424
+
3425
+ if not transformer_lora_layers:
3426
+ raise ValueError("You must pass `transformer_lora_layers`.")
3427
+
3428
+ if transformer_lora_layers:
3429
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3430
+
3431
+ # Save the model
3432
+ cls.write_lora_layers(
3433
+ state_dict=state_dict,
3434
+ save_directory=save_directory,
3435
+ is_main_process=is_main_process,
3436
+ weight_name=weight_name,
3437
+ save_function=save_function,
3438
+ safe_serialization=safe_serialization,
3439
+ )
3440
+
3441
+ def fuse_lora(
3442
+ self,
3443
+ components: List[str] = ["transformer"],
3444
+ lora_scale: float = 1.0,
3445
+ safe_fusing: bool = False,
3446
+ adapter_names: Optional[List[str]] = None,
3447
+ **kwargs,
3448
+ ):
3449
+ r"""
3450
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
3451
+
3452
+ <Tip warning={true}>
3453
+
3454
+ This is an experimental API.
3455
+
3456
+ </Tip>
3457
+
3458
+ Args:
3459
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
3460
+ lora_scale (`float`, defaults to 1.0):
3461
+ Controls how much to influence the outputs with the LoRA parameters.
3462
+ safe_fusing (`bool`, defaults to `False`):
3463
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
3464
+ adapter_names (`List[str]`, *optional*):
3465
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
3466
+
3467
+ Example:
3468
+
3469
+ ```py
3470
+ from diffusers import DiffusionPipeline
3471
+ import torch
3472
+
3473
+ pipeline = DiffusionPipeline.from_pretrained(
3474
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
3475
+ ).to("cuda")
3476
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
3477
+ pipeline.fuse_lora(lora_scale=0.7)
3478
+ ```
3479
+ """
3480
+ super().fuse_lora(
3481
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3482
+ )
3483
+
3484
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
3485
+ r"""
3486
+ Reverses the effect of
3487
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
3488
+
3489
+ <Tip warning={true}>
3490
+
3491
+ This is an experimental API.
3492
+
3493
+ </Tip>
3494
+
3495
+ Args:
3496
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3497
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3498
+ """
3499
+ super().unfuse_lora(components=components)
3500
+
3501
+
3502
+ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
3503
+ r"""
3504
+ Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
3505
+ """
3506
+
3507
+ _lora_loadable_modules = ["transformer"]
3508
+ transformer_name = TRANSFORMER_NAME
3509
+
3510
+ @classmethod
3511
+ @validate_hf_hub_args
3512
+ def lora_state_dict(
3513
+ cls,
3514
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3515
+ **kwargs,
3516
+ ):
3517
+ r"""
3518
+ Return state dict for lora weights and the network alphas.
3519
+
3520
+ <Tip warning={true}>
3521
+
3522
+ We support loading original format HunyuanVideo LoRA checkpoints.
3523
+
3524
+ This function is experimental and might change in the future.
3525
+
3526
+ </Tip>
3527
+
3528
+ Parameters:
3529
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3530
+ Can be either:
3531
+
3532
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3533
+ the Hub.
3534
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3535
+ with [`ModelMixin.save_pretrained`].
3536
+ - A [torch state
3537
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3538
+
3539
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3540
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3541
+ is not used.
3542
+ force_download (`bool`, *optional*, defaults to `False`):
3543
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3544
+ cached versions if they exist.
3545
+
3546
+ proxies (`Dict[str, str]`, *optional*):
3547
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3548
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3549
+ local_files_only (`bool`, *optional*, defaults to `False`):
3550
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
3551
+ won't be downloaded from the Hub.
3552
+ token (`str` or *bool*, *optional*):
3553
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3554
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
3555
+ revision (`str`, *optional*, defaults to `"main"`):
3556
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3557
+ allowed by Git.
3558
+ subfolder (`str`, *optional*, defaults to `""`):
3559
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
3560
+
3561
+ """
3562
+ # Load the main state dict first which has the LoRA layers for either of
3563
+ # transformer and text encoder or both.
3564
+ cache_dir = kwargs.pop("cache_dir", None)
3565
+ force_download = kwargs.pop("force_download", False)
3566
+ proxies = kwargs.pop("proxies", None)
3567
+ local_files_only = kwargs.pop("local_files_only", None)
3568
+ token = kwargs.pop("token", None)
3569
+ revision = kwargs.pop("revision", None)
3570
+ subfolder = kwargs.pop("subfolder", None)
3571
+ weight_name = kwargs.pop("weight_name", None)
3572
+ use_safetensors = kwargs.pop("use_safetensors", None)
3573
+
3574
+ allow_pickle = False
3575
+ if use_safetensors is None:
3576
+ use_safetensors = True
3577
+ allow_pickle = True
3578
+
3579
+ user_agent = {
3580
+ "file_type": "attn_procs_weights",
3581
+ "framework": "pytorch",
3582
+ }
3583
+
3584
+ state_dict = _fetch_state_dict(
3585
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
3586
+ weight_name=weight_name,
3587
+ use_safetensors=use_safetensors,
3588
+ local_files_only=local_files_only,
3589
+ cache_dir=cache_dir,
3590
+ force_download=force_download,
3591
+ proxies=proxies,
3592
+ token=token,
3593
+ revision=revision,
3594
+ subfolder=subfolder,
3595
+ user_agent=user_agent,
3596
+ allow_pickle=allow_pickle,
3597
+ )
3598
+
3599
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
3600
+ if is_dora_scale_present:
3601
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3602
+ logger.warning(warn_msg)
3603
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
3604
+
3605
+ is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
3606
+ if is_original_hunyuan_video:
3607
+ state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
3608
+
3609
+ return state_dict
3610
+
3611
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3612
+ def load_lora_weights(
3613
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
3614
+ ):
3615
+ """
3616
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3617
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3618
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3619
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3620
+ dict is loaded into `self.transformer`.
3621
+
3622
+ Parameters:
3623
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3624
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3625
+ adapter_name (`str`, *optional*):
3626
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3627
+ `default_{i}` where i is the total number of adapters being loaded.
3628
+ low_cpu_mem_usage (`bool`, *optional*):
3629
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3630
+ weights.
3631
+ kwargs (`dict`, *optional*):
3632
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3633
+ """
3634
+ if not USE_PEFT_BACKEND:
3635
+ raise ValueError("PEFT backend is required for this method.")
3636
+
3637
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
3638
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3639
+ raise ValueError(
3640
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3641
+ )
3642
+
3643
+ # if a dict is passed, copy it instead of modifying it inplace
3644
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
3645
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
3646
+
3647
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3648
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3649
+
3650
+ is_correct_format = all("lora" in key for key in state_dict.keys())
3651
+ if not is_correct_format:
3652
+ raise ValueError("Invalid LoRA checkpoint.")
3653
+
3654
+ self.load_lora_into_transformer(
3655
+ state_dict,
3656
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
3657
+ adapter_name=adapter_name,
3658
+ _pipeline=self,
3659
+ low_cpu_mem_usage=low_cpu_mem_usage,
3660
+ )
3661
+
3662
+ @classmethod
3663
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
3664
+ def load_lora_into_transformer(
3665
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
3666
+ ):
3667
+ """
3668
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
3669
+
3670
+ Parameters:
3671
+ state_dict (`dict`):
3672
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3673
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3674
+ encoder lora layers.
3675
+ transformer (`HunyuanVideoTransformer3DModel`):
3676
+ The Transformer model to load the LoRA layers into.
3677
+ adapter_name (`str`, *optional*):
3678
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3679
+ `default_{i}` where i is the total number of adapters being loaded.
3680
+ low_cpu_mem_usage (`bool`, *optional*):
3681
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3682
+ weights.
3683
+ """
3684
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3685
+ raise ValueError(
3686
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3687
+ )
3688
+
3689
+ # Load the layers corresponding to transformer.
3690
+ logger.info(f"Loading {cls.transformer_name}.")
3691
+ transformer.load_lora_adapter(
3692
+ state_dict,
3693
+ network_alphas=None,
3694
+ adapter_name=adapter_name,
3695
+ _pipeline=_pipeline,
3696
+ low_cpu_mem_usage=low_cpu_mem_usage,
3697
+ )
3698
+
3699
+ @classmethod
3700
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
3701
+ def save_lora_weights(
3702
+ cls,
3703
+ save_directory: Union[str, os.PathLike],
3704
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
3705
+ is_main_process: bool = True,
3706
+ weight_name: str = None,
3707
+ save_function: Callable = None,
3708
+ safe_serialization: bool = True,
3709
+ ):
3710
+ r"""
3711
+ Save the LoRA parameters corresponding to the UNet and text encoder.
3712
+
3713
+ Arguments:
3714
+ save_directory (`str` or `os.PathLike`):
3715
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
3716
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
3717
+ State dict of the LoRA layers corresponding to the `transformer`.
3718
+ is_main_process (`bool`, *optional*, defaults to `True`):
3719
+ Whether the process calling this is the main process or not. Useful during distributed training and you
3720
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
3721
+ process to avoid race conditions.
3722
+ save_function (`Callable`):
3723
+ The function to use to save the state dictionary. Useful during distributed training when you need to
3724
+ replace `torch.save` with another method. Can be configured with the environment variable
3725
+ `DIFFUSERS_SAVE_MODE`.
3726
+ safe_serialization (`bool`, *optional*, defaults to `True`):
3727
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3728
+ """
3729
+ state_dict = {}
3730
+
3731
+ if not transformer_lora_layers:
3732
+ raise ValueError("You must pass `transformer_lora_layers`.")
3733
+
3734
+ if transformer_lora_layers:
3735
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3736
+
3737
+ # Save the model
3738
+ cls.write_lora_layers(
3739
+ state_dict=state_dict,
3740
+ save_directory=save_directory,
3741
+ is_main_process=is_main_process,
3742
+ weight_name=weight_name,
3743
+ save_function=save_function,
3744
+ safe_serialization=safe_serialization,
3745
+ )
3746
+
3747
+ def fuse_lora(
3748
+ self,
3749
+ components: List[str] = ["transformer"],
3750
+ lora_scale: float = 1.0,
3751
+ safe_fusing: bool = False,
3752
+ adapter_names: Optional[List[str]] = None,
3753
+ **kwargs,
3754
+ ):
3755
+ r"""
3756
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
3757
+
3758
+ <Tip warning={true}>
3759
+
3760
+ This is an experimental API.
3761
+
3762
+ </Tip>
3763
+
3764
+ Args:
3765
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
3766
+ lora_scale (`float`, defaults to 1.0):
3767
+ Controls how much to influence the outputs with the LoRA parameters.
3768
+ safe_fusing (`bool`, defaults to `False`):
3769
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
3770
+ adapter_names (`List[str]`, *optional*):
3771
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
3772
+
3773
+ Example:
3774
+
3775
+ ```py
3776
+ from diffusers import DiffusionPipeline
3777
+ import torch
3778
+
3779
+ pipeline = DiffusionPipeline.from_pretrained(
3780
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
3781
+ ).to("cuda")
3782
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
3783
+ pipeline.fuse_lora(lora_scale=0.7)
3784
+ ```
3785
+ """
3786
+ super().fuse_lora(
3787
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3788
+ )
3789
+
3790
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
3791
+ r"""
3792
+ Reverses the effect of
3793
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
3794
+
3795
+ <Tip warning={true}>
3796
+
3797
+ This is an experimental API.
3798
+
3799
+ </Tip>
3800
+
3801
+ Args:
3802
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3803
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3804
+ """
3805
+ super().unfuse_lora(components=components)
3806
+
3807
+
3808
+ class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
3809
+ def __init__(self, *args, **kwargs):
3810
+ deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
3811
+ deprecate("LoraLoaderMixin", "1.0.0", deprecation_message)
3812
+ super().__init__(*args, **kwargs)