diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -33,6 +33,24 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
33
33
  # 1. get all state_dict_keys
34
34
  all_keys = list(state_dict.keys())
35
35
  sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
36
+ not_sgm_patterns = ["down_blocks", "mid_block", "up_blocks"]
37
+
38
+ # check if state_dict contains both patterns
39
+ contains_sgm_patterns = False
40
+ contains_not_sgm_patterns = False
41
+ for key in all_keys:
42
+ if any(p in key for p in sgm_patterns):
43
+ contains_sgm_patterns = True
44
+ elif any(p in key for p in not_sgm_patterns):
45
+ contains_not_sgm_patterns = True
46
+
47
+ # if state_dict contains both patterns, remove sgm
48
+ # we can then return state_dict immediately
49
+ if contains_sgm_patterns and contains_not_sgm_patterns:
50
+ for key in all_keys:
51
+ if any(p in key for p in sgm_patterns):
52
+ state_dict.pop(key)
53
+ return state_dict
36
54
 
37
55
  # 2. check if needs remapping, if not return original dict
38
56
  is_in_sgm_format = False
@@ -126,7 +144,7 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
126
144
  )
127
145
  new_state_dict[new_key] = state_dict.pop(key)
128
146
 
129
- if len(state_dict) > 0:
147
+ if state_dict:
130
148
  raise ValueError("At this point all state dict entries have to be converted.")
131
149
 
132
150
  return new_state_dict
@@ -415,7 +433,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
415
433
  ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
416
434
  if not is_sparse:
417
435
  # down_weight is copied to each split
418
- ait_sd.update({k: down_weight for k in ait_down_keys})
436
+ ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
419
437
 
420
438
  # up_weight is split to each split
421
439
  ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
@@ -709,8 +727,25 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
709
727
  elif k.startswith("lora_te1_"):
710
728
  has_te_keys = True
711
729
  continue
730
+ elif k.startswith("lora_transformer_context_embedder"):
731
+ diffusers_key = "context_embedder"
732
+ elif k.startswith("lora_transformer_norm_out_linear"):
733
+ diffusers_key = "norm_out.linear"
734
+ elif k.startswith("lora_transformer_proj_out"):
735
+ diffusers_key = "proj_out"
736
+ elif k.startswith("lora_transformer_x_embedder"):
737
+ diffusers_key = "x_embedder"
738
+ elif k.startswith("lora_transformer_time_text_embed_guidance_embedder_linear_"):
739
+ i = int(k.split("lora_transformer_time_text_embed_guidance_embedder_linear_")[-1])
740
+ diffusers_key = f"time_text_embed.guidance_embedder.linear_{i}"
741
+ elif k.startswith("lora_transformer_time_text_embed_text_embedder_linear_"):
742
+ i = int(k.split("lora_transformer_time_text_embed_text_embedder_linear_")[-1])
743
+ diffusers_key = f"time_text_embed.text_embedder.linear_{i}"
744
+ elif k.startswith("lora_transformer_time_text_embed_timestep_embedder_linear_"):
745
+ i = int(k.split("lora_transformer_time_text_embed_timestep_embedder_linear_")[-1])
746
+ diffusers_key = f"time_text_embed.timestep_embedder.linear_{i}"
712
747
  else:
713
- raise NotImplementedError
748
+ raise NotImplementedError(f"Handling for key ({k}) is not implemented.")
714
749
 
715
750
  if "attn_" in k:
716
751
  if "_to_out_0" in k:
@@ -782,7 +817,11 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
782
817
  # has both `peft` and non-peft state dict.
783
818
  has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
784
819
  if has_peft_state_dict:
785
- state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
820
+ state_dict = {
821
+ k.replace("lora_down.weight", "lora_A.weight").replace("lora_up.weight", "lora_B.weight"): v
822
+ for k, v in state_dict.items()
823
+ if k.startswith("transformer.")
824
+ }
786
825
  return state_dict
787
826
 
788
827
  # Another weird one.
@@ -801,7 +840,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
801
840
  if zero_status_pe:
802
841
  logger.info(
803
842
  "The `position_embedding` LoRA params are all zeros which make them ineffective. "
804
- "So, we will purge them out of the curret state dict to make loading possible."
843
+ "So, we will purge them out of the current state dict to make loading possible."
805
844
  )
806
845
 
807
846
  else:
@@ -817,7 +856,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
817
856
  if zero_status_t5:
818
857
  logger.info(
819
858
  "The `t5xxl` LoRA params are all zeros which make them ineffective. "
820
- "So, we will purge them out of the curret state dict to make loading possible."
859
+ "So, we will purge them out of the current state dict to make loading possible."
821
860
  )
822
861
  else:
823
862
  logger.info(
@@ -832,7 +871,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
832
871
  if zero_status_diff_b:
833
872
  logger.info(
834
873
  "The `diff_b` LoRA params are all zeros which make them ineffective. "
835
- "So, we will purge them out of the curret state dict to make loading possible."
874
+ "So, we will purge them out of the current state dict to make loading possible."
836
875
  )
837
876
  else:
838
877
  logger.info(
@@ -848,7 +887,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
848
887
  if zero_status_diff:
849
888
  logger.info(
850
889
  "The `diff` LoRA params are all zeros which make them ineffective. "
851
- "So, we will purge them out of the curret state dict to make loading possible."
890
+ "So, we will purge them out of the current state dict to make loading possible."
852
891
  )
853
892
  else:
854
893
  logger.info(
@@ -905,7 +944,7 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
905
944
  ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
906
945
 
907
946
  # down_weight is copied to each split
908
- ait_sd.update({k: down_weight for k in ait_down_keys})
947
+ ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
909
948
 
910
949
  # up_weight is split to each split
911
950
  ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
@@ -1219,7 +1258,7 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
1219
1258
  f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
1220
1259
  )
1221
1260
 
1222
- # single transfomer blocks
1261
+ # single transformer blocks
1223
1262
  for i in range(num_single_layers):
1224
1263
  block_prefix = f"single_transformer_blocks.{i}."
1225
1264
 
@@ -1311,6 +1350,228 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
1311
1350
  return converted_state_dict
1312
1351
 
1313
1352
 
1353
+ def _convert_fal_kontext_lora_to_diffusers(original_state_dict):
1354
+ converted_state_dict = {}
1355
+ original_state_dict_keys = list(original_state_dict.keys())
1356
+ num_layers = 19
1357
+ num_single_layers = 38
1358
+ inner_dim = 3072
1359
+ mlp_ratio = 4.0
1360
+
1361
+ # double transformer blocks
1362
+ for i in range(num_layers):
1363
+ block_prefix = f"transformer_blocks.{i}."
1364
+ original_block_prefix = "base_model.model."
1365
+
1366
+ for lora_key in ["lora_A", "lora_B"]:
1367
+ # norms
1368
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
1369
+ f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.weight"
1370
+ )
1371
+ if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
1372
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
1373
+ f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.bias"
1374
+ )
1375
+
1376
+ converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
1377
+ f"{original_block_prefix}double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
1378
+ )
1379
+
1380
+ # Q, K, V
1381
+ if lora_key == "lora_A":
1382
+ sample_lora_weight = original_state_dict.pop(
1383
+ f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
1384
+ )
1385
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
1386
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
1387
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
1388
+
1389
+ context_lora_weight = original_state_dict.pop(
1390
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
1391
+ )
1392
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
1393
+ [context_lora_weight]
1394
+ )
1395
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
1396
+ [context_lora_weight]
1397
+ )
1398
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
1399
+ [context_lora_weight]
1400
+ )
1401
+ else:
1402
+ sample_q, sample_k, sample_v = torch.chunk(
1403
+ original_state_dict.pop(
1404
+ f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
1405
+ ),
1406
+ 3,
1407
+ dim=0,
1408
+ )
1409
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
1410
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
1411
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
1412
+
1413
+ context_q, context_k, context_v = torch.chunk(
1414
+ original_state_dict.pop(
1415
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
1416
+ ),
1417
+ 3,
1418
+ dim=0,
1419
+ )
1420
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
1421
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
1422
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
1423
+
1424
+ if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
1425
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
1426
+ original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.bias"),
1427
+ 3,
1428
+ dim=0,
1429
+ )
1430
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
1431
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
1432
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
1433
+
1434
+ if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
1435
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
1436
+ original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"),
1437
+ 3,
1438
+ dim=0,
1439
+ )
1440
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
1441
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
1442
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
1443
+
1444
+ # ff img_mlp
1445
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
1446
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.weight"
1447
+ )
1448
+ if f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
1449
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
1450
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias"
1451
+ )
1452
+
1453
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
1454
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.weight"
1455
+ )
1456
+ if f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
1457
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
1458
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias"
1459
+ )
1460
+
1461
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
1462
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
1463
+ )
1464
+ if f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
1465
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
1466
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
1467
+ )
1468
+
1469
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
1470
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
1471
+ )
1472
+ if f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
1473
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
1474
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
1475
+ )
1476
+
1477
+ # output projections.
1478
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
1479
+ f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.weight"
1480
+ )
1481
+ if f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
1482
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
1483
+ f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias"
1484
+ )
1485
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
1486
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
1487
+ )
1488
+ if f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
1489
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
1490
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
1491
+ )
1492
+
1493
+ # single transformer blocks
1494
+ for i in range(num_single_layers):
1495
+ block_prefix = f"single_transformer_blocks.{i}."
1496
+
1497
+ for lora_key in ["lora_A", "lora_B"]:
1498
+ # norm.linear <- single_blocks.0.modulation.lin
1499
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
1500
+ f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.weight"
1501
+ )
1502
+ if f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
1503
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
1504
+ f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias"
1505
+ )
1506
+
1507
+ # Q, K, V, mlp
1508
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
1509
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
1510
+
1511
+ if lora_key == "lora_A":
1512
+ lora_weight = original_state_dict.pop(
1513
+ f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"
1514
+ )
1515
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
1516
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
1517
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
1518
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
1519
+
1520
+ if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
1521
+ lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
1522
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
1523
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
1524
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
1525
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
1526
+ else:
1527
+ q, k, v, mlp = torch.split(
1528
+ original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"),
1529
+ split_size,
1530
+ dim=0,
1531
+ )
1532
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
1533
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
1534
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
1535
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
1536
+
1537
+ if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
1538
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
1539
+ original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias"),
1540
+ split_size,
1541
+ dim=0,
1542
+ )
1543
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
1544
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
1545
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
1546
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
1547
+
1548
+ # output projections.
1549
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
1550
+ f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.weight"
1551
+ )
1552
+ if f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
1553
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
1554
+ f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias"
1555
+ )
1556
+
1557
+ for lora_key in ["lora_A", "lora_B"]:
1558
+ converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
1559
+ f"{original_block_prefix}final_layer.linear.{lora_key}.weight"
1560
+ )
1561
+ if f"{original_block_prefix}final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
1562
+ converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
1563
+ f"{original_block_prefix}final_layer.linear.{lora_key}.bias"
1564
+ )
1565
+
1566
+ if len(original_state_dict) > 0:
1567
+ raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
1568
+
1569
+ for key in list(converted_state_dict.keys()):
1570
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1571
+
1572
+ return converted_state_dict
1573
+
1574
+
1314
1575
  def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
