diffusers 0.33.1__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (478) hide show
  1. diffusers/__init__.py +48 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/dependency_versions_check.py +1 -1
  7. diffusers/dependency_versions_table.py +1 -1
  8. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  9. diffusers/hooks/faster_cache.py +2 -2
  10. diffusers/hooks/group_offloading.py +128 -29
  11. diffusers/hooks/hooks.py +2 -2
  12. diffusers/hooks/layerwise_casting.py +3 -3
  13. diffusers/hooks/pyramid_attention_broadcast.py +1 -1
  14. diffusers/image_processor.py +7 -2
  15. diffusers/loaders/__init__.py +4 -0
  16. diffusers/loaders/ip_adapter.py +5 -14
  17. diffusers/loaders/lora_base.py +212 -111
  18. diffusers/loaders/lora_conversion_utils.py +275 -34
  19. diffusers/loaders/lora_pipeline.py +1554 -819
  20. diffusers/loaders/peft.py +52 -109
  21. diffusers/loaders/single_file.py +2 -2
  22. diffusers/loaders/single_file_model.py +20 -4
  23. diffusers/loaders/single_file_utils.py +225 -5
  24. diffusers/loaders/textual_inversion.py +3 -2
  25. diffusers/loaders/transformer_flux.py +1 -1
  26. diffusers/loaders/transformer_sd3.py +2 -2
  27. diffusers/loaders/unet.py +2 -16
  28. diffusers/loaders/unet_loader_utils.py +1 -1
  29. diffusers/loaders/utils.py +1 -1
  30. diffusers/models/__init__.py +15 -1
  31. diffusers/models/activations.py +5 -5
  32. diffusers/models/adapter.py +2 -3
  33. diffusers/models/attention.py +4 -4
  34. diffusers/models/attention_flax.py +10 -10
  35. diffusers/models/attention_processor.py +14 -10
  36. diffusers/models/auto_model.py +47 -10
  37. diffusers/models/autoencoders/__init__.py +1 -0
  38. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  39. diffusers/models/autoencoders/autoencoder_dc.py +3 -3
  40. diffusers/models/autoencoders/autoencoder_kl.py +4 -4
  41. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  42. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  43. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
  44. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  45. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  46. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  47. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  48. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  49. diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
  50. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  52. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  53. diffusers/models/autoencoders/vae.py +13 -2
  54. diffusers/models/autoencoders/vq_model.py +2 -2
  55. diffusers/models/cache_utils.py +1 -1
  56. diffusers/models/controlnet.py +1 -1
  57. diffusers/models/controlnet_flux.py +1 -1
  58. diffusers/models/controlnet_sd3.py +1 -1
  59. diffusers/models/controlnet_sparsectrl.py +1 -1
  60. diffusers/models/controlnets/__init__.py +1 -0
  61. diffusers/models/controlnets/controlnet.py +3 -3
  62. diffusers/models/controlnets/controlnet_flax.py +1 -1
  63. diffusers/models/controlnets/controlnet_flux.py +16 -15
  64. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  65. diffusers/models/controlnets/controlnet_sana.py +290 -0
  66. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  67. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  68. diffusers/models/controlnets/controlnet_union.py +1 -1
  69. diffusers/models/controlnets/controlnet_xs.py +7 -7
  70. diffusers/models/controlnets/multicontrolnet.py +4 -5
  71. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  72. diffusers/models/downsampling.py +2 -2
  73. diffusers/models/embeddings.py +10 -12
  74. diffusers/models/embeddings_flax.py +2 -2
  75. diffusers/models/lora.py +3 -3
  76. diffusers/models/modeling_utils.py +44 -14
  77. diffusers/models/normalization.py +4 -4
  78. diffusers/models/resnet.py +2 -2
  79. diffusers/models/resnet_flax.py +1 -1
  80. diffusers/models/transformers/__init__.py +5 -0
  81. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  82. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  83. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  84. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  85. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  86. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  87. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  88. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  89. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  90. diffusers/models/transformers/prior_transformer.py +1 -1
  91. diffusers/models/transformers/sana_transformer.py +8 -3
  92. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  93. diffusers/models/transformers/t5_film_transformer.py +3 -3
  94. diffusers/models/transformers/transformer_2d.py +1 -1
  95. diffusers/models/transformers/transformer_allegro.py +1 -1
  96. diffusers/models/transformers/transformer_chroma.py +742 -0
  97. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  98. diffusers/models/transformers/transformer_cogview4.py +317 -25
  99. diffusers/models/transformers/transformer_cosmos.py +579 -0
  100. diffusers/models/transformers/transformer_flux.py +9 -11
  101. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  102. diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
  103. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  104. diffusers/models/transformers/transformer_ltx.py +2 -2
  105. diffusers/models/transformers/transformer_lumina2.py +1 -1
  106. diffusers/models/transformers/transformer_mochi.py +1 -1
  107. diffusers/models/transformers/transformer_omnigen.py +2 -2
  108. diffusers/models/transformers/transformer_sd3.py +7 -7
  109. diffusers/models/transformers/transformer_temporal.py +1 -1
  110. diffusers/models/transformers/transformer_wan.py +24 -8
  111. diffusers/models/transformers/transformer_wan_vace.py +393 -0
  112. diffusers/models/unets/unet_1d.py +1 -1
  113. diffusers/models/unets/unet_1d_blocks.py +1 -1
  114. diffusers/models/unets/unet_2d.py +1 -1
  115. diffusers/models/unets/unet_2d_blocks.py +1 -1
  116. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  117. diffusers/models/unets/unet_2d_condition.py +2 -2
  118. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  119. diffusers/models/unets/unet_3d_blocks.py +1 -1
  120. diffusers/models/unets/unet_3d_condition.py +3 -3
  121. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  122. diffusers/models/unets/unet_kandinsky3.py +1 -1
  123. diffusers/models/unets/unet_motion_model.py +2 -2
  124. diffusers/models/unets/unet_stable_cascade.py +1 -1
  125. diffusers/models/upsampling.py +2 -2
  126. diffusers/models/vae_flax.py +2 -2
  127. diffusers/models/vq_model.py +1 -1
  128. diffusers/pipelines/__init__.py +37 -6
  129. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  130. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  131. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  132. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  133. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  134. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  135. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  136. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  137. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  138. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  139. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  140. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  141. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
  142. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  143. diffusers/pipelines/auto_pipeline.py +6 -7
  144. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  145. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  146. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  147. diffusers/pipelines/chroma/__init__.py +49 -0
  148. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  149. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  150. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  151. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
  152. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
  153. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
  154. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
  155. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  156. diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
  157. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  158. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  159. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  160. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  161. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  162. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
  163. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  164. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  165. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  166. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  167. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  168. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
  169. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
  170. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
  171. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  172. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  173. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  174. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  175. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  176. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  177. diffusers/pipelines/cosmos/__init__.py +54 -0
  178. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  179. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  180. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  181. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  182. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  183. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  184. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  185. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  186. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  187. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  188. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  189. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  190. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  191. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  192. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  193. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  194. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  195. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  196. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  197. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  198. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  199. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  200. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  201. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  202. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  203. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  204. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
  205. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  206. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  207. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  208. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  209. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  210. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  211. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  212. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  213. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  214. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  215. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  216. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  217. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  218. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  219. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  220. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  221. diffusers/pipelines/flux/modeling_flux.py +1 -1
  222. diffusers/pipelines/flux/pipeline_flux.py +10 -17
  223. diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
  224. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
  225. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
  226. diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
  227. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
  228. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
  229. diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
  230. diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
  231. diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
  232. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  233. diffusers/pipelines/free_init_utils.py +2 -2
  234. diffusers/pipelines/free_noise_utils.py +3 -3
  235. diffusers/pipelines/hidream_image/__init__.py +47 -0
  236. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  237. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  238. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  239. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  240. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
  241. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  242. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  243. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  244. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  245. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  246. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  247. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  248. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  249. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  250. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  251. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  252. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  253. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  254. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  255. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  256. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  257. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  258. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  259. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  260. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  261. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  262. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  263. diffusers/pipelines/kolors/text_encoder.py +3 -3
  264. diffusers/pipelines/kolors/tokenizer.py +1 -1
  265. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  266. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  267. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  268. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  269. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  270. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  271. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  272. diffusers/pipelines/ltx/__init__.py +4 -0
  273. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  274. diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
  275. diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
  276. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
  277. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  278. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  279. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  280. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  281. diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
  282. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  283. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  284. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  285. diffusers/pipelines/onnx_utils.py +15 -2
  286. diffusers/pipelines/pag/pag_utils.py +2 -2
  287. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  288. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  289. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  290. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  291. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  292. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  293. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  294. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  295. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  296. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  297. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  298. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  299. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  300. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  301. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  302. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  303. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  304. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  305. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  306. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  307. diffusers/pipelines/pipeline_flax_utils.py +3 -4
  308. diffusers/pipelines/pipeline_loading_utils.py +89 -13
  309. diffusers/pipelines/pipeline_utils.py +105 -33
  310. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
  311. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
  312. diffusers/pipelines/sana/__init__.py +4 -0
  313. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  314. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  315. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  316. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  317. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  318. diffusers/pipelines/shap_e/camera.py +1 -1
  319. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  320. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  321. diffusers/pipelines/shap_e/renderer.py +3 -3
  322. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  323. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  324. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  325. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  326. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  327. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  328. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  329. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  330. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  331. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  332. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  333. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
  334. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  335. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
  336. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
  337. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
  338. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  339. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  340. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  341. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  342. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  343. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  344. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  345. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  346. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  347. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  348. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  349. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  350. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
  351. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  352. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  353. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  354. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  355. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  356. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  357. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  358. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  359. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  360. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  361. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  362. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  363. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  364. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  365. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  366. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  367. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  368. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  369. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  370. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  371. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  372. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  373. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  374. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  375. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  376. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  377. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  378. diffusers/pipelines/unclip/text_proj.py +2 -2
  379. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  380. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  381. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  382. diffusers/pipelines/visualcloze/__init__.py +52 -0
  383. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  384. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  385. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  386. diffusers/pipelines/wan/__init__.py +2 -0
  387. diffusers/pipelines/wan/pipeline_wan.py +13 -10
  388. diffusers/pipelines/wan/pipeline_wan_i2v.py +38 -18
  389. diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
  390. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  391. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  392. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  393. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  394. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  395. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  396. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  397. diffusers/quantizers/__init__.py +179 -1
  398. diffusers/quantizers/base.py +6 -1
  399. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  400. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  401. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  402. diffusers/quantizers/gguf/utils.py +16 -13
  403. diffusers/quantizers/quantization_config.py +18 -16
  404. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  405. diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
  406. diffusers/schedulers/__init__.py +3 -1
  407. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  408. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  409. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  410. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  411. diffusers/schedulers/scheduling_ddim.py +8 -8
  412. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  413. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  414. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  415. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  416. diffusers/schedulers/scheduling_ddpm.py +9 -9
  417. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  418. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  419. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  420. diffusers/schedulers/scheduling_deis_multistep.py +8 -8
  421. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  422. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
  423. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  424. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  425. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  426. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
  427. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  428. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  429. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  430. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  431. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  432. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  433. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  434. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  435. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  436. diffusers/schedulers/scheduling_ipndm.py +2 -2
  437. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  438. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  439. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  440. diffusers/schedulers/scheduling_lcm.py +3 -3
  441. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  442. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  443. diffusers/schedulers/scheduling_pndm.py +4 -4
  444. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  445. diffusers/schedulers/scheduling_repaint.py +9 -9
  446. diffusers/schedulers/scheduling_sasolver.py +15 -15
  447. diffusers/schedulers/scheduling_scm.py +1 -1
  448. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  449. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  450. diffusers/schedulers/scheduling_tcd.py +3 -3
  451. diffusers/schedulers/scheduling_unclip.py +5 -5
  452. diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
  453. diffusers/schedulers/scheduling_utils.py +1 -1
  454. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  455. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  456. diffusers/training_utils.py +13 -5
  457. diffusers/utils/__init__.py +5 -0
  458. diffusers/utils/accelerate_utils.py +1 -1
  459. diffusers/utils/doc_utils.py +1 -1
  460. diffusers/utils/dummy_pt_objects.py +120 -0
  461. diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
  462. diffusers/utils/dynamic_modules_utils.py +21 -3
  463. diffusers/utils/export_utils.py +1 -1
  464. diffusers/utils/import_utils.py +81 -18
  465. diffusers/utils/logging.py +1 -1
  466. diffusers/utils/outputs.py +2 -1
  467. diffusers/utils/peft_utils.py +91 -8
  468. diffusers/utils/state_dict_utils.py +20 -3
  469. diffusers/utils/testing_utils.py +59 -7
  470. diffusers/utils/torch_utils.py +25 -5
  471. diffusers/video_processor.py +2 -2
  472. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/METADATA +70 -55
  473. diffusers-0.34.0.dist-info/RECORD +639 -0
  474. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/WHEEL +1 -1
  475. diffusers-0.33.1.dist-info/RECORD +0 -608
  476. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
  477. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
  478. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -37,12 +37,16 @@ from .lora_base import ( # noqa
37
37
  LoraBaseMixin,
38
38
  _fetch_state_dict,
39
39
  _load_lora_into_text_encoder,
40
+ _pack_dict_with_prefix,
40
41
  )
41
42
  from .lora_conversion_utils import (
42
43
  _convert_bfl_flux_control_lora_to_diffusers,
43
44
  _convert_hunyuan_video_lora_to_diffusers,
44
45
  _convert_kohya_flux_lora_to_diffusers,
46
+ _convert_musubi_wan_lora_to_diffusers,
47
+ _convert_non_diffusers_hidream_lora_to_diffusers,
45
48
  _convert_non_diffusers_lora_to_diffusers,
49
+ _convert_non_diffusers_ltxv_lora_to_diffusers,
46
50
  _convert_non_diffusers_lumina2_lora_to_diffusers,
47
51
  _convert_non_diffusers_wan_lora_to_diffusers,
48
52
  _convert_xlabs_flux_lora_to_diffusers,
@@ -78,30 +82,36 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
78
82
  from ..quantizers.gguf.utils import dequantize_gguf_tensor
79
83
 
80
84
  is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
85
+ is_bnb_8bit_quantized = module.weight.__class__.__name__ == "Int8Params"
81
86
  is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
82
87
 
83
88
  if is_bnb_4bit_quantized and not is_bitsandbytes_available():
84
89
  raise ValueError(
85
90
  "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
86
91
  )
92
+ if is_bnb_8bit_quantized and not is_bitsandbytes_available():
93
+ raise ValueError(
94
+ "The checkpoint seems to have been quantized with `bitsandbytes` (8bits). Install `bitsandbytes` to load quantized checkpoints."
95
+ )
87
96
  if is_gguf_quantized and not is_gguf_available():
88
97
  raise ValueError(
89
98
  "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
90
99
  )
91
100
 
92
101
  weight_on_cpu = False
93
- if not module.weight.is_cuda:
102
+ if module.weight.device.type == "cpu":
94
103
  weight_on_cpu = True
95
104
 
96
- if is_bnb_4bit_quantized:
105
+ device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
106
+ if is_bnb_4bit_quantized or is_bnb_8bit_quantized:
97
107
  module_weight = dequantize_bnb_weight(
98
- module.weight.cuda() if weight_on_cpu else module.weight,
99
- state=module.weight.quant_state,
108
+ module.weight.to(device) if weight_on_cpu else module.weight,
109
+ state=module.weight.quant_state if is_bnb_4bit_quantized else module.state,
100
110
  dtype=model.dtype,
101
111
  ).data
102
112
  elif is_gguf_quantized:
103
113
  module_weight = dequantize_gguf_tensor(
104
- module.weight.cuda() if weight_on_cpu else module.weight,
114
+ module.weight.to(device) if weight_on_cpu else module.weight,
105
115
  )
106
116
  module_weight = module_weight.to(model.dtype)
107
117
  else:
@@ -126,7 +136,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
126
136
  def load_lora_weights(
127
137
  self,
128
138
  pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
129
- adapter_name=None,
139
+ adapter_name: Optional[str] = None,
130
140
  hotswap: bool = False,
131
141
  **kwargs,
132
142
  ):
@@ -153,7 +163,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
153
163
  low_cpu_mem_usage (`bool`, *optional*):
154
164
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
155
165
  weights.
156
- hotswap : (`bool`, *optional*)
166
+ hotswap (`bool`, *optional*):
157
167
  Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
158
168
  in-place. This means that, instead of loading an additional adapter, this will take the existing
159
169
  adapter weights and replace them with the weights of the new adapter. This can be faster and more
@@ -193,7 +203,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
193
203
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
194
204
 
195
205
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
196
- state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
206
+ kwargs["return_lora_metadata"] = True
207
+ state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
197
208
 
198
209
  is_correct_format = all("lora" in key for key in state_dict.keys())
199
210
  if not is_correct_format:
@@ -204,6 +215,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
204
215
  network_alphas=network_alphas,
205
216
  unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
206
217
  adapter_name=adapter_name,
218
+ metadata=metadata,
207
219
  _pipeline=self,
208
220
  low_cpu_mem_usage=low_cpu_mem_usage,
209
221
  hotswap=hotswap,
@@ -217,6 +229,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
217
229
  lora_scale=self.lora_scale,
218
230
  adapter_name=adapter_name,
219
231
  _pipeline=self,
232
+ metadata=metadata,
220
233
  low_cpu_mem_usage=low_cpu_mem_usage,
221
234
  hotswap=hotswap,
222
235
  )
@@ -273,6 +286,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
273
286
  The subfolder location of a model file within a larger model repository on the Hub or locally.
274
287
  weight_name (`str`, *optional*, defaults to None):
275
288
  Name of the serialized state dict file.
289
+ return_lora_metadata (`bool`, *optional*, defaults to False):
290
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
276
291
  """
277
292
  # Load the main state dict first which has the LoRA layers for either of
278
293
  # UNet and text encoder or both.
@@ -286,18 +301,16 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
286
301
  weight_name = kwargs.pop("weight_name", None)
287
302
  unet_config = kwargs.pop("unet_config", None)
288
303
  use_safetensors = kwargs.pop("use_safetensors", None)
304
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
289
305
 
290
306
  allow_pickle = False
291
307
  if use_safetensors is None:
292
308
  use_safetensors = True
293
309
  allow_pickle = True
294
310
 
295
- user_agent = {
296
- "file_type": "attn_procs_weights",
297
- "framework": "pytorch",
298
- }
311
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
299
312
 
300
- state_dict = _fetch_state_dict(
313
+ state_dict, metadata = _fetch_state_dict(
301
314
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
302
315
  weight_name=weight_name,
303
316
  use_safetensors=use_safetensors,
@@ -334,7 +347,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
334
347
  state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
335
348
  state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
336
349
 
337
- return state_dict, network_alphas
350
+ out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas)
351
+ return out
338
352
 
339
353
  @classmethod
340
354
  def load_lora_into_unet(
@@ -346,6 +360,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
346
360
  _pipeline=None,
347
361
  low_cpu_mem_usage=False,
348
362
  hotswap: bool = False,
363
+ metadata=None,
349
364
  ):
350
365
  """
351
366
  This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -367,29 +382,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
367
382
  low_cpu_mem_usage (`bool`, *optional*):
368
383
  Speed up model loading only loading the pretrained LoRA weights and not initializing the random
369
384
  weights.
