diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +595 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
  384. diffusers-0.33.1.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2024 The HuggingFace Inc. team.
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -44,6 +44,7 @@ from ..utils import (
44
44
  is_transformers_available,
45
45
  logging,
46
46
  )
47
+ from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
47
48
  from ..utils.hub_utils import _get_model_file
48
49
 
49
50
 
@@ -94,6 +95,12 @@ CHECKPOINT_KEY_NAMES = {
94
95
  "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
95
96
  "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
96
97
  "animatediff_rgb": "controlnet_cond_embedding.weight",
98
+ "auraflow": [
99
+ "double_layers.0.attn.w2q.weight",
100
+ "double_layers.0.attn.w1q.weight",
101
+ "cond_seq_linear.weight",
102
+ "t_embedder.mlp.0.weight",
103
+ ],
97
104
  "flux": [
98
105
  "double_blocks.0.img_attn.norm.key_norm.scale",
99
106
  "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
@@ -109,6 +116,16 @@ CHECKPOINT_KEY_NAMES = {
109
116
  "autoencoder-dc-sana": "encoder.project_in.conv.bias",
110
117
  "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
111
118
  "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
119
+ "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
120
+ "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
121
+ "sana": [
122
+ "blocks.0.cross_attn.q_linear.weight",
123
+ "blocks.0.cross_attn.q_linear.bias",
124
+ "blocks.0.cross_attn.kv_linear.weight",
125
+ "blocks.0.cross_attn.kv_linear.bias",
126
+ ],
127
+ "wan": ["model.diffusion_model.head.modulation", "head.modulation"],
128
+ "wan_vae": "decoder.middle.0.residual.0.gamma",
112
129
  }
113
130
 
114
131
  DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -153,6 +170,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
153
170
  "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
154
171
  "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
155
172
  "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
173
+ "auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"},
156
174
  "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
157
175
  "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
158
176
  "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
@@ -165,6 +183,12 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
165
183
  "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
166
184
  "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
167
185
  "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
186
+ "instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
187
+ "lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
188
+ "sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"},
189
+ "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
190
+ "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
191
+ "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
168
192
  }
169
193
 
170
194
  # Use to configure model sample size when original config is provided
@@ -177,6 +201,7 @@ DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = {
177
201
  "inpainting": 512,
178
202
  "inpainting_v2": 512,
179
203
  "controlnet": 512,
204
+ "instruct-pix2pix": 512,
180
205
  "v2": 768,
181
206
  "v1": 512,
182
207
  }
@@ -378,12 +403,14 @@ def load_single_file_checkpoint(
378
403
  cache_dir=None,
379
404
  local_files_only=None,
380
405
  revision=None,
406
+ disable_mmap=False,
381
407
  ):
382
408
  if os.path.isfile(pretrained_model_link_or_path):
383
409
  pretrained_model_link_or_path = pretrained_model_link_or_path
384
410
 
385
411
  else:
386
412
  repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
413
+ user_agent = {"file_type": "single_file", "framework": "pytorch"}
387
414
  pretrained_model_link_or_path = _get_model_file(
388
415
  repo_id,
389
416
  weights_name=weights_name,
@@ -393,9 +420,10 @@ def load_single_file_checkpoint(
393
420
  local_files_only=local_files_only,
394
421
  token=token,
395
422
  revision=revision,
423
+ user_agent=user_agent,
396
424
  )
397
425
 
398
- checkpoint = load_state_dict(pretrained_model_link_or_path)
426
+ checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)
399
427
 
400
428
  # some checkpoints contain the model state dict under a "state_dict" key
401
429
  while "state_dict" in checkpoint:
@@ -416,7 +444,7 @@ def fetch_original_config(original_config_file, local_files_only=False):
416
444
  "Please provide a valid local file path."
417
445
  )
418
446
 