1315
1576
  converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
1316
1577
 
@@ -1561,45 +1822,286 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1561
1822
  converted_state_dict = {}
1562
1823
  original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
1563
1824
 
1564
- num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
1825
+ block_numbers = {int(k.split(".")[1]) for k in original_state_dict if k.startswith("blocks.")}
1826
+ min_block = min(block_numbers)
1827
+ max_block = max(block_numbers)
1828
+
1565
1829
  is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
1830
+ lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
1831
+ lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
1832
+ has_time_projection_weight = any(
1833
+ k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
1834
+ )
1566
1835
 
1567
- for i in range(num_blocks):
1836
+ def get_alpha_scales(down_weight, alpha_key):
1837
+ rank = down_weight.shape[0]
1838
+ alpha = original_state_dict.pop(alpha_key).item()
1839
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
1840
+ scale_down = scale
1841
+ scale_up = 1.0
1842
+ while scale_down * 2 < scale_up:
1843
+ scale_down *= 2
1844
+ scale_up /= 2
1845
+ return scale_down, scale_up
1846
+
1847
+ for key in list(original_state_dict.keys()):
1848
+ if key.endswith((".diff", ".diff_b")) and "norm" in key:
1849
+ # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
1850
+ # in future if needed and they are not zeroed.
1851
+ original_state_dict.pop(key)
1852
+ logger.debug(f"Removing {key} key from the state dict as it is a norm diff key. This is unsupported.")
1853
+
1854
+ if "time_projection" in key and not has_time_projection_weight:
1855
+ # AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where
1856
+ # our lora config adds the time proj lora layers, but we don't have the weights for them.
1857
+ # CausVid lora has the weight keys and the bias keys.
1858
+ original_state_dict.pop(key)
1859
+
1860
+ # For the `diff_b` keys, we treat them as lora_bias.
1861
+ # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
1862
+
1863
+ for i in range(min_block, max_block + 1):
1568
1864
  # Self-attention