370
- hotswap : (`bool`, *optional*)
371
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
372
- in-place. This means that, instead of loading an additional adapter, this will take the existing
373
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
374
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
375
- torch.compile, loading the new adapter does not require recompilation of the model. When using
376
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
377
-
378
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
379
- to call an additional method before loading the adapter:
380
-
381
- ```py
382
- pipeline = ... # load diffusers pipeline
383
- max_rank = ... # the highest rank among all LoRAs that you want to load
384
- # call *before* compiling and loading the LoRA adapter
385
- pipeline.enable_lora_hotswap(target_rank=max_rank)
386
- pipeline.load_lora_weights(file_name)
387
- # optionally compile the model now
388
- ```
389
-
390
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
391
- limitations to this technique, which are documented here:
392
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
385
+ hotswap (`bool`, *optional*):
386
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
387
+ metadata (`dict`):
388
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
389
+ from the state dict.
393
390
  """
394
391
  if not USE_PEFT_BACKEND:
395
392
  raise ValueError("PEFT backend is required for this method.")
@@ -408,6 +405,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
408
405
  prefix=cls.unet_name,
409
406
  network_alphas=network_alphas,
410
407
  adapter_name=adapter_name,
408
+ metadata=metadata,
411
409
  _pipeline=_pipeline,
412
410
  low_cpu_mem_usage=low_cpu_mem_usage,
413
411
  hotswap=hotswap,
@@ -425,6 +423,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
425
423
  _pipeline=None,
426
424
  low_cpu_mem_usage=False,
427
425
  hotswap: bool = False,
426
+ metadata=None,
428
427
  ):
429
428
  """
430
429
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -450,29 +449,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
450
449
  low_cpu_mem_usage (`bool`, *optional*):
451
450
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
452
451
  weights.
453
- hotswap : (`bool`, *optional*)
454
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
455
- in-place. This means that, instead of loading an additional adapter, this will take the existing
456
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
457
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
458
- torch.compile, loading the new adapter does not require recompilation of the model. When using
459
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
460
-
461
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
462
- to call an additional method before loading the adapter:
463
-
464
- ```py
465
- pipeline = ... # load diffusers pipeline
466
- max_rank = ... # the highest rank among all LoRAs that you want to load
467
- # call *before* compiling and loading the LoRA adapter
468
- pipeline.enable_lora_hotswap(target_rank=max_rank)
469
- pipeline.load_lora_weights(file_name)
470
- # optionally compile the model now
471
- ```
472
-
473
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
474
- limitations to this technique, which are documented here:
475
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
452
+ hotswap (`bool`, *optional*):
453
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
454
+ metadata (`dict`):
455
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
456
+ from the state dict.
476
457
  """
477
458
  _load_lora_into_text_encoder(
478
459
  state_dict=state_dict,
@@ -482,6 +463,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
482
463
  prefix=prefix,
483
464
  text_encoder_name=cls.text_encoder_name,
484
465
  adapter_name=adapter_name,
466
+ metadata=metadata,
485
467
  _pipeline=_pipeline,
486
468
  low_cpu_mem_usage=low_cpu_mem_usage,
487
469
  hotswap=hotswap,
@@ -497,6 +479,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
497
479
  weight_name: str = None,
498
480
  save_function: Callable = None,
499
481
  safe_serialization: bool = True,
482
+ unet_lora_adapter_metadata=None,
483
+ text_encoder_lora_adapter_metadata=None,
500
484
  ):
501
485
  r"""
502
486
  Save the LoRA parameters corresponding to the UNet and text encoder.
@@ -519,8 +503,13 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
519
503
  `DIFFUSERS_SAVE_MODE`.
520
504
  safe_serialization (`bool`, *optional*, defaults to `True`):
521
505
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
506
+ unet_lora_adapter_metadata:
507
+ LoRA adapter metadata associated with the unet to be serialized with the state dict.
508
+ text_encoder_lora_adapter_metadata:
509
+ LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
522
510
  """
523
511
  state_dict = {}
512
+ lora_adapter_metadata = {}
524
513
 
525
514
  if not (unet_lora_layers or text_encoder_lora_layers):
526
515
  raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
@@ -531,6 +520,14 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
531
520
  if text_encoder_lora_layers:
532
521
  state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
533
522
 
523
+ if unet_lora_adapter_metadata:
524
+ lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
525
+
526
+ if text_encoder_lora_adapter_metadata:
527
+ lora_adapter_metadata.update(
528
+ _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
529
+ )
530
+
534
531
  # Save the model
535
532
  cls.write_lora_layers(
536
533
  state_dict=state_dict,
@@ -539,6 +536,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
539
536
  weight_name=weight_name,
540
537
  save_function=save_function,
541
538
  safe_serialization=safe_serialization,
539
+ lora_adapter_metadata=lora_adapter_metadata,
542
540
  )
543
541
 
544
542
  def fuse_lora(
@@ -624,6 +622,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
624
622
  self,
625
623
  pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
626
624
  adapter_name: Optional[str] = None,
625
+ hotswap: bool = False,
627
626
  **kwargs,
628
627
  ):
629
628
  """
@@ -650,6 +649,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
650
649
  low_cpu_mem_usage (`bool`, *optional*):
651
650
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
652
651
  weights.
652
+ hotswap (`bool`, *optional*):
653
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
653
654
  kwargs (`dict`, *optional*):
654
655
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
655
656
  """
@@ -671,7 +672,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
671
672
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
672
673
 
673
674
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
674
- state_dict, network_alphas = self.lora_state_dict(
675
+ kwargs["return_lora_metadata"] = True
676
+ state_dict, network_alphas, metadata = self.lora_state_dict(
675
677
  pretrained_model_name_or_path_or_dict,
676
678
  unet_config=self.unet.config,
677
679
  **kwargs,
@@ -686,8 +688,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
686
688
  network_alphas=network_alphas,
687
689
  unet=self.unet,
688
690
  adapter_name=adapter_name,
691
+ metadata=metadata,
689
692
  _pipeline=self,
690
693
  low_cpu_mem_usage=low_cpu_mem_usage,
694
+ hotswap=hotswap,
691
695
  )
692
696
  self.load_lora_into_text_encoder(
693
697
  state_dict,
@@ -696,8 +700,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
696
700
  prefix=self.text_encoder_name,
697
701
  lora_scale=self.lora_scale,
698
702
  adapter_name=adapter_name,
703
+ metadata=metadata,
699
704
  _pipeline=self,
700
705
  low_cpu_mem_usage=low_cpu_mem_usage,
706
+ hotswap=hotswap,
701
707
  )
702
708
  self.load_lora_into_text_encoder(
703
709
  state_dict,
@@ -706,8 +712,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
706
712
  prefix=f"{self.text_encoder_name}_2",
707
713
  lora_scale=self.lora_scale,
708
714
  adapter_name=adapter_name,
715
+ metadata=metadata,
709
716
  _pipeline=self,
710
717
  low_cpu_mem_usage=low_cpu_mem_usage,
718
+ hotswap=hotswap,
711
719
  )
712
720
 
713
721
  @classmethod
@@ -763,6 +771,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
763
771
  The subfolder location of a model file within a larger model repository on the Hub or locally.
764
772
  weight_name (`str`, *optional*, defaults to None):
765
773
  Name of the serialized state dict file.
774
+ return_lora_metadata (`bool`, *optional*, defaults to False):
775
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
766
776
  """
767
777
  # Load the main state dict first which has the LoRA layers for either of
768
778
  # UNet and text encoder or both.
@@ -776,18 +786,16 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
776
786
  weight_name = kwargs.pop("weight_name", None)
777
787
  unet_config = kwargs.pop("unet_config", None)
778
788
  use_safetensors = kwargs.pop("use_safetensors", None)
789
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
779
790
 
780
791
  allow_pickle = False
781
792
  if use_safetensors is None:
782
793
  use_safetensors = True
783
794
  allow_pickle = True
784
795
 
785
- user_agent = {
786
- "file_type": "attn_procs_weights",
787
- "framework": "pytorch",
788
- }
796
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
789
797
 
790
- state_dict = _fetch_state_dict(
798
+ state_dict, metadata = _fetch_state_dict(
791
799
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
792
800
  weight_name=weight_name,
793
801
  use_safetensors=use_safetensors,
@@ -824,7 +832,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
824
832
  state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
825
833
  state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
826
834
 
827
- return state_dict, network_alphas
835
+ out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas)
836
+ return out
828
837
 
829
838
  @classmethod
830
839
  # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
@@ -837,6 +846,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
837
846
  _pipeline=None,
838
847
  low_cpu_mem_usage=False,
839
848
  hotswap: bool = False,
849
+ metadata=None,
840
850
  ):
841
851
  """
842
852
  This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -858,29 +868,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
858
868
  low_cpu_mem_usage (`bool`, *optional*):
859
869
  Speed up model loading only loading the pretrained LoRA weights and not initializing the random
860
870
  weights.
861
- hotswap : (`bool`, *optional*)
862
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
863
- in-place. This means that, instead of loading an additional adapter, this will take the existing
864
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
865
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
866
- torch.compile, loading the new adapter does not require recompilation of the model. When using
867
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
868
-
869
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
870
- to call an additional method before loading the adapter:
871
-
872
- ```py
873
- pipeline = ... # load diffusers pipeline
874
- max_rank = ... # the highest rank among all LoRAs that you want to load
875
- # call *before* compiling and loading the LoRA adapter
876
- pipeline.enable_lora_hotswap(target_rank=max_rank)
877
- pipeline.load_lora_weights(file_name)
878
- # optionally compile the model now
879
- ```
880
-
881
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
882
- limitations to this technique, which are documented here:
883
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
871
+ hotswap (`bool`, *optional*):
872
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
873
+ metadata (`dict`):
874
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
875
+ from the state dict.
884
876
  """
885
877
  if not USE_PEFT_BACKEND:
886
878
  raise ValueError("PEFT backend is required for this method.")
@@ -899,6 +891,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
899
891
  prefix=cls.unet_name,
900
892
  network_alphas=network_alphas,
901
893
  adapter_name=adapter_name,
894
+ metadata=metadata,
902
895
  _pipeline=_pipeline,
903
896
  low_cpu_mem_usage=low_cpu_mem_usage,
904
897
  hotswap=hotswap,
@@ -917,6 +910,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
917
910
  _pipeline=None,
918
911
  low_cpu_mem_usage=False,
919
912
  hotswap: bool = False,
913
+ metadata=None,
920
914
  ):
921
915
  """
922
916
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -942,29 +936,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
942
936
  low_cpu_mem_usage (`bool`, *optional*):
943
937
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
944
938
  weights.
945
- hotswap : (`bool`, *optional*)
946
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
947
- in-place. This means that, instead of loading an additional adapter, this will take the existing
948
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
949
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
950
- torch.compile, loading the new adapter does not require recompilation of the model. When using
951
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
952
-
953
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
954
- to call an additional method before loading the adapter:
955
-
956
- ```py
957
- pipeline = ... # load diffusers pipeline
958
- max_rank = ... # the highest rank among all LoRAs that you want to load
959
- # call *before* compiling and loading the LoRA adapter
960
- pipeline.enable_lora_hotswap(target_rank=max_rank)
961
- pipeline.load_lora_weights(file_name)
962
- # optionally compile the model now
963
- ```
964
-
965
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
966
- limitations to this technique, which are documented here:
967
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
939
+ hotswap (`bool`, *optional*):
940
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
941
+ metadata (`dict`):
942
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
943
+ from the state dict.
968
944
  """
969
945
  _load_lora_into_text_encoder(
970
946
  state_dict=state_dict,
@@ -974,6 +950,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
974
950
  prefix=prefix,
975
951
  text_encoder_name=cls.text_encoder_name,
976
952
  adapter_name=adapter_name,
953
+ metadata=metadata,
977
954
  _pipeline=_pipeline,
978
955
  low_cpu_mem_usage=low_cpu_mem_usage,
979
956
  hotswap=hotswap,
@@ -990,6 +967,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
990
967
  weight_name: str = None,
991
968
  save_function: Callable = None,
992
969
  safe_serialization: bool = True,
970
+ unet_lora_adapter_metadata=None,
971
+ text_encoder_lora_adapter_metadata=None,
972
+ text_encoder_2_lora_adapter_metadata=None,
993
973
  ):
994
974
  r"""
995
975
  Save the LoRA parameters corresponding to the UNet and text encoder.
@@ -1015,8 +995,15 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
1015
995
  `DIFFUSERS_SAVE_MODE`.
1016
996
  safe_serialization (`bool`, *optional*, defaults to `True`):
1017
997
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
998
+ unet_lora_adapter_metadata:
999
+ LoRA adapter metadata associated with the unet to be serialized with the state dict.
1000
+ text_encoder_lora_adapter_metadata:
1001
+ LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
1002
+ text_encoder_2_lora_adapter_metadata:
1003
+ LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
1018
1004
  """
1019
1005
  state_dict = {}
1006
+ lora_adapter_metadata = {}
1020
1007
 
1021
1008
  if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