419
- original_config_file = BytesIO(requests.get(original_config_file).content)
447
+ original_config_file = BytesIO(requests.get(original_config_file, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
420
448
 
421
449
  else:
422
450
  raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
@@ -637,6 +665,36 @@ def infer_diffusers_model_type(checkpoint):
637
665
  elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
638
666
  model_type = "hunyuan-video"
639
667
 
668
+ elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]):
669
+ model_type = "auraflow"
670
+
671
+ elif (
672
+ CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
673
+ and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
674
+ ):
675
+ model_type = "instruct-pix2pix"
676
+
677
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
678
+ model_type = "lumina2"
679
+
680
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]):
681
+ model_type = "sana"
682
+
683
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
684
+ if "model.diffusion_model.patch_embedding.weight" in checkpoint:
685
+ target_key = "model.diffusion_model.patch_embedding.weight"
686
+ else:
687
+ target_key = "patch_embedding.weight"
688
+
689
+ if checkpoint[target_key].shape[0] == 1536:
690
+ model_type = "wan-t2v-1.3B"
691
+ elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
692
+ model_type = "wan-t2v-14B"
693
+ else:
694
+ model_type = "wan-i2v-14B"
695
+ elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
696
+ # All Wan models use the same VAE so we can use the same default model repo to fetch the config
697
+ model_type = "wan-t2v-14B"
640
698
  else:
641
699
  model_type = "v1"
642
700
 