1569
1865
  for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1570
- converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
1571
- f"blocks.{i}.self_attn.{o}.lora_A.weight"
1572
- )
1573
- converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
1574
- f"blocks.{i}.self_attn.{o}.lora_B.weight"
1575
- )
1866
+ alpha_key = f"blocks.{i}.self_attn.{o}.alpha"
1867
+ has_alpha = alpha_key in original_state_dict
1868
+ original_key_A = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
1869
+ converted_key_A = f"blocks.{i}.attn1.{c}.lora_A.weight"
1870
+
1871
+ original_key_B = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
1872
+ converted_key_B = f"blocks.{i}.attn1.{c}.lora_B.weight"
1873
+
1874
+ if has_alpha:
1875
+ down_weight = original_state_dict.pop(original_key_A)
1876
+ up_weight = original_state_dict.pop(original_key_B)
1877
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1878
+ converted_state_dict[converted_key_A] = down_weight * scale_down
1879
+ converted_state_dict[converted_key_B] = up_weight * scale_up
1880
+
1881
+ else:
1882
+ if original_key_A in original_state_dict:
1883
+ converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
1884
+ if original_key_B in original_state_dict:
1885
+ converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
1886
+
1887
+ original_key = f"blocks.{i}.self_attn.{o}.diff_b"
1888
+ converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
1889
+ if original_key in original_state_dict:
1890
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1576
1891
 
1577
1892
  # Cross-attention
1578
1893
  for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1579
- converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1580
- f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1581
- )
1582
- converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1583
- f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1584
- )
1894
+ alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
1895
+ has_alpha = alpha_key in original_state_dict
1896
+ original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1897
+ converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
1898
+
1899
+ original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1900
+ converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
1901
+
1902
+ if original_key_A in original_state_dict:
1903
+ down_weight = original_state_dict.pop(original_key_A)
1904
+ converted_state_dict[converted_key_A] = down_weight
1905
+ if original_key_B in original_state_dict:
1906
+ up_weight = original_state_dict.pop(original_key_B)
1907
+ converted_state_dict[converted_key_B] = up_weight
1908
+ if has_alpha:
1909
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1910
+ converted_state_dict[converted_key_A] *= scale_down
1911
+ converted_state_dict[converted_key_B] *= scale_up
1912
+
1913
+ original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
1914
+ converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
1915
+ if original_key in original_state_dict:
1916
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1585
1917
 
1586
1918
  if is_i2v_lora:
1587
1919
  for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1588
- converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1589
- f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1590
- )
1591
- converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1592
- f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1593
- )
1920
+ alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
1921
+ has_alpha = alpha_key in original_state_dict
1922
+ original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1923
+ converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
1924
+
1925
+ original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1926
+ converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
1927
+
1928
+ if original_key_A in original_state_dict:
1929
+ down_weight = original_state_dict.pop(original_key_A)
1930
+ converted_state_dict[converted_key_A] = down_weight
1931
+ if original_key_B in original_state_dict:
1932
+ up_weight = original_state_dict.pop(original_key_B)
1933
+ converted_state_dict[converted_key_B] = up_weight
1934
+ if has_alpha:
1935
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1936
+ converted_state_dict[converted_key_A] *= scale_down
1937
+ converted_state_dict[converted_key_B] *= scale_up
1938
+
1939
+ original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
1940
+ converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
1941
+ if original_key in original_state_dict:
1942
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1594
1943
 