1022
1009
  raise ValueError(
@@ -1032,6 +1019,19 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
1032
1019
  if text_encoder_2_lora_layers:
1033
1020
  state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
1034
1021
 
1022
+ if unet_lora_adapter_metadata is not None:
1023
+ lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
1024
+
1025
+ if text_encoder_lora_adapter_metadata:
1026
+ lora_adapter_metadata.update(
1027
+ _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
1028
+ )
1029
+
1030
+ if text_encoder_2_lora_adapter_metadata:
1031
+ lora_adapter_metadata.update(
1032
+ _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
1033
+ )
1034
+
1035
1035
  cls.write_lora_layers(
1036
1036
  state_dict=state_dict,
1037
1037
  save_directory=save_directory,
@@ -1039,6 +1039,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
1039
1039
  weight_name=weight_name,
1040
1040
  save_function=save_function,
1041
1041
  safe_serialization=safe_serialization,
1042
+ lora_adapter_metadata=lora_adapter_metadata,
1042
1043
  )
1043
1044
 
1044
1045
  def fuse_lora(
@@ -1172,6 +1173,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1172
1173
  allowed by Git.
1173
1174
  subfolder (`str`, *optional*, defaults to `""`):
1174
1175
  The subfolder location of a model file within a larger model repository on the Hub or locally.
1176
+ return_lora_metadata (`bool`, *optional*, defaults to False):
1177
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
1175
1178
 
1176
1179
  """
1177
1180
  # Load the main state dict first which has the LoRA layers for either of
@@ -1185,18 +1188,16 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1185
1188
  subfolder = kwargs.pop("subfolder", None)
1186
1189
  weight_name = kwargs.pop("weight_name", None)
1187
1190
  use_safetensors = kwargs.pop("use_safetensors", None)
1191
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
1188
1192
 
1189
1193
  allow_pickle = False
1190
1194
  if use_safetensors is None:
1191
1195
  use_safetensors = True
1192
1196
  allow_pickle = True
1193
1197
 
1194
- user_agent = {
1195
- "file_type": "attn_procs_weights",
1196
- "framework": "pytorch",
1197
- }
1198
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
1198
1199
 
1199
- state_dict = _fetch_state_dict(
1200
+ state_dict, metadata = _fetch_state_dict(
1200
1201
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
1201
1202
  weight_name=weight_name,
1202
1203
  use_safetensors=use_safetensors,
@@ -1217,7 +1218,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1217
1218
  logger.warning(warn_msg)
1218
1219
  state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
1219
1220
 
1220
- return state_dict
1221
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
1222
+ return out
1221
1223
 
1222
1224
  def load_lora_weights(
1223
1225
  self,
@@ -1247,29 +1249,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1247
1249
  low_cpu_mem_usage (`bool`, *optional*):
1248
1250
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1249
1251
  weights.
1250
- hotswap : (`bool`, *optional*)
1251
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1252
- in-place. This means that, instead of loading an additional adapter, this will take the existing
1253
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
1254
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1255
- torch.compile, loading the new adapter does not require recompilation of the model. When using
1256
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
1257
-
1258
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
1259
- to call an additional method before loading the adapter:
1260
-
1261
- ```py
1262
- pipeline = ... # load diffusers pipeline
1263
- max_rank = ... # the highest rank among all LoRAs that you want to load
1264
- # call *before* compiling and loading the LoRA adapter
1265
- pipeline.enable_lora_hotswap(target_rank=max_rank)
1266
- pipeline.load_lora_weights(file_name)
1267
- # optionally compile the model now
1268
- ```
1269
-
1270
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1271
- limitations to this technique, which are documented here:
1272
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
1252
+ hotswap (`bool`, *optional*):
1253
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
1273
1254
  kwargs (`dict`, *optional*):
1274
1255
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1275
1256
  """
@@ -1287,7 +1268,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1287
1268
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
1288
1269
 
1289
1270
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1290
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
1271
+ kwargs["return_lora_metadata"] = True
1272
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
1291
1273
 
1292
1274
  is_correct_format = all("lora" in key for key in state_dict.keys())
1293
1275
  if not is_correct_format:
@@ -1297,6 +1279,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1297
1279
  state_dict,
1298
1280
  transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
1299
1281
  adapter_name=adapter_name,
1282
+ metadata=metadata,
1300
1283
  _pipeline=self,
1301
1284
  low_cpu_mem_usage=low_cpu_mem_usage,
1302
1285
  hotswap=hotswap,
@@ -1308,6 +1291,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1308
1291
  prefix=self.text_encoder_name,
1309
1292
  lora_scale=self.lora_scale,
1310
1293
  adapter_name=adapter_name,
1294
+ metadata=metadata,
1311
1295
  _pipeline=self,
1312
1296
  low_cpu_mem_usage=low_cpu_mem_usage,
1313
1297
  hotswap=hotswap,
@@ -1319,6 +1303,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1319
1303
  prefix=f"{self.text_encoder_name}_2",
1320
1304
  lora_scale=self.lora_scale,
1321
1305
  adapter_name=adapter_name,
1306
+ metadata=metadata,
1322
1307
  _pipeline=self,
1323
1308
  low_cpu_mem_usage=low_cpu_mem_usage,
1324
1309
  hotswap=hotswap,
@@ -1326,7 +1311,14 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1326
1311
 
1327
1312
  @classmethod
1328
1313
  def load_lora_into_transformer(
1329
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
1314
+ cls,
1315
+ state_dict,
1316
+ transformer,
1317
+ adapter_name=None,
1318
+ _pipeline=None,
1319
+ low_cpu_mem_usage=False,
1320
+ hotswap: bool = False,
1321
+ metadata=None,
1330
1322
  ):
1331
1323
  """
1332
1324
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1344,29 +1336,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1344
1336
  low_cpu_mem_usage (`bool`, *optional*):
1345
1337
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1346
1338
  weights.
1347
- hotswap : (`bool`, *optional*)
1348
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1349
- in-place. This means that, instead of loading an additional adapter, this will take the existing
1350
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
1351
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1352
- torch.compile, loading the new adapter does not require recompilation of the model. When using
1353
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
1354
-
1355
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
1356
- to call an additional method before loading the adapter:
1357
-
1358
- ```py
1359
- pipeline = ... # load diffusers pipeline
1360
- max_rank = ... # the highest rank among all LoRAs that you want to load
1361
- # call *before* compiling and loading the LoRA adapter
1362
- pipeline.enable_lora_hotswap(target_rank=max_rank)
1363
- pipeline.load_lora_weights(file_name)
1364
- # optionally compile the model now
1365
- ```
1366
-
1367
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1368
- limitations to this technique, which are documented here:
1369
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
1339
+ hotswap (`bool`, *optional*):
1340
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
1341
+ metadata (`dict`):
1342
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
1343
+ from the state dict.
1370
1344
  """
1371
1345
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
1372
1346
  raise ValueError(
@@ -1379,6 +1353,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1379
1353
  state_dict,
1380
1354
  network_alphas=None,
1381
1355
  adapter_name=adapter_name,
1356
+ metadata=metadata,
1382
1357
  _pipeline=_pipeline,
1383
1358
  low_cpu_mem_usage=low_cpu_mem_usage,
1384
1359
  hotswap=hotswap,
@@ -1397,6 +1372,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1397
1372
  _pipeline=None,
1398
1373
  low_cpu_mem_usage=False,
1399
1374
  hotswap: bool = False,
1375
+ metadata=None,
1400
1376
  ):
1401
1377
  """
1402
1378
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1422,29 +1398,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1422
1398
  low_cpu_mem_usage (`bool`, *optional*):
1423
1399
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1424
1400
  weights.
1425
- hotswap : (`bool`, *optional*)
1426
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1427
- in-place. This means that, instead of loading an additional adapter, this will take the existing
1428
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
1429
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1430
- torch.compile, loading the new adapter does not require recompilation of the model. When using
1431
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
1432
-
1433
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
1434
- to call an additional method before loading the adapter:
1435
-
1436
- ```py
1437
- pipeline = ... # load diffusers pipeline
1438
- max_rank = ... # the highest rank among all LoRAs that you want to load
1439
- # call *before* compiling and loading the LoRA adapter
1440
- pipeline.enable_lora_hotswap(target_rank=max_rank)
1441
- pipeline.load_lora_weights(file_name)
1442
- # optionally compile the model now
1443
- ```
1444
-
1445
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1446
- limitations to this technique, which are documented here:
1447
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
1401
+ hotswap (`bool`, *optional*):
1402
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
1403
+ metadata (`dict`):
1404
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
1405
+ from the state dict.
1448
1406
  """
1449
1407
  _load_lora_into_text_encoder(
1450
1408
  state_dict=state_dict,
@@ -1454,6 +1412,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1454
1412
  prefix=prefix,
1455
1413
  text_encoder_name=cls.text_encoder_name,
1456
1414
  adapter_name=adapter_name,
1415
+ metadata=metadata,
1457
1416
  _pipeline=_pipeline,
1458
1417
  low_cpu_mem_usage=low_cpu_mem_usage,
1459
1418
  hotswap=hotswap,
@@ -1471,6 +1430,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1471
1430
  weight_name: str = None,
1472
1431
  save_function: Callable = None,
1473
1432
  safe_serialization: bool = True,
1433
+ transformer_lora_adapter_metadata=None,
1434
+ text_encoder_lora_adapter_metadata=None,
1435
+ text_encoder_2_lora_adapter_metadata=None,
1474
1436
  ):
1475
1437
  r"""
1476
1438
  Save the LoRA parameters corresponding to the UNet and text encoder.
@@ -1496,8 +1458,15 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1496
1458
  `DIFFUSERS_SAVE_MODE`.
1497
1459
  safe_serialization (`bool`, *optional*, defaults to `True`):
1498
1460
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
1461
+ transformer_lora_adapter_metadata:
1462
+ LoRA adapter metadata associated with the transformer to be serialized with the state dict.
1463
+ text_encoder_lora_adapter_metadata:
1464
+ LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
1465
+ text_encoder_2_lora_adapter_metadata:
1466
+ LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
1499
1467
  """
1500
1468
  state_dict = {}
1469
+ lora_adapter_metadata = {}
1501
1470
 
1502
1471
  if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
1503
1472
  raise ValueError(
@@ -1513,6 +1482,21 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1513
1482
  if text_encoder_2_lora_layers:
1514
1483
  state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
1515
1484
 
1485
+ if transformer_lora_adapter_metadata is not None:
1486
+ lora_adapter_metadata.update(
1487
+ _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
1488
+ )
1489
+
1490
+ if text_encoder_lora_adapter_metadata:
1491
+ lora_adapter_metadata.update(
1492
+ _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
1493
+ )
1494
+
1495
+ if text_encoder_2_lora_adapter_metadata:
1496
+ lora_adapter_metadata.update(
1497
+ _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
1498
+ )
1499
+
1516
1500
  cls.write_lora_layers(
1517
1501
  state_dict=state_dict,
1518
1502
  save_directory=save_directory,
@@ -1520,6 +1504,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1520
1504
  weight_name=weight_name,
1521
1505
  save_function=save_function,
1522
1506
  safe_serialization=safe_serialization,
1507
+ lora_adapter_metadata=lora_adapter_metadata,
1523
1508
  )
1524
1509
 
1525
1510
  # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
@@ -1592,25 +1577,20 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
1592
1577
  super().unfuse_lora(components=components, **kwargs)
1593
1578
 
1594
1579
 
1595
- class FluxLoraLoaderMixin(LoraBaseMixin):
1580
+ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
1596
1581
  r"""
1597
- Load LoRA layers into [`FluxTransformer2DModel`],
1598
- [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
1599
-
1600
- Specific to [`StableDiffusion3Pipeline`].
1582
+ Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`].
1601
1583
  """
1602
1584
 
1603
- _lora_loadable_modules = ["transformer", "text_encoder"]
1585
+ _lora_loadable_modules = ["transformer"]
1604
1586
  transformer_name = TRANSFORMER_NAME
1605
- text_encoder_name = TEXT_ENCODER_NAME
1606
- _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
1607
1587
 
1608
1588
  @classmethod
1609
1589
  @validate_hf_hub_args
1590
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
1610
1591
  def lora_state_dict(
1611
1592
  cls,
1612
1593
  pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
1613
- return_alphas: bool = False,
1614
1594
  **kwargs,
1615
1595
  ):
1616
1596
  r"""
@@ -1656,6 +1636,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1656
1636
  allowed by Git.
1657
1637
  subfolder (`str`, *optional*, defaults to `""`):
1658
1638
  The subfolder location of a model file within a larger model repository on the Hub or locally.
1639
+ return_lora_metadata (`bool`, *optional*, defaults to False):
1640
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
1659
1641
 
1660
1642
  """
1661
1643
  # Load the main state dict first which has the LoRA layers for either of
@@ -1669,18 +1651,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1669
1651
  subfolder = kwargs.pop("subfolder", None)
1670
1652
  weight_name = kwargs.pop("weight_name", None)
1671
1653
  use_safetensors = kwargs.pop("use_safetensors", None)
1654
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
1672
1655
 
1673
1656
  allow_pickle = False
1674
1657
  if use_safetensors is None:
1675
1658
  use_safetensors = True
1676
1659
  allow_pickle = True
1677
1660
 
1678
- user_agent = {
1679
- "file_type": "attn_procs_weights",
1680
- "framework": "pytorch",
1681
- }
1661
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
1682
1662
 
1683
- state_dict = _fetch_state_dict(
1663
+ state_dict, metadata = _fetch_state_dict(
1684
1664
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
1685
1665
  weight_name=weight_name,
1686
1666
  use_safetensors=use_safetensors,
@@ -1694,101 +1674,453 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1694
1674
  user_agent=user_agent,
1695
1675
  allow_pickle=allow_pickle,
1696
1676
  )
1677
+
1697
1678
  is_dora_scale_present = any("dora_scale" in k for k in state_dict)
1698
1679
  if is_dora_scale_present:
1699
1680
  warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1700
1681
  logger.warning(warn_msg)
1701
1682
  state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
1702
1683
 
1703
- # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
1704
- is_kohya = any(".lora_down.weight" in k for k in state_dict)
1705
- if is_kohya:
1706
- state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
1707
- # Kohya already takes care of scaling the LoRA parameters with alpha.
1708
- return (state_dict, None) if return_alphas else state_dict
1709
-
1710
- is_xlabs = any("processor" in k for k in state_dict)
1711
- if is_xlabs:
1712
- state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
1713
- # xlabs doesn't use `alpha`.
1714
- return (state_dict, None) if return_alphas else state_dict
1715
-
1716
- is_bfl_control = any("query_norm.scale" in k for k in state_dict)
1717
- if is_bfl_control:
1718
- state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
1719
- return (state_dict, None) if return_alphas else state_dict
1720
-
1721
- # For state dicts like
1722
- # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
1723
- keys = list(state_dict.keys())
1724
- network_alphas = {}
1725
- for k in keys:
1726
- if "alpha" in k:
1727
- alpha_value = state_dict.get(k)
1728
- if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
1729
- alpha_value, float
1730
- ):
1731
- network_alphas[k] = state_dict.pop(k)
1732
- else:
1733
- raise ValueError(
1734
- f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
1735
- )
1736
-
1737
- if return_alphas:
1738
- return state_dict, network_alphas
1739
- else:
1740
- return state_dict
1684
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
1685
+ return out
1741
1686
 
1687
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
1742
1688
  def load_lora_weights(
1743
1689
  self,
1744
1690
  pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
1745
- adapter_name=None,
1691
+ adapter_name: Optional[str] = None,
1746
1692
  hotswap: bool = False,
1747
1693
  **kwargs,
1748
1694
  ):
1749
1695
  """
1750
1696
  Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
1751
- `self.text_encoder`.
1752
-
1753
- All kwargs are forwarded to `self.lora_state_dict`.
1754
-
1755
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
1756
- loaded.
1757
-
1697
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
1698
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
1758
1699
  See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
1759
1700
  dict is loaded into `self.transformer`.
1760
1701
 
1761
1702
  Parameters:
1762
1703
  pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1763
1704
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1764
- kwargs (`dict`, *optional*):
1765
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1766
1705
  adapter_name (`str`, *optional*):
1767
1706
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1768
1707
  `default_{i}` where i is the total number of adapters being loaded.
1769
1708
  low_cpu_mem_usage (`bool`, *optional*):
1770
- `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1709
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1771
1710
  weights.
1772
- hotswap : (`bool`, *optional*)
1773
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1774
- in-place. This means that, instead of loading an additional adapter, this will take the existing
1775
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
1776
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1777
- torch.compile, loading the new adapter does not require recompilation of the model. When using
1778
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new
1779
- adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an
1780
- additional method before loading the adapter:
1781
- ```py
1782
- pipeline = ... # load diffusers pipeline
1783
- max_rank = ... # the highest rank among all LoRAs that you want to load
1784
- # call *before* compiling and loading the LoRA adapter
1785
- pipeline.enable_lora_hotswap(target_rank=max_rank)
1786
- pipeline.load_lora_weights(file_name)
1787
- # optionally compile the model now
1788
- ```
1789
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1790
- limitations to this technique, which are documented here:
1791
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
1711
+ hotswap (`bool`, *optional*):
1712
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
1713
+ kwargs (`dict`, *optional*):
1714
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1715
+ """
1716
+ if not USE_PEFT_BACKEND:
1717
+ raise ValueError("PEFT backend is required for this method.")
1718
+
1719
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
1720
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
1721
+ raise ValueError(
1722
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1723
+ )
1724
+
1725
+ # if a dict is passed, copy it instead of modifying it inplace
1726
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
1727
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
1728
+
1729
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1730
+ kwargs["return_lora_metadata"] = True
1731
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
1732
+
1733
+ is_correct_format = all("lora" in key for key in state_dict.keys())
1734
+ if not is_correct_format:
1735
+ raise ValueError("Invalid LoRA checkpoint.")
1736
+
1737
+ self.load_lora_into_transformer(
1738
+ state_dict,
1739
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
1740
+ adapter_name=adapter_name,
1741
+ metadata=metadata,
1742
+ _pipeline=self,
1743
+ low_cpu_mem_usage=low_cpu_mem_usage,
1744
+ hotswap=hotswap,
1745
+ )
1746
+
1747
+ @classmethod
1748
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel
1749
+ def load_lora_into_transformer(
1750
+ cls,
1751
+ state_dict,
1752
+ transformer,
1753
+ adapter_name=None,
1754
+ _pipeline=None,
1755
+ low_cpu_mem_usage=False,
1756
+ hotswap: bool = False,
1757
+ metadata=None,
1758
+ ):
1759
+ """
1760
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
1761
+
1762
+ Parameters:
1763
+ state_dict (`dict`):
1764
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
1765
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
1766
+ encoder lora layers.
1767
+ transformer (`AuraFlowTransformer2DModel`):
1768
+ The Transformer model to load the LoRA layers into.
1769
+ adapter_name (`str`, *optional*):
1770
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1771
+ `default_{i}` where i is the total number of adapters being loaded.
1772
+ low_cpu_mem_usage (`bool`, *optional*):
1773
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1774
+ weights.
1775
+ hotswap (`bool`, *optional*):
1776
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
1777
+ metadata (`dict`):
1778
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
1779
+ from the state dict.
1780
+ """
1781
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
1782
+ raise ValueError(
1783
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1784
+ )
1785
+
1786
+ # Load the layers corresponding to transformer.
1787
+ logger.info(f"Loading {cls.transformer_name}.")
1788
+ transformer.load_lora_adapter(
1789
+ state_dict,
1790
+ network_alphas=None,
1791
+ adapter_name=adapter_name,
1792
+ metadata=metadata,
1793
+ _pipeline=_pipeline,
1794
+ low_cpu_mem_usage=low_cpu_mem_usage,
1795
+ hotswap=hotswap,
1796
+ )
1797
+
1798
+ @classmethod
1799
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
1800
+ def save_lora_weights(
1801
+ cls,
1802
+ save_directory: Union[str, os.PathLike],
1803
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1804
+ is_main_process: bool = True,
1805
+ weight_name: str = None,
1806
+ save_function: Callable = None,
1807
+ safe_serialization: bool = True,
1808
+ transformer_lora_adapter_metadata: Optional[dict] = None,
1809
+ ):
1810
+ r"""
1811
+ Save the LoRA parameters corresponding to the transformer.
1812
+
1813
+ Arguments:
1814
+ save_directory (`str` or `os.PathLike`):
1815
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
1816
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1817
+ State dict of the LoRA layers corresponding to the `transformer`.
1818
+ is_main_process (`bool`, *optional*, defaults to `True`):
1819
+ Whether the process calling this is the main process or not. Useful during distributed training and you
1820
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
1821
+ process to avoid race conditions.
1822
+ save_function (`Callable`):
1823
+ The function to use to save the state dictionary. Useful during distributed training when you need to
1824
+ replace `torch.save` with another method. Can be configured with the environment variable
1825
+ `DIFFUSERS_SAVE_MODE`.
1826
+ safe_serialization (`bool`, *optional*, defaults to `True`):
1827
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
1828
+ transformer_lora_adapter_metadata:
1829
+ LoRA adapter metadata associated with the transformer to be serialized with the state dict.
1830
+ """
1831
+ state_dict = {}
1832
+ lora_adapter_metadata = {}
1833
+
1834
+ if not transformer_lora_layers:
1835
+ raise ValueError("You must pass `transformer_lora_layers`.")
1836
+
1837
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
1838
+
1839
+ if transformer_lora_adapter_metadata is not None:
1840
+ lora_adapter_metadata.update(
1841
+ _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
1842
+ )
1843
+
1844
+ # Save the model
1845
+ cls.write_lora_layers(
1846
+ state_dict=state_dict,
1847
+ save_directory=save_directory,
1848
+ is_main_process=is_main_process,
1849
+ weight_name=weight_name,
1850
+ save_function=save_function,
1851
+ safe_serialization=safe_serialization,
1852
+ lora_adapter_metadata=lora_adapter_metadata,
1853
+ )
1854
+
1855
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
1856
+ def fuse_lora(
1857
+ self,
1858
+ components: List[str] = ["transformer"],
1859
+ lora_scale: float = 1.0,
1860
+ safe_fusing: bool = False,
1861
+ adapter_names: Optional[List[str]] = None,
1862
+ **kwargs,
1863
+ ):
1864
+ r"""
1865
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
1866
+
1867
+ <Tip warning={true}>
1868
+
1869
+ This is an experimental API.
1870
+
1871
+ </Tip>
1872
+
1873
+ Args:
1874
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
1875
+ lora_scale (`float`, defaults to 1.0):
1876
+ Controls how much to influence the outputs with the LoRA parameters.
1877
+ safe_fusing (`bool`, defaults to `False`):
1878
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
1879
+ adapter_names (`List[str]`, *optional*):
1880
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
1881
+
1882
+ Example:
1883
+
1884
+ ```py
1885
+ from diffusers import DiffusionPipeline
1886
+ import torch
1887
+
1888
+ pipeline = DiffusionPipeline.from_pretrained(
1889
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
1890
+ ).to("cuda")
1891
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
1892
+ pipeline.fuse_lora(lora_scale=0.7)
1893
+ ```
1894
+ """
1895
+ super().fuse_lora(
1896
+ components=components,
1897
+ lora_scale=lora_scale,
1898
+ safe_fusing=safe_fusing,
1899
+ adapter_names=adapter_names,
1900
+ **kwargs,
1901
+ )
1902
+
1903
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
1904
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
1905
+ r"""
1906
+ Reverses the effect of
1907
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
1908
+
1909
+ <Tip warning={true}>
1910
+
1911
+ This is an experimental API.
1912
+
1913
+ </Tip>
1914
+
1915
+ Args:
1916
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
1917
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
1918
+ """
1919
+ super().unfuse_lora(components=components, **kwargs)
1920
+
1921
+
1922
+ class FluxLoraLoaderMixin(LoraBaseMixin):
1923
+ r"""
1924
+ Load LoRA layers into [`FluxTransformer2DModel`],
1925
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
1926
+
1927
+ Specific to [`StableDiffusion3Pipeline`].
1928
+ """
1929
+
1930
+ _lora_loadable_modules = ["transformer", "text_encoder"]
1931
+ transformer_name = TRANSFORMER_NAME
1932
+ text_encoder_name = TEXT_ENCODER_NAME
1933
+ _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
1934
+
1935
+ @classmethod
1936
+ @validate_hf_hub_args
1937
+ def lora_state_dict(
1938
+ cls,
1939
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
1940
+ return_alphas: bool = False,
1941
+ **kwargs,
1942
+ ):
1943
+ r"""
1944
+ Return state dict for lora weights and the network alphas.
1945
+
1946
+ <Tip warning={true}>
1947
+
1948
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
1949
+
1950
+ This function is experimental and might change in the future.
1951
+
1952
+ </Tip>
1953
+
1954
+ Parameters:
1955
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1956
+ Can be either:
1957
+
1958
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1959
+ the Hub.
1960
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1961
+ with [`ModelMixin.save_pretrained`].
1962
+ - A [torch state
1963
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
1964
+
1965
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1966
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1967
+ is not used.
1968
+ force_download (`bool`, *optional*, defaults to `False`):
1969
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1970
+ cached versions if they exist.
1971
+
1972
+ proxies (`Dict[str, str]`, *optional*):
1973
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1974
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1975
+ local_files_only (`bool`, *optional*, defaults to `False`):
1976
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1977
+ won't be downloaded from the Hub.
1978
+ token (`str` or *bool*, *optional*):
1979
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1980
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1981
+ revision (`str`, *optional*, defaults to `"main"`):
1982
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1983
+ allowed by Git.
1984
+ subfolder (`str`, *optional*, defaults to `""`):
1985
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1986
+ return_lora_metadata (`bool`, *optional*, defaults to False):
1987
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
1988
+ """
1989
+ # Load the main state dict first which has the LoRA layers for either of
1990
+ # transformer and text encoder or both.
1991
+ cache_dir = kwargs.pop("cache_dir", None)
1992
+ force_download = kwargs.pop("force_download", False)
1993
+ proxies = kwargs.pop("proxies", None)
1994
+ local_files_only = kwargs.pop("local_files_only", None)
1995
+ token = kwargs.pop("token", None)
1996
+ revision = kwargs.pop("revision", None)
1997
+ subfolder = kwargs.pop("subfolder", None)
1998
+ weight_name = kwargs.pop("weight_name", None)
1999
+ use_safetensors = kwargs.pop("use_safetensors", None)
2000
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
2001
+
2002
+ allow_pickle = False
2003
+ if use_safetensors is None:
2004
+ use_safetensors = True
2005
+ allow_pickle = True
2006
+
2007
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
2008
+
2009
+ state_dict, metadata = _fetch_state_dict(
2010
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
2011
+ weight_name=weight_name,
2012
+ use_safetensors=use_safetensors,
2013
+ local_files_only=local_files_only,
2014
+ cache_dir=cache_dir,
2015
+ force_download=force_download,
2016
+ proxies=proxies,
2017
+ token=token,
2018
+ revision=revision,
2019
+ subfolder=subfolder,
2020
+ user_agent=user_agent,
2021
+ allow_pickle=allow_pickle,
2022
+ )
2023
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
2024
+ if is_dora_scale_present:
2025
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
2026
+ logger.warning(warn_msg)
2027
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
2028
+
2029
+ # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
2030
+ is_kohya = any(".lora_down.weight" in k for k in state_dict)
2031
+ if is_kohya:
2032
+ state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
2033
+ # Kohya already takes care of scaling the LoRA parameters with alpha.
2034
+ return cls._prepare_outputs(
2035
+ state_dict,
2036
+ metadata=metadata,
2037
+ alphas=None,
2038
+ return_alphas=return_alphas,
2039
+ return_metadata=return_lora_metadata,
2040
+ )
2041
+
2042
+ is_xlabs = any("processor" in k for k in state_dict)
2043
+ if is_xlabs:
2044
+ state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
2045
+ # xlabs doesn't use `alpha`.
2046
+ return cls._prepare_outputs(
2047
+ state_dict,
2048
+ metadata=metadata,
2049
+ alphas=None,
2050
+ return_alphas=return_alphas,
2051
+ return_metadata=return_lora_metadata,
2052
+ )
2053
+
2054
+ is_bfl_control = any("query_norm.scale" in k for k in state_dict)
2055
+ if is_bfl_control:
2056
+ state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
2057
+ return cls._prepare_outputs(
2058
+ state_dict,
2059
+ metadata=metadata,
2060
+ alphas=None,
2061
+ return_alphas=return_alphas,
2062
+ return_metadata=return_lora_metadata,
2063
+ )
2064
+
2065
+ # For state dicts like
2066
+ # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
2067
+ keys = list(state_dict.keys())
2068
+ network_alphas = {}
2069
+ for k in keys:
2070
+ if "alpha" in k:
2071
+ alpha_value = state_dict.get(k)
2072
+ if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
2073
+ alpha_value, float
2074
+ ):
2075
+ network_alphas[k] = state_dict.pop(k)
2076
+ else:
2077
+ raise ValueError(
2078
+ f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
2079
+ )
2080
+
2081
+ if return_alphas or return_lora_metadata:
2082
+ return cls._prepare_outputs(
2083
+ state_dict,
2084
+ metadata=metadata,
2085
+ alphas=network_alphas,
2086
+ return_alphas=return_alphas,
2087
+ return_metadata=return_lora_metadata,
2088
+ )
2089
+ else:
2090
+ return state_dict
2091
+
2092
+ def load_lora_weights(
2093
+ self,
2094
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
2095
+ adapter_name: Optional[str] = None,
2096
+ hotswap: bool = False,
2097
+ **kwargs,
2098
+ ):
2099
+ """
2100
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
2101
+ `self.text_encoder`.
2102
+
2103
+ All kwargs are forwarded to `self.lora_state_dict`.
2104
+
2105
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
2106
+ loaded.
2107
+
2108
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
2109
+ dict is loaded into `self.transformer`.
2110
+
2111
+ Parameters:
2112
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
2113
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
2114
+ adapter_name (`str`, *optional*):
2115
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2116
+ `default_{i}` where i is the total number of adapters being loaded.
2117
+ low_cpu_mem_usage (`bool`, *optional*):
2118
+ `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2119
+ weights.
2120
+ hotswap (`bool`, *optional*):
2121
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
2122
+ kwargs (`dict`, *optional*):
2123
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1792
2124
  """
1793
2125
  if not USE_PEFT_BACKEND:
1794
2126
  raise ValueError("PEFT backend is required for this method.")
@@ -1804,7 +2136,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1804
2136
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
1805
2137
 
1806
2138
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1807
- state_dict, network_alphas = self.lora_state_dict(
2139
+ kwargs["return_lora_metadata"] = True
2140
+ state_dict, network_alphas, metadata = self.lora_state_dict(
1808
2141
  pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
1809
2142
  )
1810
2143
 
@@ -1855,6 +2188,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1855
2188
  network_alphas=network_alphas,
1856
2189
  transformer=transformer,
1857
2190
  adapter_name=adapter_name,
2191
+ metadata=metadata,
1858
2192
  _pipeline=self,
1859
2193
  low_cpu_mem_usage=low_cpu_mem_usage,
1860
2194
  hotswap=hotswap,
@@ -1874,6 +2208,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1874
2208
  prefix=self.text_encoder_name,
1875
2209
  lora_scale=self.lora_scale,
1876
2210
  adapter_name=adapter_name,
2211
+ metadata=metadata,
1877
2212
  _pipeline=self,
1878
2213
  low_cpu_mem_usage=low_cpu_mem_usage,
1879
2214
  hotswap=hotswap,
@@ -1886,6 +2221,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1886
2221
  network_alphas,
1887
2222
  transformer,
1888
2223
  adapter_name=None,
2224
+ metadata=None,
1889
2225
  _pipeline=None,
1890
2226
  low_cpu_mem_usage=False,
1891
2227
  hotswap: bool = False,
@@ -1910,29 +2246,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1910
2246
  low_cpu_mem_usage (`bool`, *optional*):
1911
2247
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1912
2248
  weights.
1913
- hotswap : (`bool`, *optional*)
1914
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1915
- in-place. This means that, instead of loading an additional adapter, this will take the existing
1916
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
1917
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1918
- torch.compile, loading the new adapter does not require recompilation of the model. When using
1919
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
1920
-
1921
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
1922
- to call an additional method before loading the adapter:
1923
-
1924
- ```py
1925
- pipeline = ... # load diffusers pipeline
1926
- max_rank = ... # the highest rank among all LoRAs that you want to load
1927
- # call *before* compiling and loading the LoRA adapter
1928
- pipeline.enable_lora_hotswap(target_rank=max_rank)
1929
- pipeline.load_lora_weights(file_name)
1930
- # optionally compile the model now
1931
- ```
1932
-
1933
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1934
- limitations to this technique, which are documented here:
1935
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
2249
+ hotswap (`bool`, *optional*):
2250
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
2251
+ metadata (`dict`):
2252
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
2253
+ from the state dict.
1936
2254
  """
1937
2255
  if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
1938
2256
  raise ValueError(
@@ -1945,6 +2263,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1945
2263
  state_dict,
1946
2264
  network_alphas=network_alphas,
1947
2265
  adapter_name=adapter_name,
2266
+ metadata=metadata,
1948
2267
  _pipeline=_pipeline,
1949
2268
  low_cpu_mem_usage=low_cpu_mem_usage,
1950
2269
  hotswap=hotswap,
@@ -1962,7 +2281,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1962
2281
  prefix = prefix or cls.transformer_name
1963
2282
  for key in list(state_dict.keys()):
1964
2283
  if key.split(".")[0] == prefix:
1965
- state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
2284
+ state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
1966
2285
 
1967
2286
  # Find invalid keys
1968
2287
  transformer_state_dict = transformer.state_dict()
@@ -2017,6 +2336,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2017
2336
  _pipeline=None,
2018
2337
  low_cpu_mem_usage=False,
2019
2338
  hotswap: bool = False,
2339
+ metadata=None,
2020
2340
  ):
2021
2341
  """
2022
2342
  This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -2040,31 +2360,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2040
2360
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2041
2361
  `default_{i}` where i is the total number of adapters being loaded.
2042
2362
  low_cpu_mem_usage (`bool`, *optional*):
2043
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2044
- weights.
2045
- hotswap : (`bool`, *optional*)
2046
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
2047
- in-place. This means that, instead of loading an additional adapter, this will take the existing
2048
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
2049
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
2050
- torch.compile, loading the new adapter does not require recompilation of the model. When using
2051
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
2052
-
2053
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
2054
- to call an additional method before loading the adapter:
2055
-
2056
- ```py
2057
- pipeline = ... # load diffusers pipeline
2058
- max_rank = ... # the highest rank among all LoRAs that you want to load
2059
- # call *before* compiling and loading the LoRA adapter
2060
- pipeline.enable_lora_hotswap(target_rank=max_rank)
2061
- pipeline.load_lora_weights(file_name)
2062
- # optionally compile the model now
2063
- ```
2064
-
2065
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
2066
- limitations to this technique, which are documented here:
2067
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
2363
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2364
+ weights.
2365
+ hotswap (`bool`, *optional*):
2366
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
2367
+ metadata (`dict`):
2368
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
2369
+ from the state dict.
2068
2370
  """
2069
2371
  _load_lora_into_text_encoder(
2070
2372
  state_dict=state_dict,
@@ -2074,6 +2376,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2074
2376
  prefix=prefix,
2075
2377
  text_encoder_name=cls.text_encoder_name,
2076
2378
  adapter_name=adapter_name,
2379
+ metadata=metadata,
2077
2380
  _pipeline=_pipeline,
2078
2381
  low_cpu_mem_usage=low_cpu_mem_usage,
2079
2382
  hotswap=hotswap,
@@ -2090,6 +2393,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2090
2393
  weight_name: str = None,
2091
2394
  save_function: Callable = None,
2092
2395
  safe_serialization: bool = True,
2396
+ transformer_lora_adapter_metadata=None,
2397
+ text_encoder_lora_adapter_metadata=None,
2093
2398
  ):
2094
2399
  r"""
2095
2400
  Save the LoRA parameters corresponding to the UNet and text encoder.
@@ -2112,8 +2417,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2112
2417
  `DIFFUSERS_SAVE_MODE`.
2113
2418
  safe_serialization (`bool`, *optional*, defaults to `True`):
2114
2419
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
2420
+ transformer_lora_adapter_metadata:
2421
+ LoRA adapter metadata associated with the transformer to be serialized with the state dict.
2422
+ text_encoder_lora_adapter_metadata:
2423
+ LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
2115
2424
  """
2116
2425
  state_dict = {}
2426
+ lora_adapter_metadata = {}
2117
2427
 
2118
2428
  if not (transformer_lora_layers or text_encoder_lora_layers):
2119
2429
  raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
@@ -2124,6 +2434,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2124
2434
  if text_encoder_lora_layers:
2125
2435
  state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
2126
2436
 
2437
+ if transformer_lora_adapter_metadata:
2438
+ lora_adapter_metadata.update(
2439
+ _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
2440
+ )
2441
+
2442
+ if text_encoder_lora_adapter_metadata:
2443
+ lora_adapter_metadata.update(
2444
+ _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
2445
+ )
2446
+
2127
2447
  # Save the model
2128
2448
  cls.write_lora_layers(
2129
2449
  state_dict=state_dict,
@@ -2132,6 +2452,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2132
2452
  weight_name=weight_name,
2133
2453
  save_function=save_function,
2134
2454
  safe_serialization=safe_serialization,
2455
+ lora_adapter_metadata=lora_adapter_metadata,
2135
2456
  )
2136
2457
 
2137
2458
  def fuse_lora(
@@ -2293,7 +2614,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2293
2614
  ) -> bool:
2294
2615
  """
2295
2616
  Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and
2296
- generalizes things a bit so that any parameter that needs expansion receives appropriate treatement.
2617
+ generalizes things a bit so that any parameter that needs expansion receives appropriate treatment.
2297
2618
  """
2298
2619
  state_dict = {}
2299
2620
  if lora_state_dict is not None:
@@ -2305,7 +2626,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2305
2626
  prefix = prefix or cls.transformer_name
2306
2627
  for key in list(state_dict.keys()):
2307
2628
  if key.split(".")[0] == prefix:
2308
- state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
2629
+ state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
2309
2630
 
2310
2631
  # Expand transformer parameter shapes if they don't match lora
2311
2632
  has_param_with_shape_update = False
@@ -2423,14 +2744,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2423
2744
  if unexpected_modules:
2424
2745
  logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
2425
2746
 
2426
- is_peft_loaded = getattr(transformer, "peft_config", None) is not None
2427
2747
  for k in lora_module_names:
2428
2748
  if k in unexpected_modules:
2429
2749
  continue
2430
2750
 
2431
2751
  base_param_name = (
2432
2752
  f"{k.replace(prefix, '')}.base_layer.weight"
2433
- if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
2753
+ if f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
2434
2754
  else f"{k.replace(prefix, '')}.weight"
2435
2755
  )
2436
2756
  base_weight_param = transformer_state_dict[base_param_name]
@@ -2484,6 +2804,15 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2484
2804
 
2485
2805
  raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
2486
2806
 
2807
+ @staticmethod
2808
+ def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False):
2809
+ outputs = [state_dict]
2810
+ if return_alphas:
2811
+ outputs.append(alphas)
2812
+ if return_metadata:
2813
+ outputs.append(metadata)
2814
+ return tuple(outputs) if (return_alphas or return_metadata) else state_dict
2815
+
2487
2816
 
2488
2817
  # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
2489
2818
  # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
@@ -2500,6 +2829,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2500
2829
  network_alphas,
2501
2830
  transformer,
2502
2831
  adapter_name=None,
2832
+ metadata=None,
2503
2833
  _pipeline=None,
2504
2834
  low_cpu_mem_usage=False,
2505
2835
  hotswap: bool = False,
@@ -2524,143 +2854,380 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2524
2854
  low_cpu_mem_usage (`bool`, *optional*):
2525
2855
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2526
2856
  weights.
2527
- hotswap : (`bool`, *optional*)
2528
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
2529
- in-place. This means that, instead of loading an additional adapter, this will take the existing
2530
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
2531
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
2532
- torch.compile, loading the new adapter does not require recompilation of the model. When using
2533
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
2534
-
2535
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
2536
- to call an additional method before loading the adapter:
2537
-
2538
- ```py
2539
- pipeline = ... # load diffusers pipeline
2540
- max_rank = ... # the highest rank among all LoRAs that you want to load
2541
- # call *before* compiling and loading the LoRA adapter
2542
- pipeline.enable_lora_hotswap(target_rank=max_rank)
2543
- pipeline.load_lora_weights(file_name)
2544
- # optionally compile the model now
2545
- ```
2546
-
2547
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
2548
- limitations to this technique, which are documented here:
2549
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
2857
+ hotswap (`bool`, *optional*):
2858
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
2859
+ metadata (`dict`):
2860
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
2861
+ from the state dict.
2862
+ """
2863
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
2864
+ raise ValueError(
2865
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2866
+ )
2867
+
2868
+ # Load the layers corresponding to transformer.
2869
+ logger.info(f"Loading {cls.transformer_name}.")
2870
+ transformer.load_lora_adapter(
2871
+ state_dict,
2872
+ network_alphas=network_alphas,
2873
+ adapter_name=adapter_name,
2874
+ metadata=metadata,
2875
+ _pipeline=_pipeline,
2876
+ low_cpu_mem_usage=low_cpu_mem_usage,
2877
+ hotswap=hotswap,
2878
+ )
2879
+
2880
+ @classmethod
2881
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
2882
+ def load_lora_into_text_encoder(
2883
+ cls,
2884
+ state_dict,
2885
+ network_alphas,
2886
+ text_encoder,
2887
+ prefix=None,
2888
+ lora_scale=1.0,
2889
+ adapter_name=None,
2890
+ _pipeline=None,
2891
+ low_cpu_mem_usage=False,
2892
+ hotswap: bool = False,
2893
+ metadata=None,
2894
+ ):
2895
+ """
2896
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
2897
+
2898
+ Parameters:
2899
+ state_dict (`dict`):
2900
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
2901
+ additional `text_encoder` to distinguish between unet lora layers.
2902
+ network_alphas (`Dict[str, float]`):
2903
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
2904
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
2905
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
2906
+ text_encoder (`CLIPTextModel`):
2907
+ The text encoder model to load the LoRA layers into.
2908
+ prefix (`str`):
2909
+ Expected prefix of the `text_encoder` in the `state_dict`.
2910
+ lora_scale (`float`):
2911
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
2912
+ lora layer.
2913
+ adapter_name (`str`, *optional*):
2914
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2915
+ `default_{i}` where i is the total number of adapters being loaded.
2916
+ low_cpu_mem_usage (`bool`, *optional*):
2917
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2918
+ weights.
2919
+ hotswap (`bool`, *optional*):
2920
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
2921
+ metadata (`dict`):
2922
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
2923
+ from the state dict.
2924
+ """
2925
+ _load_lora_into_text_encoder(
2926
+ state_dict=state_dict,
2927
+ network_alphas=network_alphas,
2928
+ lora_scale=lora_scale,
2929
+ text_encoder=text_encoder,
2930
+ prefix=prefix,
2931
+ text_encoder_name=cls.text_encoder_name,
2932
+ adapter_name=adapter_name,
2933
+ metadata=metadata,
2934
+ _pipeline=_pipeline,
2935
+ low_cpu_mem_usage=low_cpu_mem_usage,
2936
+ hotswap=hotswap,
2937
+ )
2938
+
2939
+ @classmethod
2940
+ def save_lora_weights(
2941
+ cls,
2942
+ save_directory: Union[str, os.PathLike],
2943
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
2944
+ transformer_lora_layers: Dict[str, torch.nn.Module] = None,
2945
+ is_main_process: bool = True,
2946
+ weight_name: str = None,
2947
+ save_function: Callable = None,
2948
+ safe_serialization: bool = True,
2949
+ ):
2950
+ r"""
2951
+ Save the LoRA parameters corresponding to the UNet and text encoder.
2952
+
2953
+ Arguments:
2954
+ save_directory (`str` or `os.PathLike`):
2955
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
2956
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
2957
+ State dict of the LoRA layers corresponding to the `unet`.
2958
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
2959
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
2960
+ encoder LoRA state dict because it comes from 🤗 Transformers.
2961
+ is_main_process (`bool`, *optional*, defaults to `True`):
2962
+ Whether the process calling this is the main process or not. Useful during distributed training and you
2963
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
2964
+ process to avoid race conditions.
2965
+ save_function (`Callable`):
2966
+ The function to use to save the state dictionary. Useful during distributed training when you need to
2967
+ replace `torch.save` with another method. Can be configured with the environment variable
2968
+ `DIFFUSERS_SAVE_MODE`.
2969
+ safe_serialization (`bool`, *optional*, defaults to `True`):
2970
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
2971
+ """
2972
+ state_dict = {}
2973
+
2974
+ if not (transformer_lora_layers or text_encoder_lora_layers):
2975
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
2976
+
2977
+ if transformer_lora_layers:
2978
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
2979
+
2980
+ if text_encoder_lora_layers:
2981
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
2982
+
2983
+ # Save the model
2984
+ cls.write_lora_layers(
2985
+ state_dict=state_dict,
2986
+ save_directory=save_directory,
2987
+ is_main_process=is_main_process,
2988
+ weight_name=weight_name,
2989
+ save_function=save_function,
2990
+ safe_serialization=safe_serialization,
2991
+ )
2992
+
2993
+
2994
+ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2995
+ r"""
2996
+ Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline`].
2997
+ """
2998
+
2999
+ _lora_loadable_modules = ["transformer"]
3000
+ transformer_name = TRANSFORMER_NAME
3001
+
3002
+ @classmethod
3003
+ @validate_hf_hub_args
3004
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
3005
+ def lora_state_dict(
3006
+ cls,
3007
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3008
+ **kwargs,
3009
+ ):
3010
+ r"""
3011
+ Return state dict for lora weights and the network alphas.
3012
+
3013
+ <Tip warning={true}>
3014
+
3015
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
3016
+
3017
+ This function is experimental and might change in the future.
3018
+
3019
+ </Tip>
3020
+
3021
+ Parameters:
3022
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3023
+ Can be either:
3024
+
3025
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3026
+ the Hub.
3027
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3028
+ with [`ModelMixin.save_pretrained`].
3029
+ - A [torch state
3030
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3031
+
3032
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3033
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3034
+ is not used.
3035
+ force_download (`bool`, *optional*, defaults to `False`):
3036
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3037
+ cached versions if they exist.
3038
+
3039
+ proxies (`Dict[str, str]`, *optional*):
3040
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3041
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3042
+ local_files_only (`bool`, *optional*, defaults to `False`):
3043
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
3044
+ won't be downloaded from the Hub.
3045
+ token (`str` or *bool*, *optional*):
3046
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3047
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
3048
+ revision (`str`, *optional*, defaults to `"main"`):
3049
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3050
+ allowed by Git.
3051
+ subfolder (`str`, *optional*, defaults to `""`):
3052
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
3053
+ return_lora_metadata (`bool`, *optional*, defaults to False):
3054
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
3055
+
3056
+ """
3057
+ # Load the main state dict first which has the LoRA layers for either of
3058
+ # transformer and text encoder or both.
3059
+ cache_dir = kwargs.pop("cache_dir", None)
3060
+ force_download = kwargs.pop("force_download", False)
3061
+ proxies = kwargs.pop("proxies", None)
3062
+ local_files_only = kwargs.pop("local_files_only", None)
3063
+ token = kwargs.pop("token", None)
3064
+ revision = kwargs.pop("revision", None)
3065
+ subfolder = kwargs.pop("subfolder", None)
3066
+ weight_name = kwargs.pop("weight_name", None)
3067
+ use_safetensors = kwargs.pop("use_safetensors", None)
3068
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
3069
+
3070
+ allow_pickle = False
3071
+ if use_safetensors is None:
3072
+ use_safetensors = True
3073
+ allow_pickle = True
3074
+
3075
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
3076
+
3077
+ state_dict, metadata = _fetch_state_dict(
3078
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
3079
+ weight_name=weight_name,
3080
+ use_safetensors=use_safetensors,
3081
+ local_files_only=local_files_only,
3082
+ cache_dir=cache_dir,
3083
+ force_download=force_download,
3084
+ proxies=proxies,
3085
+ token=token,
3086
+ revision=revision,
3087
+ subfolder=subfolder,
3088
+ user_agent=user_agent,
3089
+ allow_pickle=allow_pickle,
3090
+ )
3091
+
3092
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
3093
+ if is_dora_scale_present:
3094
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3095
+ logger.warning(warn_msg)
3096
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
3097
+
3098
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
3099
+ return out
3100
+
3101
+ def load_lora_weights(
3102
+ self,
3103
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3104
+ adapter_name: Optional[str] = None,
3105
+ hotswap: bool = False,
3106
+ **kwargs,
3107
+ ):
3108
+ """
3109
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3110
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3111
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3112
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3113
+ dict is loaded into `self.transformer`.
3114
+
3115
+ Parameters:
3116
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3117
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3118
+ adapter_name (`str`, *optional*):
3119
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3120
+ `default_{i}` where i is the total number of adapters being loaded.
3121
+ low_cpu_mem_usage (`bool`, *optional*):
3122
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3123
+ weights.
3124
+ hotswap (`bool`, *optional*):
3125
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
3126
+ kwargs (`dict`, *optional*):
3127
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
2550
3128
  """
2551
- if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
3129
+ if not USE_PEFT_BACKEND:
3130
+ raise ValueError("PEFT backend is required for this method.")
3131
+
3132
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
3133
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
2552
3134
  raise ValueError(
2553
3135
  "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2554
3136
  )
2555
3137
 
2556
- # Load the layers corresponding to transformer.
2557
- logger.info(f"Loading {cls.transformer_name}.")
2558
- transformer.load_lora_adapter(
3138
+ # if a dict is passed, copy it instead of modifying it inplace
3139
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
3140
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
3141
+
3142
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3143
+ kwargs["return_lora_metadata"] = True
3144
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3145
+
3146
+ is_correct_format = all("lora" in key for key in state_dict.keys())
3147
+ if not is_correct_format:
3148
+ raise ValueError("Invalid LoRA checkpoint.")
3149
+
3150
+ self.load_lora_into_transformer(
2559
3151
  state_dict,
2560
- network_alphas=network_alphas,
3152
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
2561
3153
  adapter_name=adapter_name,
2562
- _pipeline=_pipeline,
3154
+ metadata=metadata,
3155
+ _pipeline=self,
2563
3156
  low_cpu_mem_usage=low_cpu_mem_usage,
2564
3157
  hotswap=hotswap,
2565
3158
  )
2566
3159
 
2567
3160
  @classmethod
2568
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
2569
- def load_lora_into_text_encoder(
3161
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
3162
+ def load_lora_into_transformer(
2570
3163
  cls,
2571
3164
  state_dict,
2572
- network_alphas,
2573
- text_encoder,
2574
- prefix=None,
2575
- lora_scale=1.0,
3165
+ transformer,
2576
3166
  adapter_name=None,
2577
3167
  _pipeline=None,
2578
3168
  low_cpu_mem_usage=False,
2579
3169
  hotswap: bool = False,
3170
+ metadata=None,
2580
3171
  ):
2581
3172
  """
2582
- This will load the LoRA layers specified in `state_dict` into `text_encoder`
3173
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
2583
3174
 
2584
3175
  Parameters:
2585
3176
  state_dict (`dict`):
2586
- A standard state dict containing the lora layer parameters. The key should be prefixed with an
2587
- additional `text_encoder` to distinguish between unet lora layers.
2588
- network_alphas (`Dict[str, float]`):
2589
- The value of the network alpha used for stable learning and preventing underflow. This value has the
2590
- same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
2591
- link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
2592
- text_encoder (`CLIPTextModel`):
2593
- The text encoder model to load the LoRA layers into.
2594
- prefix (`str`):
2595
- Expected prefix of the `text_encoder` in the `state_dict`.
2596
- lora_scale (`float`):
2597
- How much to scale the output of the lora linear layer before it is added with the output of the regular
2598
- lora layer.
3177
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3178
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3179
+ encoder lora layers.
3180
+ transformer (`CogVideoXTransformer3DModel`):
3181
+ The Transformer model to load the LoRA layers into.
2599
3182
  adapter_name (`str`, *optional*):
2600
3183
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2601
3184
  `default_{i}` where i is the total number of adapters being loaded.
2602
3185
  low_cpu_mem_usage (`bool`, *optional*):
2603
3186
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2604
3187
  weights.
2605
- hotswap : (`bool`, *optional*)
2606
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
2607
- in-place. This means that, instead of loading an additional adapter, this will take the existing
2608
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
2609
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
2610
- torch.compile, loading the new adapter does not require recompilation of the model. When using
2611
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
2612
-
2613
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
2614
- to call an additional method before loading the adapter:
2615
-
2616
- ```py
2617
- pipeline = ... # load diffusers pipeline
2618
- max_rank = ... # the highest rank among all LoRAs that you want to load
2619
- # call *before* compiling and loading the LoRA adapter
2620
- pipeline.enable_lora_hotswap(target_rank=max_rank)
2621
- pipeline.load_lora_weights(file_name)
2622
- # optionally compile the model now
2623
- ```
2624
-
2625
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
2626
- limitations to this technique, which are documented here:
2627
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
3188
+ hotswap (`bool`, *optional*):
3189
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
3190
+ metadata (`dict`):
3191
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
3192
+ from the state dict.
2628
3193
  """
2629
- _load_lora_into_text_encoder(
2630
- state_dict=state_dict,
2631
- network_alphas=network_alphas,
2632
- lora_scale=lora_scale,
2633
- text_encoder=text_encoder,
2634
- prefix=prefix,
2635
- text_encoder_name=cls.text_encoder_name,
3194
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3195
+ raise ValueError(
3196
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3197
+ )
3198
+
3199
+ # Load the layers corresponding to transformer.
3200
+ logger.info(f"Loading {cls.transformer_name}.")
3201
+ transformer.load_lora_adapter(
3202
+ state_dict,
3203
+ network_alphas=None,
2636
3204
  adapter_name=adapter_name,
3205
+ metadata=metadata,
2637
3206
  _pipeline=_pipeline,
2638
3207
  low_cpu_mem_usage=low_cpu_mem_usage,
2639
3208
  hotswap=hotswap,
2640
3209
  )
2641
3210
 
2642
3211
  @classmethod
3212
+ # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
2643
3213
  def save_lora_weights(
2644
3214
  cls,
2645
3215
  save_directory: Union[str, os.PathLike],
2646
- text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
2647
- transformer_lora_layers: Dict[str, torch.nn.Module] = None,
3216
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
2648
3217
  is_main_process: bool = True,
2649
3218
  weight_name: str = None,
2650
3219
  save_function: Callable = None,
2651
3220
  safe_serialization: bool = True,
3221
+ transformer_lora_adapter_metadata: Optional[dict] = None,
2652
3222
  ):
2653
3223
  r"""
2654
- Save the LoRA parameters corresponding to the UNet and text encoder.
3224
+ Save the LoRA parameters corresponding to the transformer.
2655
3225
 
2656
3226
  Arguments:
2657
3227
  save_directory (`str` or `os.PathLike`):
2658
3228
  Directory to save LoRA parameters to. Will be created if it doesn't exist.
2659
- unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
2660
- State dict of the LoRA layers corresponding to the `unet`.
2661
- text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
2662
- State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
2663
- encoder LoRA state dict because it comes from 🤗 Transformers.
3229
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
3230
+ State dict of the LoRA layers corresponding to the `transformer`.
2664
3231
  is_main_process (`bool`, *optional*, defaults to `True`):
2665
3232
  Whether the process calling this is the main process or not. Useful during distributed training and you
2666
3233
  need to call this function on all processes. In this case, set `is_main_process=True` only on the main
@@ -2671,17 +3238,21 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2671
3238
  `DIFFUSERS_SAVE_MODE`.
2672
3239
  safe_serialization (`bool`, *optional*, defaults to `True`):
2673
3240
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3241
+ transformer_lora_adapter_metadata:
3242
+ LoRA adapter metadata associated with the transformer to be serialized with the state dict.
2674
3243
  """
2675
3244
  state_dict = {}
3245
+ lora_adapter_metadata = {}
2676
3246
 
2677
- if not (transformer_lora_layers or text_encoder_lora_layers):
2678
- raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
3247
+ if not transformer_lora_layers:
3248
+ raise ValueError("You must pass `transformer_lora_layers`.")
2679
3249
 
2680
- if transformer_lora_layers:
2681
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3250
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
2682
3251
 
2683
- if text_encoder_lora_layers:
2684
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
3252
+ if transformer_lora_adapter_metadata is not None:
3253
+ lora_adapter_metadata.update(
3254
+ _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
3255
+ )
2685
3256
 
2686
3257
  # Save the model
2687
3258
  cls.write_lora_layers(
@@ -2691,12 +3262,77 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2691
3262
  weight_name=weight_name,
2692
3263
  save_function=save_function,
2693
3264
  safe_serialization=safe_serialization,
3265
+ lora_adapter_metadata=lora_adapter_metadata,
2694
3266
  )
2695
3267
 
3268
+ def fuse_lora(
3269
+ self,
3270
+ components: List[str] = ["transformer"],
3271
+ lora_scale: float = 1.0,
3272
+ safe_fusing: bool = False,
3273
+ adapter_names: Optional[List[str]] = None,
3274
+ **kwargs,
3275
+ ):
3276
+ r"""
3277
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
3278
+
3279
+ <Tip warning={true}>
2696
3280
 
2697
- class CogVideoXLoraLoaderMixin(LoraBaseMixin):
3281
+ This is an experimental API.
3282
+
3283
+ </Tip>
3284
+
3285
+ Args:
3286
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
3287
+ lora_scale (`float`, defaults to 1.0):
3288
+ Controls how much to influence the outputs with the LoRA parameters.
3289
+ safe_fusing (`bool`, defaults to `False`):
3290
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
3291
+ adapter_names (`List[str]`, *optional*):
3292
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
3293
+
3294
+ Example:
3295
+
3296
+ ```py
3297
+ from diffusers import DiffusionPipeline
3298
+ import torch
3299
+
3300
+ pipeline = DiffusionPipeline.from_pretrained(
3301
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
3302
+ ).to("cuda")
3303
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
3304
+ pipeline.fuse_lora(lora_scale=0.7)
3305
+ ```
3306
+ """
3307
+ super().fuse_lora(
3308
+ components=components,
3309
+ lora_scale=lora_scale,
3310
+ safe_fusing=safe_fusing,
3311
+ adapter_names=adapter_names,
3312
+ **kwargs,
3313
+ )
3314
+
3315
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
3316
+ r"""
3317
+ Reverses the effect of
3318
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
3319
+
3320
+ <Tip warning={true}>
3321
+
3322
+ This is an experimental API.
3323
+
3324
+ </Tip>
3325
+
3326
+ Args:
3327
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3328
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
3329
+ """
3330
+ super().unfuse_lora(components=components, **kwargs)
3331
+
3332
+
3333
+ class Mochi1LoraLoaderMixin(LoraBaseMixin):
2698
3334
  r"""
2699
- Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline`].
3335
+ Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`].
2700
3336
  """
2701
3337
 
2702
3338
  _lora_loadable_modules = ["transformer"]
@@ -2753,6 +3389,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2753
3389
  allowed by Git.
2754
3390
  subfolder (`str`, *optional*, defaults to `""`):
2755
3391
  The subfolder location of a model file within a larger model repository on the Hub or locally.
3392
+ return_lora_metadata (`bool`, *optional*, defaults to False):
3393
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
2756
3394
 
2757
3395
  """
2758
3396
  # Load the main state dict first which has the LoRA layers for either of
@@ -2766,18 +3404,16 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2766
3404
  subfolder = kwargs.pop("subfolder", None)
2767
3405
  weight_name = kwargs.pop("weight_name", None)
2768
3406
  use_safetensors = kwargs.pop("use_safetensors", None)
3407
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
2769
3408
 
2770
3409
  allow_pickle = False
2771
3410
  if use_safetensors is None:
2772
3411
  use_safetensors = True
2773
3412
  allow_pickle = True
2774
3413
 
2775
- user_agent = {
2776
- "file_type": "attn_procs_weights",
2777
- "framework": "pytorch",
2778
- }
3414
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
2779
3415
 
2780
- state_dict = _fetch_state_dict(
3416
+ state_dict, metadata = _fetch_state_dict(
2781
3417
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
2782
3418
  weight_name=weight_name,
2783
3419
  use_safetensors=use_safetensors,
@@ -2798,10 +3434,16 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2798
3434
  logger.warning(warn_msg)
2799
3435
  state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
2800
3436
 
2801
- return state_dict
3437
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
3438
+ return out
2802
3439
 
3440
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
2803
3441
  def load_lora_weights(
2804
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
3442
+ self,
3443
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3444
+ adapter_name: Optional[str] = None,
3445
+ hotswap: bool = False,
3446
+ **kwargs,
2805
3447
  ):
2806
3448
  """
2807
3449
  Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
@@ -2819,6 +3461,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2819
3461
  low_cpu_mem_usage (`bool`, *optional*):
2820
3462
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2821
3463
  weights.
3464
+ hotswap (`bool`, *optional*):
3465
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
2822
3466
  kwargs (`dict`, *optional*):
2823
3467
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
2824
3468
  """
@@ -2836,7 +3480,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2836
3480
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
2837
3481
 
2838
3482
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
2839
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3483
+ kwargs["return_lora_metadata"] = True
3484
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
2840
3485
 
2841
3486
  is_correct_format = all("lora" in key for key in state_dict.keys())
2842
3487
  if not is_correct_format:
@@ -2846,54 +3491,45 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2846
3491
  state_dict,
2847
3492
  transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
2848
3493
  adapter_name=adapter_name,
3494
+ metadata=metadata,
2849
3495
  _pipeline=self,
2850
3496
  low_cpu_mem_usage=low_cpu_mem_usage,
3497
+ hotswap=hotswap,
2851
3498
  )
2852
3499
 
2853
3500
  @classmethod
2854
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
3501
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
2855
3502
  def load_lora_into_transformer(
2856
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
3503
+ cls,
3504
+ state_dict,
3505
+ transformer,
3506
+ adapter_name=None,
3507
+ _pipeline=None,
3508
+ low_cpu_mem_usage=False,
3509
+ hotswap: bool = False,
3510
+ metadata=None,
2857
3511
  ):
2858
3512
  """
2859
3513
  This will load the LoRA layers specified in `state_dict` into `transformer`.
2860
3514
 
2861
- Parameters:
2862
- state_dict (`dict`):
2863
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
2864
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
2865
- encoder lora layers.
2866
- transformer (`CogVideoXTransformer3DModel`):
2867
- The Transformer model to load the LoRA layers into.
2868
- adapter_name (`str`, *optional*):
2869
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2870
- `default_{i}` where i is the total number of adapters being loaded.
2871
- low_cpu_mem_usage (`bool`, *optional*):
2872
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2873
- weights.
2874
- hotswap : (`bool`, *optional*)
2875
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
2876
- in-place. This means that, instead of loading an additional adapter, this will take the existing
2877
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
2878
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
2879
- torch.compile, loading the new adapter does not require recompilation of the model. When using
2880
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
2881
-
2882
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
2883
- to call an additional method before loading the adapter:
2884
-
2885
- ```py
2886
- pipeline = ... # load diffusers pipeline
2887
- max_rank = ... # the highest rank among all LoRAs that you want to load
2888
- # call *before* compiling and loading the LoRA adapter
2889
- pipeline.enable_lora_hotswap(target_rank=max_rank)
2890
- pipeline.load_lora_weights(file_name)
2891
- # optionally compile the model now
2892
- ```
2893
-
2894
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
2895
- limitations to this technique, which are documented here:
2896
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
3515
+ Parameters:
3516
+ state_dict (`dict`):
3517
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3518
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3519
+ encoder lora layers.
3520
+ transformer (`MochiTransformer3DModel`):
3521
+ The Transformer model to load the LoRA layers into.
3522
+ adapter_name (`str`, *optional*):
3523
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3524
+ `default_{i}` where i is the total number of adapters being loaded.
3525
+ low_cpu_mem_usage (`bool`, *optional*):
3526
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3527
+ weights.
3528
+ hotswap (`bool`, *optional*):
3529
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
3530
+ metadata (`dict`):
3531
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
3532
+ from the state dict.
2897
3533
  """
2898
3534
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
2899
3535
  raise ValueError(
@@ -2906,13 +3542,14 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2906
3542
  state_dict,
2907
3543
  network_alphas=None,
2908
3544
  adapter_name=adapter_name,
3545
+ metadata=metadata,
2909
3546
  _pipeline=_pipeline,
2910
3547
  low_cpu_mem_usage=low_cpu_mem_usage,
2911
3548
  hotswap=hotswap,
2912
3549
  )
2913
3550
 
2914
3551
  @classmethod
2915
- # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
3552
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
2916
3553
  def save_lora_weights(
2917
3554
  cls,
2918
3555
  save_directory: Union[str, os.PathLike],
@@ -2921,9 +3558,10 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2921
3558
  weight_name: str = None,
2922
3559
  save_function: Callable = None,
2923
3560
  safe_serialization: bool = True,
3561
+ transformer_lora_adapter_metadata: Optional[dict] = None,
2924
3562
  ):
2925
3563
  r"""
2926
- Save the LoRA parameters corresponding to the UNet and text encoder.
3564
+ Save the LoRA parameters corresponding to the transformer.
2927
3565
 
2928
3566
  Arguments:
2929
3567
  save_directory (`str` or `os.PathLike`):
@@ -2940,14 +3578,21 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2940
3578
  `DIFFUSERS_SAVE_MODE`.
2941
3579
  safe_serialization (`bool`, *optional*, defaults to `True`):
2942
3580
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3581
+ transformer_lora_adapter_metadata:
3582
+ LoRA adapter metadata associated with the transformer to be serialized with the state dict.
2943
3583
  """
2944
3584
  state_dict = {}
3585
+ lora_adapter_metadata = {}
2945
3586
 
2946
3587
  if not transformer_lora_layers:
2947
3588
  raise ValueError("You must pass `transformer_lora_layers`.")
2948
3589
 
2949
- if transformer_lora_layers:
2950
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3590
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3591
+
3592
+ if transformer_lora_adapter_metadata is not None:
3593
+ lora_adapter_metadata.update(
3594
+ _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
3595
+ )
2951
3596
 
2952
3597
  # Save the model
2953
3598
  cls.write_lora_layers(
@@ -2957,8 +3602,10 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
2957
3602
  weight_name=weight_name,
2958
3603
  save_function=save_function,
2959
3604
  safe_serialization=safe_serialization,
3605
+ lora_adapter_metadata=lora_adapter_metadata,
2960
3606
  )
2961
3607
 
3608
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
2962
3609
  def fuse_lora(
2963
3610
  self,
2964
3611
  components: List[str] = ["transformer"],
@@ -3006,6 +3653,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
3006
3653
  **kwargs,
3007
3654
  )
3008
3655
 
3656
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
3009
3657
  def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
3010
3658
  r"""
3011
3659
  Reverses the effect of
@@ -3024,9 +3672,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
3024
3672
  super().unfuse_lora(components=components, **kwargs)
3025
3673
 
3026
3674
 
3027
- class Mochi1LoraLoaderMixin(LoraBaseMixin):
3675
+ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3028
3676
  r"""
3029
- Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`].
3677
+ Load LoRA layers into [`LTXVideoTransformer3DModel`]. Specific to [`LTXPipeline`].
3030
3678
  """
3031
3679
 
3032
3680
  _lora_loadable_modules = ["transformer"]
@@ -3034,7 +3682,6 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3034
3682
 
3035
3683
  @classmethod
3036
3684
  @validate_hf_hub_args
3037
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
3038
3685
  def lora_state_dict(
3039
3686
  cls,
3040
3687
  pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -3083,7 +3730,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3083
3730
  allowed by Git.
3084
3731
  subfolder (`str`, *optional*, defaults to `""`):
3085
3732
  The subfolder location of a model file within a larger model repository on the Hub or locally.
3086
-
3733
+ return_lora_metadata (`bool`, *optional*, defaults to False):
3734
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
3087
3735
  """
3088
3736
  # Load the main state dict first which has the LoRA layers for either of
3089
3737
  # transformer and text encoder or both.
@@ -3096,18 +3744,16 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3096
3744
  subfolder = kwargs.pop("subfolder", None)
3097
3745
  weight_name = kwargs.pop("weight_name", None)
3098
3746
  use_safetensors = kwargs.pop("use_safetensors", None)
3747
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
3099
3748
 
3100
3749
  allow_pickle = False
3101
3750
  if use_safetensors is None:
3102
3751
  use_safetensors = True
3103
3752
  allow_pickle = True
3104
3753
 
3105
- user_agent = {
3106
- "file_type": "attn_procs_weights",
3107
- "framework": "pytorch",
3108
- }
3754
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
3109
3755
 
3110
- state_dict = _fetch_state_dict(
3756
+ state_dict, metadata = _fetch_state_dict(
3111
3757
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
3112
3758
  weight_name=weight_name,
3113
3759
  use_safetensors=use_safetensors,
@@ -3128,11 +3774,20 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3128
3774
  logger.warning(warn_msg)
3129
3775
  state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
3130
3776
 
3131
- return state_dict
3777
+ is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict)
3778
+ if is_non_diffusers_format:
3779
+ state_dict = _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict)
3780
+
3781
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
3782
+ return out
3132
3783
 
3133
3784
  # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3134
3785
  def load_lora_weights(
3135
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
3786
+ self,
3787
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3788
+ adapter_name: Optional[str] = None,
3789
+ hotswap: bool = False,
3790
+ **kwargs,
3136
3791
  ):
3137
3792
  """
3138
3793
  Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
@@ -3150,6 +3805,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3150
3805
  low_cpu_mem_usage (`bool`, *optional*):
3151
3806
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3152
3807
  weights.
3808
+ hotswap (`bool`, *optional*):
3809
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
3153
3810
  kwargs (`dict`, *optional*):
3154
3811
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3155
3812
  """
@@ -3167,7 +3824,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3167
3824
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
3168
3825
 
3169
3826
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3170
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3827
+ kwargs["return_lora_metadata"] = True
3828
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3171
3829
 
3172
3830
  is_correct_format = all("lora" in key for key in state_dict.keys())
3173
3831
  if not is_correct_format:
@@ -3177,14 +3835,23 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3177
3835
  state_dict,
3178
3836
  transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
3179
3837
  adapter_name=adapter_name,
3838
+ metadata=metadata,
3180
3839
  _pipeline=self,
3181
3840
  low_cpu_mem_usage=low_cpu_mem_usage,
3841
+ hotswap=hotswap,
3182
3842
  )
3183
3843
 
3184
3844
  @classmethod
3185
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
3845
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
3186
3846
  def load_lora_into_transformer(
3187
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
3847
+ cls,
3848
+ state_dict,
3849
+ transformer,
3850
+ adapter_name=None,
3851
+ _pipeline=None,
3852
+ low_cpu_mem_usage=False,
3853
+ hotswap: bool = False,
3854
+ metadata=None,
3188
3855
  ):
3189
3856
  """
3190
3857
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -3194,7 +3861,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3194
3861
  A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3195
3862
  into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3196
3863
  encoder lora layers.
3197
- transformer (`MochiTransformer3DModel`):
3864
+ transformer (`LTXVideoTransformer3DModel`):
3198
3865
  The Transformer model to load the LoRA layers into.
3199
3866
  adapter_name (`str`, *optional*):
3200
3867
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
@@ -3202,29 +3869,11 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3202
3869
  low_cpu_mem_usage (`bool`, *optional*):
3203
3870
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3204
3871
  weights.
3205
- hotswap : (`bool`, *optional*)
3206
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
3207
- in-place. This means that, instead of loading an additional adapter, this will take the existing
3208
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
3209
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
3210
- torch.compile, loading the new adapter does not require recompilation of the model. When using
3211
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
3212
-
3213
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
3214
- to call an additional method before loading the adapter:
3215
-
3216
- ```py
3217
- pipeline = ... # load diffusers pipeline
3218
- max_rank = ... # the highest rank among all LoRAs that you want to load
3219
- # call *before* compiling and loading the LoRA adapter
3220
- pipeline.enable_lora_hotswap(target_rank=max_rank)
3221
- pipeline.load_lora_weights(file_name)
3222
- # optionally compile the model now
3223
- ```
3224
-
3225
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
3226
- limitations to this technique, which are documented here:
3227
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
3872
+ hotswap (`bool`, *optional*):
3873
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
3874
+ metadata (`dict`):
3875
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
3876
+ from the state dict.
3228
3877
  """
3229
3878
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3230
3879
  raise ValueError(
@@ -3237,6 +3886,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3237
3886
  state_dict,
3238
3887
  network_alphas=None,
3239
3888
  adapter_name=adapter_name,
3889
+ metadata=metadata,
3240
3890
  _pipeline=_pipeline,
3241
3891
  low_cpu_mem_usage=low_cpu_mem_usage,
3242
3892
  hotswap=hotswap,
@@ -3252,9 +3902,10 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3252
3902
  weight_name: str = None,
3253
3903
  save_function: Callable = None,
3254
3904
  safe_serialization: bool = True,
3905
+ transformer_lora_adapter_metadata: Optional[dict] = None,
3255
3906
  ):
3256
3907
  r"""
3257
- Save the LoRA parameters corresponding to the UNet and text encoder.
3908
+ Save the LoRA parameters corresponding to the transformer.
3258
3909
 
3259
3910
  Arguments:
3260
3911
  save_directory (`str` or `os.PathLike`):
@@ -3271,14 +3922,21 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3271
3922
  `DIFFUSERS_SAVE_MODE`.
3272
3923
  safe_serialization (`bool`, *optional*, defaults to `True`):
3273
3924
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3925
+ transformer_lora_adapter_metadata:
3926
+ LoRA adapter metadata associated with the transformer to be serialized with the state dict.
3274
3927
  """
3275
3928
  state_dict = {}
3929
+ lora_adapter_metadata = {}
3276
3930
 
3277
3931
  if not transformer_lora_layers:
3278
3932
  raise ValueError("You must pass `transformer_lora_layers`.")
3279
3933
 
3280
- if transformer_lora_layers:
3281
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3934
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3935
+
3936
+ if transformer_lora_adapter_metadata is not None:
3937
+ lora_adapter_metadata.update(
3938
+ _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
3939
+ )
3282
3940
 
3283
3941
  # Save the model
3284
3942
  cls.write_lora_layers(
@@ -3288,6 +3946,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3288
3946
  weight_name=weight_name,
3289
3947
  save_function=save_function,
3290
3948
  safe_serialization=safe_serialization,
3949
+ lora_adapter_metadata=lora_adapter_metadata,
3291
3950
  )
3292
3951
 
3293
3952
  # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -3357,9 +4016,9 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
3357
4016
  super().unfuse_lora(components=components, **kwargs)
3358
4017
 
3359
4018
 
3360
- class LTXVideoLoraLoaderMixin(LoraBaseMixin):
4019
+ class SanaLoraLoaderMixin(LoraBaseMixin):
3361
4020
  r"""
3362
- Load LoRA layers into [`LTXVideoTransformer3DModel`]. Specific to [`LTXPipeline`].
4021
+ Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
3363
4022
  """
3364
4023
 
3365
4024
  _lora_loadable_modules = ["transformer"]
@@ -3367,7 +4026,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3367
4026
 
3368
4027
  @classmethod
3369
4028
  @validate_hf_hub_args
3370
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
4029
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
3371
4030
  def lora_state_dict(
3372
4031
  cls,
3373
4032
  pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -3416,6 +4075,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3416
4075
  allowed by Git.
3417
4076
  subfolder (`str`, *optional*, defaults to `""`):
3418
4077
  The subfolder location of a model file within a larger model repository on the Hub or locally.
4078
+ return_lora_metadata (`bool`, *optional*, defaults to False):
4079
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
3419
4080
 
3420
4081
  """
3421
4082
  # Load the main state dict first which has the LoRA layers for either of
@@ -3429,18 +4090,16 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3429
4090
  subfolder = kwargs.pop("subfolder", None)
3430
4091
  weight_name = kwargs.pop("weight_name", None)
3431
4092
  use_safetensors = kwargs.pop("use_safetensors", None)
4093
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
3432
4094
 
3433
4095
  allow_pickle = False
3434
4096
  if use_safetensors is None:
3435
4097
  use_safetensors = True
3436
4098
  allow_pickle = True
3437
4099
 
3438
- user_agent = {
3439
- "file_type": "attn_procs_weights",
3440
- "framework": "pytorch",
3441
- }
4100
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
3442
4101
 
3443
- state_dict = _fetch_state_dict(
4102
+ state_dict, metadata = _fetch_state_dict(
3444
4103
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
3445
4104
  weight_name=weight_name,
3446
4105
  use_safetensors=use_safetensors,
@@ -3461,11 +4120,16 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3461
4120
  logger.warning(warn_msg)
3462
4121
  state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
3463
4122
 
3464
- return state_dict
4123
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
4124
+ return out
3465
4125
 
3466
4126
  # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3467
4127
  def load_lora_weights(
3468
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
4128
+ self,
4129
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
4130
+ adapter_name: Optional[str] = None,
4131
+ hotswap: bool = False,
4132
+ **kwargs,
3469
4133
  ):
3470
4134
  """
3471
4135
  Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
@@ -3483,6 +4147,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3483
4147
  low_cpu_mem_usage (`bool`, *optional*):
3484
4148
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3485
4149
  weights.
4150
+ hotswap (`bool`, *optional*):
4151
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
3486
4152
  kwargs (`dict`, *optional*):
3487
4153
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3488
4154
  """
@@ -3500,7 +4166,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3500
4166
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
3501
4167
 
3502
4168
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3503
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4169
+ kwargs["return_lora_metadata"] = True
4170
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3504
4171
 
3505
4172
  is_correct_format = all("lora" in key for key in state_dict.keys())
3506
4173
  if not is_correct_format:
@@ -3510,14 +4177,23 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3510
4177
  state_dict,
3511
4178
  transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
3512
4179
  adapter_name=adapter_name,
4180
+ metadata=metadata,
3513
4181
  _pipeline=self,
3514
4182
  low_cpu_mem_usage=low_cpu_mem_usage,
4183
+ hotswap=hotswap,
3515
4184
  )
3516
4185
 
3517
4186
  @classmethod
3518
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
4187
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
3519
4188
  def load_lora_into_transformer(
3520
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
4189
+ cls,
4190
+ state_dict,
4191
+ transformer,
4192
+ adapter_name=None,
4193
+ _pipeline=None,
4194
+ low_cpu_mem_usage=False,
4195
+ hotswap: bool = False,
4196
+ metadata=None,
3521
4197
  ):
3522
4198
  """
3523
4199
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -3527,7 +4203,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3527
4203
  A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3528
4204
  into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3529
4205
  encoder lora layers.
3530
- transformer (`LTXVideoTransformer3DModel`):
4206
+ transformer (`SanaTransformer2DModel`):
3531
4207
  The Transformer model to load the LoRA layers into.
3532
4208
  adapter_name (`str`, *optional*):
3533
4209
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
@@ -3535,29 +4211,11 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3535
4211
  low_cpu_mem_usage (`bool`, *optional*):
3536
4212
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3537
4213
  weights.
3538
- hotswap : (`bool`, *optional*)
3539
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
3540
- in-place. This means that, instead of loading an additional adapter, this will take the existing
3541
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
3542
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
3543
- torch.compile, loading the new adapter does not require recompilation of the model. When using
3544
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
3545
-
3546
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
3547
- to call an additional method before loading the adapter:
3548
-
3549
- ```py
3550
- pipeline = ... # load diffusers pipeline
3551
- max_rank = ... # the highest rank among all LoRAs that you want to load
3552
- # call *before* compiling and loading the LoRA adapter
3553
- pipeline.enable_lora_hotswap(target_rank=max_rank)
3554
- pipeline.load_lora_weights(file_name)
3555
- # optionally compile the model now
3556
- ```
3557
-
3558
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
3559
- limitations to this technique, which are documented here:
3560
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
4214
+ hotswap (`bool`, *optional*):
4215
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
4216
+ metadata (`dict`):
4217
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
4218
+ from the state dict.
3561
4219
  """
3562
4220
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3563
4221
  raise ValueError(
@@ -3570,6 +4228,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3570
4228
  state_dict,
3571
4229
  network_alphas=None,
3572
4230
  adapter_name=adapter_name,
4231
+ metadata=metadata,
3573
4232
  _pipeline=_pipeline,
3574
4233
  low_cpu_mem_usage=low_cpu_mem_usage,
3575
4234
  hotswap=hotswap,
@@ -3585,9 +4244,10 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3585
4244
  weight_name: str = None,
3586
4245
  save_function: Callable = None,
3587
4246
  safe_serialization: bool = True,
4247
+ transformer_lora_adapter_metadata: Optional[dict] = None,
3588
4248
  ):
3589
4249
  r"""
3590
- Save the LoRA parameters corresponding to the UNet and text encoder.
4250
+ Save the LoRA parameters corresponding to the transformer.
3591
4251
 
3592
4252
  Arguments:
3593
4253
  save_directory (`str` or `os.PathLike`):
@@ -3604,14 +4264,21 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3604
4264
  `DIFFUSERS_SAVE_MODE`.
3605
4265
  safe_serialization (`bool`, *optional*, defaults to `True`):
3606
4266
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4267
+ transformer_lora_adapter_metadata:
4268
+ LoRA adapter metadata associated with the transformer to be serialized with the state dict.
3607
4269
  """
3608
4270
  state_dict = {}
4271
+ lora_adapter_metadata = {}
3609
4272
 
3610
4273
  if not transformer_lora_layers:
3611
4274
  raise ValueError("You must pass `transformer_lora_layers`.")
3612
4275
 
3613
- if transformer_lora_layers:
3614
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
4276
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
4277
+
4278
+ if transformer_lora_adapter_metadata is not None:
4279
+ lora_adapter_metadata.update(
4280
+ _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
4281
+ )
3615
4282
 
3616
4283
  # Save the model
3617
4284
  cls.write_lora_layers(
@@ -3621,6 +4288,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3621
4288
  weight_name=weight_name,
3622
4289
  save_function=save_function,
3623
4290
  safe_serialization=safe_serialization,
4291
+ lora_adapter_metadata=lora_adapter_metadata,
3624
4292
  )
3625
4293
 
3626
4294
  # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -3690,9 +4358,9 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
3690
4358
  super().unfuse_lora(components=components, **kwargs)
3691
4359
 
3692
4360
 
3693
- class SanaLoraLoaderMixin(LoraBaseMixin):
4361
+ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
3694
4362
  r"""
3695
- Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
4363
+ Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
3696
4364
  """
3697
4365
 
3698
4366
  _lora_loadable_modules = ["transformer"]
@@ -3700,7 +4368,6 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3700
4368
 
3701
4369
  @classmethod
3702
4370
  @validate_hf_hub_args
3703
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
3704
4371
  def lora_state_dict(
3705
4372
  cls,
3706
4373
  pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -3711,7 +4378,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3711
4378
 
3712
4379
  <Tip warning={true}>
3713
4380
 
3714
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
4381
+ We support loading original format HunyuanVideo LoRA checkpoints.
3715
4382
 
3716
4383
  This function is experimental and might change in the future.
3717
4384
 
@@ -3749,7 +4416,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3749
4416
  allowed by Git.
3750
4417
  subfolder (`str`, *optional*, defaults to `""`):
3751
4418
  The subfolder location of a model file within a larger model repository on the Hub or locally.
3752
-
4419
+ return_lora_metadata (`bool`, *optional*, defaults to False):
4420
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
3753
4421
  """
3754
4422
  # Load the main state dict first which has the LoRA layers for either of
3755
4423
  # transformer and text encoder or both.
@@ -3762,18 +4430,16 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3762
4430
  subfolder = kwargs.pop("subfolder", None)
3763
4431
  weight_name = kwargs.pop("weight_name", None)
3764
4432
  use_safetensors = kwargs.pop("use_safetensors", None)
4433
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
3765
4434
 
3766
4435
  allow_pickle = False
3767
4436
  if use_safetensors is None:
3768
4437
  use_safetensors = True
3769
4438
  allow_pickle = True
3770
4439
 
3771
- user_agent = {
3772
- "file_type": "attn_procs_weights",
3773
- "framework": "pytorch",
3774
- }
4440
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
3775
4441
 
3776
- state_dict = _fetch_state_dict(
4442
+ state_dict, metadata = _fetch_state_dict(
3777
4443
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
3778
4444
  weight_name=weight_name,
3779
4445
  use_safetensors=use_safetensors,
@@ -3794,11 +4460,20 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3794
4460
  logger.warning(warn_msg)
3795
4461
  state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
3796
4462
 
3797
- return state_dict
4463
+ is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
4464
+ if is_original_hunyuan_video:
4465
+ state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
4466
+
4467
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
4468
+ return out
3798
4469
 
3799
4470
  # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3800
4471
  def load_lora_weights(
3801
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
4472
+ self,
4473
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
4474
+ adapter_name: Optional[str] = None,
4475
+ hotswap: bool = False,
4476
+ **kwargs,
3802
4477
  ):
3803
4478
  """
3804
4479
  Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
@@ -3816,6 +4491,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3816
4491
  low_cpu_mem_usage (`bool`, *optional*):
3817
4492
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3818
4493
  weights.
4494
+ hotswap (`bool`, *optional*):
4495
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
3819
4496
  kwargs (`dict`, *optional*):
3820
4497
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3821
4498
  """
@@ -3833,7 +4510,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3833
4510
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
3834
4511
 
3835
4512
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3836
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4513
+ kwargs["return_lora_metadata"] = True
4514
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3837
4515
 
3838
4516
  is_correct_format = all("lora" in key for key in state_dict.keys())
3839
4517
  if not is_correct_format:
@@ -3843,14 +4521,23 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3843
4521
  state_dict,
3844
4522
  transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
3845
4523
  adapter_name=adapter_name,
4524
+ metadata=metadata,
3846
4525
  _pipeline=self,
3847
4526
  low_cpu_mem_usage=low_cpu_mem_usage,
4527
+ hotswap=hotswap,
3848
4528
  )
3849
4529
 
3850
4530
  @classmethod
3851
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
4531
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
3852
4532
  def load_lora_into_transformer(
3853
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
4533
+ cls,
4534
+ state_dict,
4535
+ transformer,
4536
+ adapter_name=None,
4537
+ _pipeline=None,
4538
+ low_cpu_mem_usage=False,
4539
+ hotswap: bool = False,
4540
+ metadata=None,
3854
4541
  ):
3855
4542
  """
3856
4543
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -3860,7 +4547,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3860
4547
  A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3861
4548
  into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3862
4549
  encoder lora layers.
3863
- transformer (`SanaTransformer2DModel`):
4550
+ transformer (`HunyuanVideoTransformer3DModel`):
3864
4551
  The Transformer model to load the LoRA layers into.
3865
4552
  adapter_name (`str`, *optional*):
3866
4553
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
@@ -3868,29 +4555,11 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3868
4555
  low_cpu_mem_usage (`bool`, *optional*):
3869
4556
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3870
4557
  weights.
3871
- hotswap : (`bool`, *optional*)
3872
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
3873
- in-place. This means that, instead of loading an additional adapter, this will take the existing
3874
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
3875
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
3876
- torch.compile, loading the new adapter does not require recompilation of the model. When using
3877
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
3878
-
3879
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
3880
- to call an additional method before loading the adapter:
3881
-
3882
- ```py
3883
- pipeline = ... # load diffusers pipeline
3884
- max_rank = ... # the highest rank among all LoRAs that you want to load
3885
- # call *before* compiling and loading the LoRA adapter
3886
- pipeline.enable_lora_hotswap(target_rank=max_rank)
3887
- pipeline.load_lora_weights(file_name)
3888
- # optionally compile the model now
3889
- ```
3890
-
3891
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
3892
- limitations to this technique, which are documented here:
3893
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
4558
+ hotswap (`bool`, *optional*):
4559
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
4560
+ metadata (`dict`):
4561
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
4562
+ from the state dict.
3894
4563
  """
3895
4564
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3896
4565
  raise ValueError(
@@ -3903,6 +4572,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3903
4572
  state_dict,
3904
4573
  network_alphas=None,
3905
4574
  adapter_name=adapter_name,
4575
+ metadata=metadata,
3906
4576
  _pipeline=_pipeline,
3907
4577
  low_cpu_mem_usage=low_cpu_mem_usage,
3908
4578
  hotswap=hotswap,
@@ -3918,9 +4588,10 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3918
4588
  weight_name: str = None,
3919
4589
  save_function: Callable = None,
3920
4590
  safe_serialization: bool = True,
4591
+ transformer_lora_adapter_metadata: Optional[dict] = None,
3921
4592
  ):
3922
4593
  r"""
3923
- Save the LoRA parameters corresponding to the UNet and text encoder.
4594
+ Save the LoRA parameters corresponding to the transformer.
3924
4595
 
3925
4596
  Arguments:
3926
4597
  save_directory (`str` or `os.PathLike`):
@@ -3937,14 +4608,21 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3937
4608
  `DIFFUSERS_SAVE_MODE`.
3938
4609
  safe_serialization (`bool`, *optional*, defaults to `True`):
3939
4610
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4611
+ transformer_lora_adapter_metadata:
4612
+ LoRA adapter metadata associated with the transformer to be serialized with the state dict.
3940
4613
  """
3941
4614
  state_dict = {}
4615
+ lora_adapter_metadata = {}
3942
4616
 
3943
4617
  if not transformer_lora_layers:
3944
4618
  raise ValueError("You must pass `transformer_lora_layers`.")
3945
4619
 
3946
- if transformer_lora_layers:
3947
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
4620
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
4621
+
4622
+ if transformer_lora_adapter_metadata is not None:
4623
+ lora_adapter_metadata.update(
4624
+ _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
4625
+ )
3948
4626
 
3949
4627
  # Save the model
3950
4628
  cls.write_lora_layers(
@@ -3954,6 +4632,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
3954
4632
  weight_name=weight_name,
3955
4633
  save_function=save_function,
3956
4634
  safe_serialization=safe_serialization,
4635
+ lora_adapter_metadata=lora_adapter_metadata,
3957
4636
  )
3958
4637
 
3959
4638
  # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -4023,9 +4702,9 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
4023
4702
  super().unfuse_lora(components=components, **kwargs)
4024
4703
 
4025
4704
 
4026
- class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4705
+ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4027
4706
  r"""
4028
- Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
4707
+ Load LoRA layers into [`Lumina2Transformer2DModel`]. Specific to [`Lumina2Text2ImgPipeline`].
4029
4708
  """
4030
4709
 
4031
4710
  _lora_loadable_modules = ["transformer"]
@@ -4043,7 +4722,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4043
4722
 
4044
4723
  <Tip warning={true}>
4045
4724
 
4046
- We support loading original format HunyuanVideo LoRA checkpoints.
4725
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
4047
4726
 
4048
4727
  This function is experimental and might change in the future.
4049
4728
 
@@ -4081,7 +4760,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4081
4760
  allowed by Git.
4082
4761
  subfolder (`str`, *optional*, defaults to `""`):
4083
4762
  The subfolder location of a model file within a larger model repository on the Hub or locally.
4084
-
4763
+ return_lora_metadata (`bool`, *optional*, defaults to False):
4764
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
4085
4765
  """
4086
4766
  # Load the main state dict first which has the LoRA layers for either of
4087
4767
  # transformer and text encoder or both.
@@ -4094,18 +4774,16 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4094
4774
  subfolder = kwargs.pop("subfolder", None)
4095
4775
  weight_name = kwargs.pop("weight_name", None)
4096
4776
  use_safetensors = kwargs.pop("use_safetensors", None)
4777
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
4097
4778
 
4098
4779
  allow_pickle = False
4099
4780
  if use_safetensors is None:
4100
4781
  use_safetensors = True
4101
4782
  allow_pickle = True
4102
4783
 
4103
- user_agent = {
4104
- "file_type": "attn_procs_weights",
4105
- "framework": "pytorch",
4106
- }
4784
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
4107
4785
 
4108
- state_dict = _fetch_state_dict(
4786
+ state_dict, metadata = _fetch_state_dict(
4109
4787
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
4110
4788
  weight_name=weight_name,
4111
4789
  use_safetensors=use_safetensors,
@@ -4126,15 +4804,21 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4126
4804
  logger.warning(warn_msg)
4127
4805
  state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
4128
4806
 
4129
- is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
4130
- if is_original_hunyuan_video:
4131
- state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
4807
+ # conversion.
4808
+ non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
4809
+ if non_diffusers:
4810
+ state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
4132
4811
 
4133
- return state_dict
4812
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
4813
+ return out
4134
4814
 
4135
4815
  # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
4136
4816
  def load_lora_weights(
4137
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
4817
+ self,
4818
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
4819
+ adapter_name: Optional[str] = None,
4820
+ hotswap: bool = False,
4821
+ **kwargs,
4138
4822
  ):
4139
4823
  """
4140
4824
  Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
@@ -4152,6 +4836,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4152
4836
  low_cpu_mem_usage (`bool`, *optional*):
4153
4837
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4154
4838
  weights.
4839
+ hotswap (`bool`, *optional*):
4840
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
4155
4841
  kwargs (`dict`, *optional*):
4156
4842
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4157
4843
  """
@@ -4169,7 +4855,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4169
4855
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
4170
4856
 
4171
4857
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4172
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4858
+ kwargs["return_lora_metadata"] = True
4859
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4173
4860
 
4174
4861
  is_correct_format = all("lora" in key for key in state_dict.keys())
4175
4862
  if not is_correct_format:
@@ -4179,54 +4866,45 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4179
4866
  state_dict,
4180
4867
  transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4181
4868
  adapter_name=adapter_name,
4869
+ metadata=metadata,
4182
4870
  _pipeline=self,
4183
4871
  low_cpu_mem_usage=low_cpu_mem_usage,
4872
+ hotswap=hotswap,
4184
4873
  )
4185
4874
 
4186
4875
  @classmethod
4187
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
4876
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
4188
4877
  def load_lora_into_transformer(
4189
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
4878
+ cls,
4879
+ state_dict,
4880
+ transformer,
4881
+ adapter_name=None,
4882
+ _pipeline=None,
4883
+ low_cpu_mem_usage=False,
4884
+ hotswap: bool = False,
4885
+ metadata=None,
4190
4886
  ):
4191
4887
  """
4192
- This will load the LoRA layers specified in `state_dict` into `transformer`.
4193
-
4194
- Parameters:
4195
- state_dict (`dict`):
4196
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
4197
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
4198
- encoder lora layers.
4199
- transformer (`HunyuanVideoTransformer3DModel`):
4200
- The Transformer model to load the LoRA layers into.
4201
- adapter_name (`str`, *optional*):
4202
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4203
- `default_{i}` where i is the total number of adapters being loaded.
4204
- low_cpu_mem_usage (`bool`, *optional*):
4205
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4206
- weights.
4207
- hotswap : (`bool`, *optional*)
4208
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
4209
- in-place. This means that, instead of loading an additional adapter, this will take the existing
4210
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
4211
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
4212
- torch.compile, loading the new adapter does not require recompilation of the model. When using
4213
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
4214
-
4215
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
4216
- to call an additional method before loading the adapter:
4217
-
4218
- ```py
4219
- pipeline = ... # load diffusers pipeline
4220
- max_rank = ... # the highest rank among all LoRAs that you want to load
4221
- # call *before* compiling and loading the LoRA adapter
4222
- pipeline.enable_lora_hotswap(target_rank=max_rank)
4223
- pipeline.load_lora_weights(file_name)
4224
- # optionally compile the model now
4225
- ```
4226
-
4227
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
4228
- limitations to this technique, which are documented here:
4229
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
4888
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
4889
+
4890
+ Parameters:
4891
+ state_dict (`dict`):
4892
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
4893
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
4894
+ encoder lora layers.
4895
+ transformer (`Lumina2Transformer2DModel`):
4896
+ The Transformer model to load the LoRA layers into.
4897
+ adapter_name (`str`, *optional*):
4898
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4899
+ `default_{i}` where i is the total number of adapters being loaded.
4900
+ low_cpu_mem_usage (`bool`, *optional*):
4901
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4902
+ weights.
4903
+ hotswap (`bool`, *optional*):
4904
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
4905
+ metadata (`dict`):
4906
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
4907
+ from the state dict.
4230
4908
  """
4231
4909
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4232
4910
  raise ValueError(
@@ -4239,6 +4917,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4239
4917
  state_dict,
4240
4918
  network_alphas=None,
4241
4919
  adapter_name=adapter_name,
4920
+ metadata=metadata,
4242
4921
  _pipeline=_pipeline,
4243
4922
  low_cpu_mem_usage=low_cpu_mem_usage,
4244
4923
  hotswap=hotswap,
@@ -4254,9 +4933,10 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4254
4933
  weight_name: str = None,
4255
4934
  save_function: Callable = None,
4256
4935
  safe_serialization: bool = True,
4936
+ transformer_lora_adapter_metadata: Optional[dict] = None,
4257
4937
  ):
4258
4938
  r"""
4259
- Save the LoRA parameters corresponding to the UNet and text encoder.
4939
+ Save the LoRA parameters corresponding to the transformer.
4260
4940
 
4261
4941
  Arguments:
4262
4942
  save_directory (`str` or `os.PathLike`):
@@ -4273,14 +4953,21 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4273
4953
  `DIFFUSERS_SAVE_MODE`.
4274
4954
  safe_serialization (`bool`, *optional*, defaults to `True`):
4275
4955
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4956
+ transformer_lora_adapter_metadata:
4957
+ LoRA adapter metadata associated with the transformer to be serialized with the state dict.
4276
4958
  """
4277
4959
  state_dict = {}
4960
+ lora_adapter_metadata = {}
4278
4961
 
4279
4962
  if not transformer_lora_layers:
4280
4963
  raise ValueError("You must pass `transformer_lora_layers`.")
4281
4964
 
4282
- if transformer_lora_layers:
4283
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
4965
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
4966
+
4967
+ if transformer_lora_adapter_metadata is not None:
4968
+ lora_adapter_metadata.update(
4969
+ _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
4970
+ )
4284
4971
 
4285
4972
  # Save the model
4286
4973
  cls.write_lora_layers(
@@ -4290,9 +4977,10 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4290
4977
  weight_name=weight_name,
4291
4978
  save_function=save_function,
4292
4979
  safe_serialization=safe_serialization,
4980
+ lora_adapter_metadata=lora_adapter_metadata,
4293
4981
  )
4294
4982
 
4295
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
4983
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
4296
4984
  def fuse_lora(
4297
4985
  self,
4298
4986
  components: List[str] = ["transformer"],
@@ -4340,7 +5028,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4340
5028
  **kwargs,
4341
5029
  )
4342
5030
 
4343
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
5031
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
4344
5032
  def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
4345
5033
  r"""
4346
5034
  Reverses the effect of
@@ -4359,9 +5047,9 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
4359
5047
  super().unfuse_lora(components=components, **kwargs)
4360
5048
 
4361
5049
 
4362
- class Lumina2LoraLoaderMixin(LoraBaseMixin):
5050
+ class WanLoraLoaderMixin(LoraBaseMixin):
4363
5051
  r"""
4364
- Load LoRA layers into [`Lumina2Transformer2DModel`]. Specific to [`Lumina2Text2ImgPipeline`].
5052
+ Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
4365
5053
  """
4366
5054
 
4367
5055
  _lora_loadable_modules = ["transformer"]
@@ -4417,7 +5105,8 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4417
5105
  allowed by Git.
4418
5106
  subfolder (`str`, *optional*, defaults to `""`):
4419
5107
  The subfolder location of a model file within a larger model repository on the Hub or locally.
4420
-
5108
+ return_lora_metadata (`bool`, *optional*, defaults to False):
5109
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
4421
5110
  """
4422
5111
  # Load the main state dict first which has the LoRA layers for either of
4423
5112
  # transformer and text encoder or both.
@@ -4430,18 +5119,16 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4430
5119
  subfolder = kwargs.pop("subfolder", None)
4431
5120
  weight_name = kwargs.pop("weight_name", None)
4432
5121
  use_safetensors = kwargs.pop("use_safetensors", None)
5122
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
4433
5123
 
4434
5124
  allow_pickle = False
4435
5125
  if use_safetensors is None:
4436
5126
  use_safetensors = True
4437
5127
  allow_pickle = True
4438
5128
 
4439
- user_agent = {
4440
- "file_type": "attn_procs_weights",
4441
- "framework": "pytorch",
4442
- }
5129
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
4443
5130
 
4444
- state_dict = _fetch_state_dict(
5131
+ state_dict, metadata = _fetch_state_dict(
4445
5132
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
4446
5133
  weight_name=weight_name,
4447
5134
  use_safetensors=use_safetensors,
@@ -4455,6 +5142,10 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4455
5142
  user_agent=user_agent,
4456
5143
  allow_pickle=allow_pickle,
4457
5144
  )
5145
+ if any(k.startswith("diffusion_model.") for k in state_dict):
5146
+ state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
5147
+ elif any(k.startswith("lora_unet_") for k in state_dict):
5148
+ state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
4458
5149
 
4459
5150
  is_dora_scale_present = any("dora_scale" in k for k in state_dict)
4460
5151
  if is_dora_scale_present:
@@ -4462,16 +5153,63 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4462
5153
  logger.warning(warn_msg)
4463
5154
  state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
4464
5155
 
4465
- # conversion.
4466
- non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
4467
- if non_diffusers:
4468
- state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
5156
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
5157
+ return out
5158
+
5159
+ @classmethod
5160
+ def _maybe_expand_t2v_lora_for_i2v(
5161
+ cls,
5162
+ transformer: torch.nn.Module,
5163
+ state_dict,
5164
+ ):
5165
+ if transformer.config.image_dim is None:
5166
+ return state_dict
5167
+
5168
+ target_device = transformer.device
5169
+
5170
+ if any(k.startswith("transformer.blocks.") for k in state_dict):
5171
+ num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
5172
+ is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
5173
+ has_bias = any(".lora_B.bias" in k for k in state_dict)
5174
+
5175
+ if is_i2v_lora:
5176
+ return state_dict
5177
+
5178
+ for i in range(num_blocks):
5179
+ for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
5180
+ # These keys should exist if the block `i` was part of the T2V LoRA.
5181
+ ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
5182
+ ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"
5183
+
5184
+ if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict:
5185
+ continue
5186
+
5187
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
5188
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
5189
+ )
5190
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
5191
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
5192
+ )
5193
+
5194
+ # If the original LoRA had biases (indicated by has_bias)
5195
+ # AND the specific reference bias key exists for this block.
5196
+
5197
+ ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias"
5198
+ if has_bias and ref_key_lora_B_bias in state_dict:
5199
+ ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias]
5200
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like(
5201
+ ref_lora_B_bias_tensor,
5202
+ device=target_device,
5203
+ )
4469
5204
 
4470
5205
  return state_dict
4471
5206
 
4472
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
4473
5207
  def load_lora_weights(
4474
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
5208
+ self,
5209
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
5210
+ adapter_name: Optional[str] = None,
5211
+ hotswap: bool = False,
5212
+ **kwargs,
4475
5213
  ):
4476
5214
  """
4477
5215
  Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
@@ -4489,6 +5227,8 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4489
5227
  low_cpu_mem_usage (`bool`, *optional*):
4490
5228
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4491
5229
  weights.
5230
+ hotswap (`bool`, *optional*):
5231
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
4492
5232
  kwargs (`dict`, *optional*):
4493
5233
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4494
5234
  """
@@ -4506,8 +5246,13 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4506
5246
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
4507
5247
 
4508
5248
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4509
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4510
-
5249
+ kwargs["return_lora_metadata"] = True
5250
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
5251
+ # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
5252
+ state_dict = self._maybe_expand_t2v_lora_for_i2v(
5253
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5254
+ state_dict=state_dict,
5255
+ )
4511
5256
  is_correct_format = all("lora" in key for key in state_dict.keys())
4512
5257
  if not is_correct_format:
4513
5258
  raise ValueError("Invalid LoRA checkpoint.")
@@ -4516,14 +5261,23 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4516
5261
  state_dict,
4517
5262
  transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4518
5263
  adapter_name=adapter_name,
5264
+ metadata=metadata,
4519
5265
  _pipeline=self,
4520
5266
  low_cpu_mem_usage=low_cpu_mem_usage,
5267
+ hotswap=hotswap,
4521
5268
  )
4522
5269
 
4523
5270
  @classmethod
4524
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
5271
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
4525
5272
  def load_lora_into_transformer(
4526
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
5273
+ cls,
5274
+ state_dict,
5275
+ transformer,
5276
+ adapter_name=None,
5277
+ _pipeline=None,
5278
+ low_cpu_mem_usage=False,
5279
+ hotswap: bool = False,
5280
+ metadata=None,
4527
5281
  ):
4528
5282
  """
4529
5283
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -4533,7 +5287,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4533
5287
  A standard state dict containing the lora layer parameters. The keys can either be indexed directly
4534
5288
  into the unet or prefixed with an additional `unet` which can be used to distinguish between text
4535
5289
  encoder lora layers.
4536
- transformer (`Lumina2Transformer2DModel`):
5290
+ transformer (`WanTransformer3DModel`):
4537
5291
  The Transformer model to load the LoRA layers into.
4538
5292
  adapter_name (`str`, *optional*):
4539
5293
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
@@ -4541,29 +5295,11 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4541
5295
  low_cpu_mem_usage (`bool`, *optional*):
4542
5296
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4543
5297
  weights.
4544
- hotswap : (`bool`, *optional*)
4545
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
4546
- in-place. This means that, instead of loading an additional adapter, this will take the existing
4547
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
4548
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
4549
- torch.compile, loading the new adapter does not require recompilation of the model. When using
4550
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
4551
-
4552
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
4553
- to call an additional method before loading the adapter:
4554
-
4555
- ```py
4556
- pipeline = ... # load diffusers pipeline
4557
- max_rank = ... # the highest rank among all LoRAs that you want to load
4558
- # call *before* compiling and loading the LoRA adapter
4559
- pipeline.enable_lora_hotswap(target_rank=max_rank)
4560
- pipeline.load_lora_weights(file_name)
4561
- # optionally compile the model now
4562
- ```
4563
-
4564
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
4565
- limitations to this technique, which are documented here:
4566
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
5298
+ hotswap (`bool`, *optional*):
5299
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
5300
+ metadata (`dict`):
5301
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
5302
+ from the state dict.
4567
5303
  """
4568
5304
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4569
5305
  raise ValueError(
@@ -4576,6 +5312,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4576
5312
  state_dict,
4577
5313
  network_alphas=None,
4578
5314
  adapter_name=adapter_name,
5315
+ metadata=metadata,
4579
5316
  _pipeline=_pipeline,
4580
5317
  low_cpu_mem_usage=low_cpu_mem_usage,
4581
5318
  hotswap=hotswap,
@@ -4591,9 +5328,10 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4591
5328
  weight_name: str = None,
4592
5329
  save_function: Callable = None,
4593
5330
  safe_serialization: bool = True,
5331
+ transformer_lora_adapter_metadata: Optional[dict] = None,
4594
5332
  ):
4595
5333
  r"""
4596
- Save the LoRA parameters corresponding to the UNet and text encoder.
5334
+ Save the LoRA parameters corresponding to the transformer.
4597
5335
 
4598
5336
  Arguments:
4599
5337
  save_directory (`str` or `os.PathLike`):
@@ -4610,14 +5348,21 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4610
5348
  `DIFFUSERS_SAVE_MODE`.
4611
5349
  safe_serialization (`bool`, *optional*, defaults to `True`):
4612
5350
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
5351
+ transformer_lora_adapter_metadata:
5352
+ LoRA adapter metadata associated with the transformer to be serialized with the state dict.
4613
5353
  """
4614
5354
  state_dict = {}
5355
+ lora_adapter_metadata = {}
4615
5356
 
4616
5357
  if not transformer_lora_layers:
4617
5358
  raise ValueError("You must pass `transformer_lora_layers`.")
4618
5359
 
4619
- if transformer_lora_layers:
4620
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
5360
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
5361
+
5362
+ if transformer_lora_adapter_metadata is not None:
5363
+ lora_adapter_metadata.update(
5364
+ _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
5365
+ )
4621
5366
 
4622
5367
  # Save the model
4623
5368
  cls.write_lora_layers(
@@ -4627,9 +5372,10 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4627
5372
  weight_name=weight_name,
4628
5373
  save_function=save_function,
4629
5374
  safe_serialization=safe_serialization,
5375
+ lora_adapter_metadata=lora_adapter_metadata,
4630
5376
  )
4631
5377
 
4632
- # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
5378
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
4633
5379
  def fuse_lora(
4634
5380
  self,
4635
5381
  components: List[str] = ["transformer"],
@@ -4677,7 +5423,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4677
5423
  **kwargs,
4678
5424
  )
4679
5425
 
4680
- # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
5426
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
4681
5427
  def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
4682
5428
  r"""
4683
5429
  Reverses the effect of
@@ -4696,9 +5442,9 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
4696
5442
  super().unfuse_lora(components=components, **kwargs)
4697
5443
 
4698
5444
 
4699
- class WanLoraLoaderMixin(LoraBaseMixin):
5445
+ class CogView4LoraLoaderMixin(LoraBaseMixin):
4700
5446
  r"""
4701
- Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
5447
+ Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
4702
5448
  """
4703
5449
 
4704
5450
  _lora_loadable_modules = ["transformer"]
@@ -4706,6 +5452,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
4706
5452
 
4707
5453
  @classmethod
4708
5454
  @validate_hf_hub_args
5455
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
4709
5456
  def lora_state_dict(
4710
5457
  cls,
4711
5458
  pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -4754,6 +5501,8 @@ class WanLoraLoaderMixin(LoraBaseMixin):
4754
5501
  allowed by Git.
4755
5502
  subfolder (`str`, *optional*, defaults to `""`):
4756
5503
  The subfolder location of a model file within a larger model repository on the Hub or locally.
5504
+ return_lora_metadata (`bool`, *optional*, defaults to False):
5505
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
4757
5506
 
4758
5507
  """
4759
5508
  # Load the main state dict first which has the LoRA layers for either of
@@ -4767,18 +5516,16 @@ class WanLoraLoaderMixin(LoraBaseMixin):
4767
5516
  subfolder = kwargs.pop("subfolder", None)
4768
5517
  weight_name = kwargs.pop("weight_name", None)
4769
5518
  use_safetensors = kwargs.pop("use_safetensors", None)
5519
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
4770
5520
 
4771
5521
  allow_pickle = False
4772
5522
  if use_safetensors is None:
4773
5523
  use_safetensors = True
4774
5524
  allow_pickle = True
4775
5525
 
4776
- user_agent = {
4777
- "file_type": "attn_procs_weights",
4778
- "framework": "pytorch",
4779
- }
5526
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
4780
5527
 
4781
- state_dict = _fetch_state_dict(
5528
+ state_dict, metadata = _fetch_state_dict(
4782
5529
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
4783
5530
  weight_name=weight_name,
4784
5531
  use_safetensors=use_safetensors,
@@ -4792,8 +5539,6 @@ class WanLoraLoaderMixin(LoraBaseMixin):
4792
5539
  user_agent=user_agent,
4793
5540
  allow_pickle=allow_pickle,
4794
5541
  )
4795
- if any(k.startswith("diffusion_model.") for k in state_dict):
4796
- state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
4797
5542
 
4798
5543
  is_dora_scale_present = any("dora_scale" in k for k in state_dict)
4799
5544
  if is_dora_scale_present:
@@ -4801,37 +5546,16 @@ class WanLoraLoaderMixin(LoraBaseMixin):
4801
5546
  logger.warning(warn_msg)
4802
5547
  state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
4803
5548
 
4804
- return state_dict
4805
-
4806
- @classmethod
4807
- def _maybe_expand_t2v_lora_for_i2v(
4808
- cls,
4809
- transformer: torch.nn.Module,
4810
- state_dict,
4811
- ):
4812
- if transformer.config.image_dim is None:
4813
- return state_dict
4814
-
4815
- if any(k.startswith("transformer.blocks.") for k in state_dict):
4816
- num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
4817
- is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
4818
-
4819
- if is_i2v_lora:
4820
- return state_dict
4821
-
4822
- for i in range(num_blocks):
4823
- for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
4824
- state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
4825
- state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
4826
- )
4827
- state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
4828
- state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
4829
- )
4830
-
4831
- return state_dict
5549
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
5550
+ return out
4832
5551
 
5552
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
4833
5553
  def load_lora_weights(
4834
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
5554
+ self,
5555
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
5556
+ adapter_name: Optional[str] = None,
5557
+ hotswap: bool = False,
5558
+ **kwargs,
4835
5559
  ):
4836
5560
  """
4837
5561
  Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
@@ -4849,6 +5573,8 @@ class WanLoraLoaderMixin(LoraBaseMixin):
4849
5573
  low_cpu_mem_usage (`bool`, *optional*):
4850
5574
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4851
5575
  weights.
5576
+ hotswap (`bool`, *optional*):
5577
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
4852
5578
  kwargs (`dict`, *optional*):
4853
5579
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4854
5580
  """
@@ -4866,12 +5592,9 @@ class WanLoraLoaderMixin(LoraBaseMixin):
4866
5592
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
4867
5593
 
4868
5594
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4869
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4870
- # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
4871
- state_dict = self._maybe_expand_t2v_lora_for_i2v(
4872
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4873
- state_dict=state_dict,
4874
- )
5595
+ kwargs["return_lora_metadata"] = True
5596
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
5597
+
4875
5598
  is_correct_format = all("lora" in key for key in state_dict.keys())
4876
5599
  if not is_correct_format:
4877
5600
  raise ValueError("Invalid LoRA checkpoint.")
@@ -4880,14 +5603,23 @@ class WanLoraLoaderMixin(LoraBaseMixin):
4880
5603
  state_dict,
4881
5604
  transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4882
5605
  adapter_name=adapter_name,
5606
+ metadata=metadata,
4883
5607
  _pipeline=self,
4884
5608
  low_cpu_mem_usage=low_cpu_mem_usage,
5609
+ hotswap=hotswap,
4885
5610
  )
4886
5611
 
4887
5612
  @classmethod
4888
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
5613
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
4889
5614
  def load_lora_into_transformer(
4890
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
5615
+ cls,
5616
+ state_dict,
5617
+ transformer,
5618
+ adapter_name=None,
5619
+ _pipeline=None,
5620
+ low_cpu_mem_usage=False,
5621
+ hotswap: bool = False,
5622
+ metadata=None,
4891
5623
  ):
4892
5624
  """
4893
5625
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -4897,7 +5629,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
4897
5629
  A standard state dict containing the lora layer parameters. The keys can either be indexed directly
4898
5630
  into the unet or prefixed with an additional `unet` which can be used to distinguish between text
4899
5631
  encoder lora layers.
4900
- transformer (`WanTransformer3DModel`):
5632
+ transformer (`CogView4Transformer2DModel`):
4901
5633
  The Transformer model to load the LoRA layers into.
4902
5634
  adapter_name (`str`, *optional*):
4903
5635
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
@@ -4905,29 +5637,11 @@ class WanLoraLoaderMixin(LoraBaseMixin):
4905
5637
  low_cpu_mem_usage (`bool`, *optional*):
4906
5638
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4907
5639
  weights.
4908
- hotswap : (`bool`, *optional*)
4909
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
4910
- in-place. This means that, instead of loading an additional adapter, this will take the existing
4911
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
4912
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
4913
- torch.compile, loading the new adapter does not require recompilation of the model. When using
4914
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
4915
-
4916
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
4917
- to call an additional method before loading the adapter:
4918
-
4919
- ```py
4920
- pipeline = ... # load diffusers pipeline
4921
- max_rank = ... # the highest rank among all LoRAs that you want to load
4922
- # call *before* compiling and loading the LoRA adapter
4923
- pipeline.enable_lora_hotswap(target_rank=max_rank)
4924
- pipeline.load_lora_weights(file_name)
4925
- # optionally compile the model now
4926
- ```
4927
-
4928
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
4929
- limitations to this technique, which are documented here:
4930
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
5640
+ hotswap (`bool`, *optional*):
5641
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
5642
+ metadata (`dict`):
5643
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
5644
+ from the state dict.
4931
5645
  """
4932
5646
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4933
5647
  raise ValueError(
@@ -4940,6 +5654,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
4940
5654
  state_dict,
4941
5655
  network_alphas=None,
4942
5656
  adapter_name=adapter_name,
5657
+ metadata=metadata,
4943
5658
  _pipeline=_pipeline,
4944
5659
  low_cpu_mem_usage=low_cpu_mem_usage,
4945
5660
  hotswap=hotswap,
@@ -4955,9 +5670,10 @@ class WanLoraLoaderMixin(LoraBaseMixin):
4955
5670
  weight_name: str = None,
4956
5671
  save_function: Callable = None,
4957
5672
  safe_serialization: bool = True,
5673
+ transformer_lora_adapter_metadata: Optional[dict] = None,
4958
5674
  ):
4959
5675
  r"""
4960
- Save the LoRA parameters corresponding to the UNet and text encoder.
5676
+ Save the LoRA parameters corresponding to the transformer.
4961
5677
 
4962
5678
  Arguments:
4963
5679
  save_directory (`str` or `os.PathLike`):
@@ -4974,14 +5690,21 @@ class WanLoraLoaderMixin(LoraBaseMixin):
4974
5690
  `DIFFUSERS_SAVE_MODE`.
4975
5691
  safe_serialization (`bool`, *optional*, defaults to `True`):
4976
5692
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
5693
+ transformer_lora_adapter_metadata:
5694
+ LoRA adapter metadata associated with the transformer to be serialized with the state dict.
4977
5695
  """
4978
5696
  state_dict = {}
5697
+ lora_adapter_metadata = {}
4979
5698
 
4980
5699
  if not transformer_lora_layers:
4981
5700
  raise ValueError("You must pass `transformer_lora_layers`.")
4982
5701
 
4983
- if transformer_lora_layers:
4984
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
5702
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
5703
+
5704
+ if transformer_lora_adapter_metadata is not None:
5705
+ lora_adapter_metadata.update(
5706
+ _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
5707
+ )
4985
5708
 
4986
5709
  # Save the model
4987
5710
  cls.write_lora_layers(
@@ -4991,6 +5714,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
4991
5714
  weight_name=weight_name,
4992
5715
  save_function=save_function,
4993
5716
  safe_serialization=safe_serialization,
5717
+ lora_adapter_metadata=lora_adapter_metadata,
4994
5718
  )
4995
5719
 
4996
5720
  # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -5060,9 +5784,9 @@ class WanLoraLoaderMixin(LoraBaseMixin):
5060
5784
  super().unfuse_lora(components=components, **kwargs)
5061
5785
 
5062
5786
 
5063
- class CogView4LoraLoaderMixin(LoraBaseMixin):
5787
+ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
5064
5788
  r"""
5065
- Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
5789
+ Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`].
5066
5790
  """
5067
5791
 
5068
5792
  _lora_loadable_modules = ["transformer"]
@@ -5070,7 +5794,6 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5070
5794
 
5071
5795
  @classmethod
5072
5796
  @validate_hf_hub_args
5073
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
5074
5797
  def lora_state_dict(
5075
5798
  cls,
5076
5799
  pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -5119,7 +5842,8 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5119
5842
  allowed by Git.
5120
5843
  subfolder (`str`, *optional*, defaults to `""`):
5121
5844
  The subfolder location of a model file within a larger model repository on the Hub or locally.
5122
-
5845
+ return_lora_metadata (`bool`, *optional*, defaults to False):
5846
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
5123
5847
  """
5124
5848
  # Load the main state dict first which has the LoRA layers for either of
5125
5849
  # transformer and text encoder or both.
@@ -5132,18 +5856,16 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5132
5856
  subfolder = kwargs.pop("subfolder", None)
5133
5857
  weight_name = kwargs.pop("weight_name", None)
5134
5858
  use_safetensors = kwargs.pop("use_safetensors", None)
5859
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
5135
5860
 
5136
5861
  allow_pickle = False
5137
5862
  if use_safetensors is None:
5138
5863
  use_safetensors = True
5139
5864
  allow_pickle = True
5140
5865
 
5141
- user_agent = {
5142
- "file_type": "attn_procs_weights",
5143
- "framework": "pytorch",
5144
- }
5866
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
5145
5867
 
5146
- state_dict = _fetch_state_dict(
5868
+ state_dict, metadata = _fetch_state_dict(
5147
5869
  pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
5148
5870
  weight_name=weight_name,
5149
5871
  use_safetensors=use_safetensors,
@@ -5164,11 +5886,20 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5164
5886
  logger.warning(warn_msg)
5165
5887
  state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
5166
5888
 
5167
- return state_dict
5889
+ is_non_diffusers_format = any("diffusion_model" in k for k in state_dict)
5890
+ if is_non_diffusers_format:
5891
+ state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
5892
+
5893
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
5894
+ return out
5168
5895
 
5169
5896
  # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
5170
5897
  def load_lora_weights(
5171
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
5898
+ self,
5899
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
5900
+ adapter_name: Optional[str] = None,
5901
+ hotswap: bool = False,
5902
+ **kwargs,
5172
5903
  ):
5173
5904
  """
5174
5905
  Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
@@ -5186,6 +5917,8 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5186
5917
  low_cpu_mem_usage (`bool`, *optional*):
5187
5918
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
5188
5919
  weights.
5920
+ hotswap (`bool`, *optional*):
5921
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
5189
5922
  kwargs (`dict`, *optional*):
5190
5923
  See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
5191
5924
  """
@@ -5203,7 +5936,8 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5203
5936
  pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
5204
5937
 
5205
5938
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
5206
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
5939
+ kwargs["return_lora_metadata"] = True
5940
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
5207
5941
 
5208
5942
  is_correct_format = all("lora" in key for key in state_dict.keys())
5209
5943
  if not is_correct_format:
@@ -5213,14 +5947,23 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5213
5947
  state_dict,
5214
5948
  transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5215
5949
  adapter_name=adapter_name,
5950
+ metadata=metadata,
5216
5951
  _pipeline=self,
5217
5952
  low_cpu_mem_usage=low_cpu_mem_usage,
5953
+ hotswap=hotswap,
5218
5954
  )
5219
5955
 
5220
5956
  @classmethod
5221
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
5957
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel
5222
5958
  def load_lora_into_transformer(
5223
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
5959
+ cls,
5960
+ state_dict,
5961
+ transformer,
5962
+ adapter_name=None,
5963
+ _pipeline=None,
5964
+ low_cpu_mem_usage=False,
5965
+ hotswap: bool = False,
5966
+ metadata=None,
5224
5967
  ):
5225
5968
  """
5226
5969
  This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -5230,7 +5973,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5230
5973
  A standard state dict containing the lora layer parameters. The keys can either be indexed directly
5231
5974
  into the unet or prefixed with an additional `unet` which can be used to distinguish between text
5232
5975
  encoder lora layers.
5233
- transformer (`CogView4Transformer2DModel`):
5976
+ transformer (`HiDreamImageTransformer2DModel`):
5234
5977
  The Transformer model to load the LoRA layers into.
5235
5978
  adapter_name (`str`, *optional*):
5236
5979
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
@@ -5238,29 +5981,11 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5238
5981
  low_cpu_mem_usage (`bool`, *optional*):
5239
5982
  Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
5240
5983
  weights.
5241
- hotswap : (`bool`, *optional*)
5242
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
5243
- in-place. This means that, instead of loading an additional adapter, this will take the existing
5244
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
5245
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
5246
- torch.compile, loading the new adapter does not require recompilation of the model. When using
5247
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
5248
-
5249
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
5250
- to call an additional method before loading the adapter:
5251
-
5252
- ```py
5253
- pipeline = ... # load diffusers pipeline
5254
- max_rank = ... # the highest rank among all LoRAs that you want to load
5255
- # call *before* compiling and loading the LoRA adapter
5256
- pipeline.enable_lora_hotswap(target_rank=max_rank)
5257
- pipeline.load_lora_weights(file_name)
5258
- # optionally compile the model now
5259
- ```
5260
-
5261
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
5262
- limitations to this technique, which are documented here:
5263
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
5984
+ hotswap (`bool`, *optional*):
5985
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
5986
+ metadata (`dict`):
5987
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
5988
+ from the state dict.
5264
5989
  """
5265
5990
  if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
5266
5991
  raise ValueError(
@@ -5273,6 +5998,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5273
5998
  state_dict,
5274
5999
  network_alphas=None,
5275
6000
  adapter_name=adapter_name,
6001
+ metadata=metadata,
5276
6002
  _pipeline=_pipeline,
5277
6003
  low_cpu_mem_usage=low_cpu_mem_usage,
5278
6004
  hotswap=hotswap,
@@ -5288,9 +6014,10 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5288
6014
  weight_name: str = None,
5289
6015
  save_function: Callable = None,
5290
6016
  safe_serialization: bool = True,
6017
+ transformer_lora_adapter_metadata: Optional[dict] = None,
5291
6018
  ):
5292
6019
  r"""
5293
- Save the LoRA parameters corresponding to the UNet and text encoder.
6020
+ Save the LoRA parameters corresponding to the transformer.
5294
6021
 
5295
6022
  Arguments:
5296
6023
  save_directory (`str` or `os.PathLike`):
@@ -5307,14 +6034,21 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5307
6034
  `DIFFUSERS_SAVE_MODE`.
5308
6035
  safe_serialization (`bool`, *optional*, defaults to `True`):
5309
6036
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
6037
+ transformer_lora_adapter_metadata:
6038
+ LoRA adapter metadata associated with the transformer to be serialized with the state dict.
5310
6039
  """
5311
6040
  state_dict = {}
6041
+ lora_adapter_metadata = {}
5312
6042
 
5313
6043
  if not transformer_lora_layers:
5314
6044
  raise ValueError("You must pass `transformer_lora_layers`.")
5315
6045
 
5316
- if transformer_lora_layers:
5317
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
6046
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
6047
+
6048
+ if transformer_lora_adapter_metadata is not None:
6049
+ lora_adapter_metadata.update(
6050
+ _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
6051
+ )
5318
6052
 
5319
6053
  # Save the model
5320
6054
  cls.write_lora_layers(
@@ -5324,9 +6058,10 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5324
6058
  weight_name=weight_name,
5325
6059
  save_function=save_function,
5326
6060
  safe_serialization=safe_serialization,
6061
+ lora_adapter_metadata=lora_adapter_metadata,
5327
6062
  )
5328
6063
 
5329
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
6064
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
5330
6065
  def fuse_lora(
5331
6066
  self,
5332
6067
  components: List[str] = ["transformer"],
@@ -5374,7 +6109,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5374
6109
  **kwargs,
5375
6110
  )
5376
6111
 
5377
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
6112
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
5378
6113
  def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
5379
6114
  r"""
5380
6115
  Reverses the effect of