@@ -1423,8 +1481,8 @@ def convert_open_clip_checkpoint(
1423
1481
 
1424
1482
  if text_proj_key in checkpoint:
1425
1483
  text_proj_dim = int(checkpoint[text_proj_key].shape[0])
1426
- elif hasattr(text_model.config, "projection_dim"):
1427
- text_proj_dim = text_model.config.projection_dim
1484
+ elif hasattr(text_model.config, "hidden_size"):
1485
+ text_proj_dim = text_model.config.hidden_size
1428
1486
  else:
1429
1487
  text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
1430
1488
 
@@ -1568,18 +1626,9 @@ def create_diffusers_clip_model_from_ldm(
1568
1626
  raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
1569
1627
 
1570
1628
  if is_accelerate_available():
1571
- unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1629
+ load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1572
1630
  else:
1573
- _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
1574
-
1575
- if model._keys_to_ignore_on_load_unexpected is not None:
1576
- for pat in model._keys_to_ignore_on_load_unexpected:
1577
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1578
-
1579
- if len(unexpected_keys) > 0:
1580
- logger.warning(
1581
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1582
- )
1631
+ model.load_state_dict(diffusers_format_checkpoint, strict=False)
1583
1632
 
1584
1633
  if torch_dtype is not None:
1585
1634
  model.to(torch_dtype)
@@ -2036,16 +2085,7 @@ def create_diffusers_t5_model_from_checkpoint(
2036
2085
  diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
2037
2086
 
2038
2087
  if is_accelerate_available():
2039
- unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
2040
- if model._keys_to_ignore_on_load_unexpected is not None:
2041
- for pat in model._keys_to_ignore_on_load_unexpected:
2042
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
2043
-
2044
- if len(unexpected_keys) > 0:
2045
- logger.warning(
2046
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
2047
- )
2048
-
2088
+ load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
2049
2089
  else:
2050
2090
  model.load_state_dict(diffusers_format_checkpoint)
2051
2091
 
@@ -2086,6 +2126,7 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
2086
2126
  def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2087
2127
  converted_state_dict = {}
2088
2128
  keys = list(checkpoint.keys())
2129
+
2089
2130
  for k in keys:
2090
2131
  if "model.diffusion_model." in k:
2091
2132
  checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
@@ -2366,7 +2407,6 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
2366
2407
  "per_channel_statistics.channel": remove_keys_,
2367
2408
  "per_channel_statistics.mean-of-means": remove_keys_,
2368
2409
  "per_channel_statistics.mean-of-stds": remove_keys_,
2369
- "timestep_scale_multiplier": remove_keys_,
2370
2410
  }
2371
2411
 
2372
2412
  if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
@@ -2460,7 +2500,7 @@ def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
2460
2500
 
2461
2501
 
2462
2502
  def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2463
- new_state_dict = {}
2503
+ converted_state_dict = {}
2464
2504
 
2465
2505
  # Comfy checkpoints add this prefix
2466
2506
  keys = list(checkpoint.keys())
@@ -2469,22 +2509,22 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2469
2509
  checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
2470
2510
 
2471
2511
  # Convert patch_embed
2472
- new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
2473
- new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
2512
+ converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
2513
+ converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
2474
2514
 
2475
2515
  # Convert time_embed
2476
- new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
2477
- new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
2478
- new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
2479
- new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
2480
- new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
2481
- new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
2482
- new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
2483
- new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
2484
- new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
2485
- new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
2486
- new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
2487
- new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
2516
+ converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
2517
+ converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
2518
+ converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
2519
+ converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
2520
+ converted_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
2521
+ converted_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
2522
+ converted_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
2523
+ converted_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
2524
+ converted_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
2525
+ converted_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
2526
+ converted_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
2527
+ converted_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
2488
2528
 
2489
2529
  # Convert transformer blocks
2490
2530
  num_layers = 48
@@ -2493,68 +2533,84 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2493
2533
  old_prefix = f"blocks.{i}."
2494
2534
 
2495
2535
  # norm1
2496
- new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
2497
- new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
2536
+ converted_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
2537
+ converted_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
2498
2538
  if i < num_layers - 1:
2499
- new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
2500
- new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
2539
+ converted_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(
2540
+ old_prefix + "mod_y.weight"
2541
+ )
2542
+ converted_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(
2543
+ old_prefix + "mod_y.bias"
2544
+ )
2501
2545
  else:
2502
- new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
2546
+ converted_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
2503
2547
  old_prefix + "mod_y.weight"
2504
2548
  )
2505
- new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
2549
+ converted_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(
2550
+ old_prefix + "mod_y.bias"
2551
+ )
2506
2552
 
2507
2553
  # Visual attention
2508
2554
  qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
2509
2555
  q, k, v = qkv_weight.chunk(3, dim=0)
2510
2556
 
2511
- new_state_dict[block_prefix + "attn1.to_q.weight"] = q
2512
- new_state_dict[block_prefix + "attn1.to_k.weight"] = k
2513
- new_state_dict[block_prefix + "attn1.to_v.weight"] = v
2514
- new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
2515
- new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
2516
- new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
2517
- new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
2557
+ converted_state_dict[block_prefix + "attn1.to_q.weight"] = q
2558
+ converted_state_dict[block_prefix + "attn1.to_k.weight"] = k
2559
+ converted_state_dict[block_prefix + "attn1.to_v.weight"] = v
2560
+ converted_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(
2561
+ old_prefix + "attn.q_norm_x.weight"
2562
+ )
2563
+ converted_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(
2564
+ old_prefix + "attn.k_norm_x.weight"
2565
+ )
2566
+ converted_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(
2567
+ old_prefix + "attn.proj_x.weight"
2568
+ )
2569
+ converted_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
2518
2570
 
2519
2571
  # Context attention
2520
2572
  qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
2521
2573
  q, k, v = qkv_weight.chunk(3, dim=0)
2522
2574
 
2523
- new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
2524
- new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
2525
- new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
2526
- new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
2575
+ converted_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
2576
+ converted_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
2577
+ converted_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
2578
+ converted_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
2527
2579
  old_prefix + "attn.q_norm_y.weight"
2528
2580
  )
2529
- new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
2581
+ converted_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
2530
2582
  old_prefix + "attn.k_norm_y.weight"
2531
2583
  )
2532
2584
  if i < num_layers - 1:
2533
- new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
2585
+ converted_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
2534
2586
  old_prefix + "attn.proj_y.weight"
2535
2587
  )
2536
- new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
2588
+ converted_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(
2589
+ old_prefix + "attn.proj_y.bias"
2590
+ )
2537
2591
 
2538
2592
  # MLP
2539
- new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
2593
+ converted_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
2540
2594
  checkpoint.pop(old_prefix + "mlp_x.w1.weight")
2541
2595
  )