1595
1944
  # FFN
1596
1945
  for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
1597
- converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
1598
- f"blocks.{i}.{o}.lora_A.weight"
1946
+ alpha_key = f"blocks.{i}.{o}.alpha"
1947
+ has_alpha = alpha_key in original_state_dict
1948
+ original_key_A = f"blocks.{i}.{o}.{lora_down_key}.weight"
1949
+ converted_key_A = f"blocks.{i}.ffn.{c}.lora_A.weight"
1950
+
1951
+ original_key_B = f"blocks.{i}.{o}.{lora_up_key}.weight"
1952
+ converted_key_B = f"blocks.{i}.ffn.{c}.lora_B.weight"
1953
+
1954
+ if original_key_A in original_state_dict:
1955
+ down_weight = original_state_dict.pop(original_key_A)
1956
+ converted_state_dict[converted_key_A] = down_weight
1957
+ if original_key_B in original_state_dict:
1958
+ up_weight = original_state_dict.pop(original_key_B)
1959
+ converted_state_dict[converted_key_B] = up_weight
1960
+ if has_alpha:
1961
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1962
+ converted_state_dict[converted_key_A] *= scale_down
1963
+ converted_state_dict[converted_key_B] *= scale_up
1964
+
1965
+ original_key = f"blocks.{i}.{o}.diff_b"
1966
+ converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
1967
+ if original_key in original_state_dict:
1968
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1969
+
1970
+ # Remaining.
1971
+ if original_state_dict:
1972
+ if any("time_projection" in k for k in original_state_dict):
1973
+ original_key = f"time_projection.1.{lora_down_key}.weight"
1974
+ converted_key = "condition_embedder.time_proj.lora_A.weight"
1975
+ if original_key in original_state_dict:
1976
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1977
+
1978
+ original_key = f"time_projection.1.{lora_up_key}.weight"
1979
+ converted_key = "condition_embedder.time_proj.lora_B.weight"
1980
+ if original_key in original_state_dict:
1981
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1982
+
1983
+ if "time_projection.1.diff_b" in original_state_dict:
1984
+ converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop(
1985
+ "time_projection.1.diff_b"
1986
+ )
1987
+
1988
+ if any("head.head" in k for k in state_dict):
1989
+ converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
1990
+ f"head.head.{lora_down_key}.weight"
1599
1991
  )
1600
- converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
1601
- f"blocks.{i}.{o}.lora_B.weight"
1992
+ converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
1993
+ if "head.head.diff_b" in original_state_dict:
1994
+ converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
1995
+
1996
+ for text_time in ["text_embedding", "time_embedding"]:
1997
+ if any(text_time in k for k in original_state_dict):
1998
+ for b_n in [0, 2]:
1999
+ diffusers_b_n = 1 if b_n == 0 else 2
2000
+ diffusers_name = (
2001
+ "condition_embedder.text_embedder"
2002
+ if text_time == "text_embedding"
2003
+ else "condition_embedder.time_embedder"
2004
+ )
2005
+ if any(f"{text_time}.{b_n}" in k for k in original_state_dict):
2006
+ converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_A.weight"] = (
2007
+ original_state_dict.pop(f"{text_time}.{b_n}.{lora_down_key}.weight")
2008
+ )
2009
+ converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.weight"] = (
2010
+ original_state_dict.pop(f"{text_time}.{b_n}.{lora_up_key}.weight")
2011
+ )
2012
+ if f"{text_time}.{b_n}.diff_b" in original_state_dict:
2013
+ converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.bias"] = (
2014
+ original_state_dict.pop(f"{text_time}.{b_n}.diff_b")
2015
+ )
2016
+
2017
+ for img_ours, img_theirs in [
2018
+ ("ff.net.0.proj", "img_emb.proj.1"),
2019
+ ("ff.net.2", "img_emb.proj.3"),
2020
+ ]:
2021
+ original_key = f"{img_theirs}.{lora_down_key}.weight"
2022
+ converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_A.weight"
2023
+ if original_key in original_state_dict:
2024
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
2025
+
2026
+ original_key = f"{img_theirs}.{lora_up_key}.weight"
2027
+ converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_B.weight"
2028
+ if original_key in original_state_dict:
2029
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
2030
+ bias_key_theirs = original_key.removesuffix(f".{lora_up_key}.weight") + ".diff_b"
2031
+ if bias_key_theirs in original_state_dict:
2032
+ bias_key = converted_key.removesuffix(".weight") + ".bias"
2033
+ converted_state_dict[bias_key] = original_state_dict.pop(bias_key_theirs)
2034
+
2035
+ if len(original_state_dict) > 0:
2036
+ diff = all(".diff" in k for k in original_state_dict)
2037
+ if diff:
2038
+ diff_keys = {k for k in original_state_dict if k.endswith(".diff")}
2039
+ if not all("lora" not in k for k in diff_keys):
2040
+ raise ValueError
2041
+ logger.info(
2042
+ "The remaining `state_dict` contains `diff` keys which we do not handle yet. If you see performance issues, please file an issue: "
2043
+ "https://github.com/huggingface/diffusers//issues/new"
1602
2044
  )
2045
+ else:
2046
+ raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
2047
+
2048
+ for key in list(converted_state_dict.keys()):
2049
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
2050
+
2051
+ return converted_state_dict
2052
+
2053
+
2054
+ def _convert_musubi_wan_lora_to_diffusers(state_dict):
2055
+ # https://github.com/kohya-ss/musubi-tuner
2056
+ converted_state_dict = {}
2057
+ original_state_dict = {k[len("lora_unet_") :]: v for k, v in state_dict.items()}
2058
+
2059
+ num_blocks = len({k.split("blocks_")[1].split("_")[0] for k in original_state_dict})
2060
+ is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
2061
+
2062
+ def get_alpha_scales(down_weight, key):
2063
+ rank = down_weight.shape[0]
2064
+ alpha = original_state_dict.pop(key + ".alpha").item()
2065
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2066
+ scale_down = scale
2067
+ scale_up = 1.0
2068
+ while scale_down * 2 < scale_up:
2069
+ scale_down *= 2
2070
+ scale_up /= 2
2071
+ return scale_down, scale_up
2072
+
2073
+ for i in range(num_blocks):
2074
+ # Self-attention
2075
+ for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
2076
+ down_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_down.weight")
2077
+ up_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_up.weight")
2078
+ scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_self_attn_{o}")
2079
+ converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = down_weight * scale_down
2080
+ converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = up_weight * scale_up
2081
+
2082
+ # Cross-attention
2083
+ for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
2084
+ down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
2085
+ up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
2086
+ scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
2087
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
2088
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
2089
+
2090
+ if is_i2v_lora:
2091
+ for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
2092
+ down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
2093
+ up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
2094
+ scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
2095
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
2096
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
2097
+
2098
+ # FFN
2099
+ for o, c in zip(["ffn_0", "ffn_2"], ["net.0.proj", "net.2"]):
2100
+ down_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_down.weight")
2101
+ up_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_up.weight")
2102
+ scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_{o}")
2103
+ converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = down_weight * scale_down
2104
+ converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = up_weight * scale_up
1603
2105
 
1604
2106
  if len(original_state_dict) > 0:
1605
2107
  raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
@@ -1608,3 +2110,123 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1608
2110
  converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1609
2111
 
1610
2112
  return converted_state_dict