2542
- new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
2596
+ converted_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
2543
2597
  if i < num_layers - 1:
2544
- new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
2598
+ converted_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
2545
2599
  checkpoint.pop(old_prefix + "mlp_y.w1.weight")
2546
2600
  )
2547
- new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
2601
+ converted_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(
2602
+ old_prefix + "mlp_y.w2.weight"
2603
+ )
2548
2604
 
2549
2605
  # Output layers
2550
- new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
2551
- new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
2552
- new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
2553
- new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
2606
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
2607
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
2608
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
2609
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
2554
2610
 
2555
- new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
2611
+ converted_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
2556
2612
 
2557
- return new_state_dict
2613
+ return converted_state_dict
2558
2614
 
2559
2615
 
2560
2616
  def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
@@ -2685,3 +2741,521 @@ def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
2685
2741
  handler_fn_inplace(key, checkpoint)
2686
2742
 
2687
2743
  return checkpoint
2744
+
2745
+
2746
+ def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2747
+ converted_state_dict = {}
2748
+ state_dict_keys = list(checkpoint.keys())
2749
+
2750
+ # Handle register tokens and positional embeddings
2751
+ converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None)
2752
+
2753
+ # Handle time step projection
2754
+ converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None)
2755
+ converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None)
2756
+ converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None)
2757
+ converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None)
2758
+
2759
+ # Handle context embedder
2760
+ converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None)
2761
+
2762
+ # Calculate the number of layers
2763
+ def calculate_layers(keys, key_prefix):
2764
+ layers = set()
2765
+ for k in keys:
2766
+ if key_prefix in k:
2767
+ layer_num = int(k.split(".")[1]) # get the layer number
2768
+ layers.add(layer_num)
2769
+ return len(layers)
2770
+
2771
+ mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
2772
+ single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
2773
+
2774
+ # MMDiT blocks
2775
+ for i in range(mmdit_layers):
2776
+ # Feed-forward
2777
+ path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
2778
+ weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
2779
+ for orig_k, diffuser_k in path_mapping.items():
2780
+ for k, v in weight_mapping.items():
2781
+ converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop(
2782
+ f"double_layers.{i}.{orig_k}.{k}.weight", None
2783
+ )
2784
+
2785
+ # Norms
2786
+ path_mapping = {"modX": "norm1", "modC": "norm1_context"}
2787
+ for orig_k, diffuser_k in path_mapping.items():
2788
+ converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop(
2789
+ f"double_layers.{i}.{orig_k}.1.weight", None
2790
+ )
2791
+
2792
+ # Attentions
2793
+ x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
2794
+ context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
2795
+ for attn_mapping in [x_attn_mapping, context_attn_mapping]:
2796
+ for k, v in attn_mapping.items():
2797
+ converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
2798
+ f"double_layers.{i}.attn.{k}.weight", None
2799
+ )
2800
+
2801
+ # Single-DiT blocks
2802
+ for i in range(single_dit_layers):
2803
+ # Feed-forward
2804
+ mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
2805
+ for k, v in mapping.items():
2806
+ converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop(
2807
+ f"single_layers.{i}.mlp.{k}.weight", None
2808
+ )
2809
+
2810
+ # Norms
2811
+ converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
2812
+ f"single_layers.{i}.modCX.1.weight", None
2813
+ )
2814
+
2815
+ # Attentions
2816
+ x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
2817
+ for k, v in x_attn_mapping.items():
2818
+ converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
2819
+ f"single_layers.{i}.attn.{k}.weight", None
2820
+ )
2821
+ # Final blocks
2822
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None)
2823
+
2824
+ # Handle the final norm layer
2825
+ norm_weight = checkpoint.pop("modF.1.weight", None)
2826
+ if norm_weight is not None:
2827
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None)
2828
+ else:
2829
+ converted_state_dict["norm_out.linear.weight"] = None
2830
+
2831
+ converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding")
2832
+ converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight")
2833
+ converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")
2834
+
2835
+ return converted_state_dict
2836
+
2837
+
2838
+ def convert_lumina2_to_diffusers(checkpoint, **kwargs):
2839
+ converted_state_dict = {}
2840
+
2841
+ # Original Lumina-Image-2 has an extra norm paramter that is unused
2842
+ # We just remove it here
2843
+ checkpoint.pop("norm_final.weight", None)
2844
+
2845
+ # Comfy checkpoints add this prefix
2846
+ keys = list(checkpoint.keys())
2847
+ for k in keys:
2848
+ if "model.diffusion_model." in k:
2849
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
2850
+
2851
+ LUMINA_KEY_MAP = {
2852
+ "cap_embedder": "time_caption_embed.caption_embedder",
2853
+ "t_embedder.mlp.0": "time_caption_embed.timestep_embedder.linear_1",
2854
+ "t_embedder.mlp.2": "time_caption_embed.timestep_embedder.linear_2",
2855
+ "attention": "attn",
2856
+ ".out.": ".to_out.0.",
2857
+ "k_norm": "norm_k",
2858
+ "q_norm": "norm_q",
2859
+ "w1": "linear_1",
2860
+ "w2": "linear_2",
2861
+ "w3": "linear_3",
2862
+ "adaLN_modulation.1": "norm1.linear",
2863
+ }
2864
+ ATTENTION_NORM_MAP = {
2865
+ "attention_norm1": "norm1.norm",
2866
+ "attention_norm2": "norm2",
2867
+ }
2868
+ CONTEXT_REFINER_MAP = {
2869
+ "context_refiner.0.attention_norm1": "context_refiner.0.norm1",
2870
+ "context_refiner.0.attention_norm2": "context_refiner.0.norm2",
2871
+ "context_refiner.1.attention_norm1": "context_refiner.1.norm1",
2872
+ "context_refiner.1.attention_norm2": "context_refiner.1.norm2",
2873
+ }
2874
+ FINAL_LAYER_MAP = {
2875
+ "final_layer.adaLN_modulation.1": "norm_out.linear_1",
2876
+ "final_layer.linear": "norm_out.linear_2",
2877
+ }
2878
+
2879
+ def convert_lumina_attn_to_diffusers(tensor, diffusers_key):
2880
+ q_dim = 2304
2881
+ k_dim = v_dim = 768
2882
+
2883
+ to_q, to_k, to_v = torch.split(tensor, [q_dim, k_dim, v_dim], dim=0)
2884
+
2885
+ return {
2886
+ diffusers_key.replace("qkv", "to_q"): to_q,
2887
+ diffusers_key.replace("qkv", "to_k"): to_k,
2888
+ diffusers_key.replace("qkv", "to_v"): to_v,
2889
+ }
2890
+
2891
+ for key in keys:
2892
+ diffusers_key = key
2893
+ for k, v in CONTEXT_REFINER_MAP.items():
2894
+ diffusers_key = diffusers_key.replace(k, v)
2895
+ for k, v in FINAL_LAYER_MAP.items():
2896
+ diffusers_key = diffusers_key.replace(k, v)
2897
+ for k, v in ATTENTION_NORM_MAP.items():
2898
+ diffusers_key = diffusers_key.replace(k, v)
2899
+ for k, v in LUMINA_KEY_MAP.items():
2900
+ diffusers_key = diffusers_key.replace(k, v)
2901
+
2902
+ if "qkv" in diffusers_key:
2903
+ converted_state_dict.update(convert_lumina_attn_to_diffusers(checkpoint.pop(key), diffusers_key))
2904
+ else:
2905
+ converted_state_dict[diffusers_key] = checkpoint.pop(key)
2906
+
2907
+ return converted_state_dict
2908
+
2909
+
2910
+ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
2911
+ converted_state_dict = {}
2912
+ keys = list(checkpoint.keys())
2913
+ for k in keys:
2914
+ if "model.diffusion_model." in k:
2915
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
2916
+
2917
+ num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401
2918
+
2919
+ # Positional and patch embeddings.
2920
+ checkpoint.pop("pos_embed")
2921
+ converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
2922
+ converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
2923
+
2924
+ # Timestep embeddings.
2925
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop(
2926
+ "t_embedder.mlp.0.weight"
2927
+ )
2928
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
2929
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop(
2930
+ "t_embedder.mlp.2.weight"
2931
+ )
2932
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
2933
+ converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight")
2934
+ converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias")
2935
+
2936
+ # Caption Projection.
2937
+ checkpoint.pop("y_embedder.y_embedding")
2938
+ converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight")
2939
+ converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias")
2940
+ converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight")
2941
+ converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias")
2942
+ converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight")
2943
+
2944
+ for i in range(num_layers):
2945
+ converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop(
2946
+ f"blocks.{i}.scale_shift_table"
2947
+ )
2948
+
2949
+ # Self-Attention
2950
+ sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0)
2951
+ converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q])
2952
+ converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k])
2953
+ converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v])
2954
+
2955
+ # Output Projections
2956
+ converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop(
2957
+ f"blocks.{i}.attn.proj.weight"
2958
+ )
2959
+ converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop(
2960
+ f"blocks.{i}.attn.proj.bias"
2961
+ )
2962
+
2963
+ # Cross-Attention
2964
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop(
2965
+ f"blocks.{i}.cross_attn.q_linear.weight"
2966
+ )
2967
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop(
2968
+ f"blocks.{i}.cross_attn.q_linear.bias"
2969
+ )
2970
+
2971
+ linear_sample_k, linear_sample_v = torch.chunk(
2972
+ checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0
2973
+ )
2974
+ linear_sample_k_bias, linear_sample_v_bias = torch.chunk(
2975
+ checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0
2976
+ )
2977
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k
2978
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v
2979
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias
2980
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias
2981
+
2982
+ # Output Projections
2983
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
2984
+ f"blocks.{i}.cross_attn.proj.weight"
2985
+ )
2986
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
2987
+ f"blocks.{i}.cross_attn.proj.bias"
2988
+ )
2989
+
2990
+ # MLP
2991
+ converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop(
2992
+ f"blocks.{i}.mlp.inverted_conv.conv.weight"
2993
+ )
2994
+ converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop(
2995
+ f"blocks.{i}.mlp.inverted_conv.conv.bias"
2996
+ )
2997
+ converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(
2998
+ f"blocks.{i}.mlp.depth_conv.conv.weight"
2999
+ )
3000
+ converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(
3001
+ f"blocks.{i}.mlp.depth_conv.conv.bias"
3002
+ )
3003
+ converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(
3004
+ f"blocks.{i}.mlp.point_conv.conv.weight"
3005
+ )
3006
+
3007
+ # Final layer
3008
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
3009
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
3010
+ converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table")
3011
+
3012
+ return converted_state_dict
3013
+
3014
+
3015
+ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
3016
+ converted_state_dict = {}
3017
+
3018
+ keys = list(checkpoint.keys())
3019
+ for k in keys:
3020
+ if "model.diffusion_model." in k:
3021
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
3022
+
3023
+ TRANSFORMER_KEYS_RENAME_DICT = {
3024
+ "time_embedding.0": "condition_embedder.time_embedder.linear_1",
3025
+ "time_embedding.2": "condition_embedder.time_embedder.linear_2",
3026
+ "text_embedding.0": "condition_embedder.text_embedder.linear_1",
3027
+ "text_embedding.2": "condition_embedder.text_embedder.linear_2",
3028
+ "time_projection.1": "condition_embedder.time_proj",
3029
+ "cross_attn": "attn2",
3030
+ "self_attn": "attn1",
3031
+ ".o.": ".to_out.0.",
3032
+ ".q.": ".to_q.",
3033
+ ".k.": ".to_k.",
3034
+ ".v.": ".to_v.",
3035
+ ".k_img.": ".add_k_proj.",
3036
+ ".v_img.": ".add_v_proj.",
3037
+ ".norm_k_img.": ".norm_added_k.",
3038
+ "head.modulation": "scale_shift_table",
3039
+ "head.head": "proj_out",
3040
+ "modulation": "scale_shift_table",
3041
+ "ffn.0": "ffn.net.0.proj",
3042
+ "ffn.2": "ffn.net.2",
3043
+ # Hack to swap the layer names
3044
+ # The original model calls the norms in following order: norm1, norm3, norm2
3045
+ # We convert it to: norm1, norm2, norm3
3046
+ "norm2": "norm__placeholder",
3047
+ "norm3": "norm2",
3048
+ "norm__placeholder": "norm3",
3049
+ # For the I2V model
3050
+ "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
3051
+ "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
3052
+ "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
3053
+ "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
3054
+ }
3055
+
3056
+ for key in list(checkpoint.keys()):
3057
+ new_key = key[:]
3058
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
3059
+ new_key = new_key.replace(replace_key, rename_key)
3060
+
3061
+ converted_state_dict[new_key] = checkpoint.pop(key)
3062
+
3063
+ return converted_state_dict
3064
+
3065
+
3066
+ def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
3067
+ converted_state_dict = {}
3068
+
3069
+ # Create mappings for specific components
3070
+ middle_key_mapping = {
3071
+ # Encoder middle block
3072
+ "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
3073
+ "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
3074
+ "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
3075
+ "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
3076
+ "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
3077
+ "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
3078
+ "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
3079
+ "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
3080
+ "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
3081
+ "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
3082
+ "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
3083
+ "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
3084
+ # Decoder middle block
3085
+ "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
3086
+ "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
3087
+ "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
3088
+ "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
3089
+ "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
3090
+ "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
3091
+ "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
3092
+ "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
3093
+ "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
3094
+ "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
3095
+ "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
3096
+ "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
3097
+ }
3098
+
3099
+ # Create a mapping for attention blocks
3100
+ attention_mapping = {
3101
+ # Encoder middle attention
3102
+ "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
3103
+ "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
3104
+ "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
3105
+ "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
3106
+ "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
3107
+ # Decoder middle attention
3108
+ "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
3109
+ "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
3110
+ "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
3111
+ "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
3112
+ "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
3113
+ }
3114
+
3115
+ # Create a mapping for the head components
3116
+ head_mapping = {
3117
+ # Encoder head
3118
+ "encoder.head.0.gamma": "encoder.norm_out.gamma",
3119
+ "encoder.head.2.bias": "encoder.conv_out.bias",
3120
+ "encoder.head.2.weight": "encoder.conv_out.weight",
3121
+ # Decoder head
3122
+ "decoder.head.0.gamma": "decoder.norm_out.gamma",
3123
+ "decoder.head.2.bias": "decoder.conv_out.bias",
3124
+ "decoder.head.2.weight": "decoder.conv_out.weight",
3125
+ }
3126
+
3127
+ # Create a mapping for the quant components
3128
+ quant_mapping = {
3129
+ "conv1.weight": "quant_conv.weight",
3130
+ "conv1.bias": "quant_conv.bias",
3131
+ "conv2.weight": "post_quant_conv.weight",
3132
+ "conv2.bias": "post_quant_conv.bias",
3133
+ }
3134
+
3135
+ # Process each key in the state dict
3136
+ for key, value in checkpoint.items():
3137
+ # Handle middle block keys using the mapping
3138
+ if key in middle_key_mapping:
3139
+ new_key = middle_key_mapping[key]
3140
+ converted_state_dict[new_key] = value
3141
+ # Handle attention blocks using the mapping
3142
+ elif key in attention_mapping:
3143
+ new_key = attention_mapping[key]
3144
+ converted_state_dict[new_key] = value
3145
+ # Handle head keys using the mapping
3146
+ elif key in head_mapping:
3147
+ new_key = head_mapping[key]
3148
+ converted_state_dict[new_key] = value
3149
+ # Handle quant keys using the mapping
3150
+ elif key in quant_mapping:
3151
+ new_key = quant_mapping[key]
3152
+ converted_state_dict[new_key] = value
3153
+ # Handle encoder conv1
3154
+ elif key == "encoder.conv1.weight":
3155
+ converted_state_dict["encoder.conv_in.weight"] = value
3156
+ elif key == "encoder.conv1.bias":
3157
+ converted_state_dict["encoder.conv_in.bias"] = value
3158
+ # Handle decoder conv1
3159
+ elif key == "decoder.conv1.weight":
3160
+ converted_state_dict["decoder.conv_in.weight"] = value
3161
+ elif key == "decoder.conv1.bias":
3162
+ converted_state_dict["decoder.conv_in.bias"] = value
3163
+ # Handle encoder downsamples
3164
+ elif key.startswith("encoder.downsamples."):
3165
+ # Convert to down_blocks
3166
+ new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
3167
+
3168
+ # Convert residual block naming but keep the original structure
3169
+ if ".residual.0.gamma" in new_key:
3170
+ new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
3171
+ elif ".residual.2.bias" in new_key:
3172
+ new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
3173
+ elif ".residual.2.weight" in new_key:
3174
+ new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
3175
+ elif ".residual.3.gamma" in new_key:
3176
+ new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
3177
+ elif ".residual.6.bias" in new_key:
3178
+ new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
3179
+ elif ".residual.6.weight" in new_key:
3180
+ new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
3181
+ elif ".shortcut.bias" in new_key:
3182
+ new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
3183
+ elif ".shortcut.weight" in new_key:
3184
+ new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
3185
+
3186
+ converted_state_dict[new_key] = value
3187
+
3188
+ # Handle decoder upsamples
3189
+ elif key.startswith("decoder.upsamples."):
3190
+ # Convert to up_blocks
3191
+ parts = key.split(".")
3192
+ block_idx = int(parts[2])
3193
+
3194
+ # Group residual blocks
3195
+ if "residual" in key:
3196
+ if block_idx in [0, 1, 2]:
3197
+ new_block_idx = 0
3198
+ resnet_idx = block_idx
3199
+ elif block_idx in [4, 5, 6]:
3200
+ new_block_idx = 1
3201
+ resnet_idx = block_idx - 4
3202
+ elif block_idx in [8, 9, 10]:
3203
+ new_block_idx = 2
3204
+ resnet_idx = block_idx - 8
3205
+ elif block_idx in [12, 13, 14]:
3206
+ new_block_idx = 3
3207
+ resnet_idx = block_idx - 12
3208
+ else:
3209
+ # Keep as is for other blocks
3210
+ converted_state_dict[key] = value
3211
+ continue
3212
+
3213
+ # Convert residual block naming
3214
+ if ".residual.0.gamma" in key:
3215
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
3216
+ elif ".residual.2.bias" in key:
3217
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
3218
+ elif ".residual.2.weight" in key:
3219
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
3220
+ elif ".residual.3.gamma" in key:
3221
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
3222
+ elif ".residual.6.bias" in key:
3223
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
3224
+ elif ".residual.6.weight" in key:
3225
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
3226
+ else:
3227
+ new_key = key
3228
+
3229
+ converted_state_dict[new_key] = value
3230
+
3231
+ # Handle shortcut connections
3232
+ elif ".shortcut." in key:
3233
+ if block_idx == 4:
3234
+ new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
3235
+ new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
3236
+ else:
3237
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
3238
+ new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
3239
+
3240
+ converted_state_dict[new_key] = value
3241
+
3242
+ # Handle upsamplers
3243
+ elif ".resample." in key or ".time_conv." in key:
3244
+ if block_idx == 3:
3245
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
3246
+ elif block_idx == 7:
3247
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
3248
+ elif block_idx == 11:
3249
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
3250
+ else:
3251
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
3252
+
3253
+ converted_state_dict[new_key] = value
3254
+ else:
3255
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
3256
+ converted_state_dict[new_key] = value
3257
+ else:
3258
+ # Keep other keys unchanged
3259
+ converted_state_dict[key] = value
3260
+
3261
+ return converted_state_dict