2113
+
2114
+
2115
+ def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
2116
+ if not all(k.startswith(non_diffusers_prefix) for k in state_dict):
2117
+ raise ValueError("Invalid LoRA state dict for HiDream.")
2118
+ converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
2119
+ converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
2120
+ return converted_state_dict
2121
+
2122
+
2123
+ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
2124
+ if not all(k.startswith(f"{non_diffusers_prefix}.") for k in state_dict):
2125
+ raise ValueError("Invalid LoRA state dict for LTX-Video.")
2126
+ converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
2127
+ converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
2128
+ return converted_state_dict
2129
+
2130
+
2131
+ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
2132
+ has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
2133
+ if has_lora_unet:
2134
+ state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
2135
+
2136
+ def convert_key(key: str) -> str:
2137
+ prefix = "transformer_blocks"
2138
+ if "." in key:
2139
+ base, suffix = key.rsplit(".", 1)
2140
+ else:
2141
+ base, suffix = key, ""
2142
+
2143
+ start = f"{prefix}_"
2144
+ rest = base[len(start) :]
2145
+
2146
+ if "." in rest:
2147
+ head, tail = rest.split(".", 1)
2148
+ tail = "." + tail
2149
+ else:
2150
+ head, tail = rest, ""
2151
+
2152
+ # Protected n-grams that must keep their internal underscores
2153
+ protected = {
2154
+ # pairs
2155
+ ("to", "q"),
2156
+ ("to", "k"),
2157
+ ("to", "v"),
2158
+ ("to", "out"),
2159
+ ("add", "q"),
2160
+ ("add", "k"),
2161
+ ("add", "v"),
2162
+ ("txt", "mlp"),
2163
+ ("img", "mlp"),
2164
+ ("txt", "mod"),
2165
+ ("img", "mod"),
2166
+ # triplets
2167
+ ("add", "q", "proj"),
2168
+ ("add", "k", "proj"),
2169
+ ("add", "v", "proj"),
2170
+ ("to", "add", "out"),
2171
+ }
2172
+
2173
+ prot_by_len = {}
2174
+ for ng in protected:
2175
+ prot_by_len.setdefault(len(ng), set()).add(ng)
2176
+
2177
+ parts = head.split("_")
2178
+ merged = []
2179
+ i = 0
2180
+ lengths_desc = sorted(prot_by_len.keys(), reverse=True)
2181
+
2182
+ while i < len(parts):
2183
+ matched = False
2184
+ for L in lengths_desc:
2185
+ if i + L <= len(parts) and tuple(parts[i : i + L]) in prot_by_len[L]:
2186
+ merged.append("_".join(parts[i : i + L]))
2187
+ i += L
2188
+ matched = True
2189
+ break
2190
+ if not matched:
2191
+ merged.append(parts[i])
2192
+ i += 1
2193
+
2194
+ head_converted = ".".join(merged)
2195
+ converted_base = f"{prefix}.{head_converted}{tail}"
2196
+ return converted_base + (("." + suffix) if suffix else "")
2197
+
2198
+ state_dict = {convert_key(k): v for k, v in state_dict.items()}
2199
+
2200
+ converted_state_dict = {}
2201
+ all_keys = list(state_dict.keys())
2202
+ down_key = ".lora_down.weight"
2203
+ up_key = ".lora_up.weight"
2204
+
2205
+ def get_alpha_scales(down_weight, alpha_key):
2206
+ rank = down_weight.shape[0]
2207
+ alpha = state_dict.pop(alpha_key).item()
2208
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2209
+ scale_down = scale
2210
+ scale_up = 1.0
2211
+ while scale_down * 2 < scale_up:
2212
+ scale_down *= 2
2213
+ scale_up /= 2
2214
+ return scale_down, scale_up
2215
+
2216
+ for k in all_keys:
2217
+ if k.endswith(down_key):
2218
+ diffusers_down_key = k.replace(down_key, ".lora_A.weight")
2219
+ diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
2220
+ alpha_key = k.replace(down_key, ".alpha")
2221
+
2222
+ down_weight = state_dict.pop(k)
2223
+ up_weight = state_dict.pop(k.replace(down_key, up_key))
2224
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
2225
+ converted_state_dict[diffusers_down_key] = down_weight * scale_down
2226
+ converted_state_dict[diffusers_up_key] = up_weight * scale_up
2227
+
2228
+ if len(state_dict) > 0:
2229
+ raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
2230
+
2231
+ converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
2232
+ return converted_state_dict