diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
  """Conversion script for the Stable Diffusion checkpoints."""
16
16
 
17
+ import copy
17
18
  import os
18
19
  import re
19
20
  from contextlib import nullcontext
@@ -61,7 +62,14 @@ CHECKPOINT_KEY_NAMES = {
61
62
  "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
62
63
  "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
63
64
  "upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias",
64
- "controlnet": "control_model.time_embed.0.weight",
65
+ "controlnet": [
66
+ "control_model.time_embed.0.weight",
67
+ "controlnet_cond_embedding.conv_in.weight",
68
+ ],
69
+ # TODO: find non-Diffusers keys for controlnet_xl
70
+ "controlnet_xl": "add_embedding.linear_1.weight",
71
+ "controlnet_xl_large": "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
72
+ "controlnet_xl_mid": "down_blocks.1.attentions.0.norm.weight",
65
73
  "playground-v2-5": "edm_mean",
66
74
  "inpainting": "model.diffusion_model.input_blocks.0.0.weight",
67
75
  "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
@@ -73,7 +81,14 @@ CHECKPOINT_KEY_NAMES = {
73
81
  "open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
74
82
  "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
75
83
  "stable_cascade_stage_c": "clip_txt_mapper.weight",
76
- "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
84
+ "sd3": [
85
+ "joint_blocks.0.context_block.adaLN_modulation.1.bias",
86
+ "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
87
+ ],
88
+ "sd35_large": [
89
+ "joint_blocks.37.x_block.mlp.fc1.weight",
90
+ "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
91
+ ],
77
92
  "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
78
93
  "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
79
94
  "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
@@ -83,6 +98,17 @@ CHECKPOINT_KEY_NAMES = {
83
98
  "double_blocks.0.img_attn.norm.key_norm.scale",
84
99
  "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
85
100
  ],
101
+ "ltx-video": [
102
+ "model.diffusion_model.patchify_proj.weight",
103
+ "model.diffusion_model.transformer_blocks.27.scale_shift_table",
104
+ "patchify_proj.weight",
105
+ "transformer_blocks.27.scale_shift_table",
106
+ "vae.per_channel_statistics.mean-of-means",
107
+ ],
108
+ "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
109
+ "autoencoder-dc-sana": "encoder.project_in.conv.bias",
110
+ "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
111
+ "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
86
112
  }
87
113
 
88
114
  DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -91,11 +117,14 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
91
117
  "xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"},
92
118
  "playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"},
93
119
  "upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"},
94
- "inpainting": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8-inpainting"},
120
+ "inpainting": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-inpainting"},
95
121
  "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"},
96
122
  "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"},
123
+ "controlnet_xl_large": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0"},
124
+ "controlnet_xl_mid": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-mid"},
125
+ "controlnet_xl_small": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-small"},
97
126
  "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"},
98
- "v1": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8"},
127
+ "v1": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-v1-5"},
99
128
  "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"},
100
129
  "stable_cascade_stage_b_lite": {
101
130
  "pretrained_model_name_or_path": "stabilityai/stable-cascade",
@@ -112,6 +141,12 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
112
141
  "sd3": {
113
142
  "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
114
143
  },
144
+ "sd35_large": {
145
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-large",
146
+ },
147
+ "sd35_medium": {
148
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-medium",
149
+ },
115
150
  "animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
116
151
  "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
117
152
  "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
@@ -119,7 +154,17 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
119
154
  "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
120
155
  "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
121
156
  "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
157
+ "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
158
+ "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
122
159
  "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
160
+ "ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
161
+ "ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
162
+ "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
163
+ "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
164
+ "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
165
+ "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
166
+ "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
167
+ "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
123
168
  }
124
169
 
125
170
  # Use to configure model sample size when original config is provided
@@ -456,6 +501,8 @@ def infer_diffusers_model_type(checkpoint):
456
501
  ):
457
502
  if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
458
503
  model_type = "inpainting_v2"
504
+ elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
505
+ model_type = "xl_inpaint"
459
506
  else:
460
507
  model_type = "inpainting"
461
508
 
@@ -474,8 +521,16 @@ def infer_diffusers_model_type(checkpoint):
474
521
  elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint:
475
522
  model_type = "upscale"
476
523
 
477
- elif CHECKPOINT_KEY_NAMES["controlnet"] in checkpoint:
478
- model_type = "controlnet"
524
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["controlnet"]):
525
+ if CHECKPOINT_KEY_NAMES["controlnet_xl"] in checkpoint:
526
+ if CHECKPOINT_KEY_NAMES["controlnet_xl_large"] in checkpoint:
527
+ model_type = "controlnet_xl_large"
528
+ elif CHECKPOINT_KEY_NAMES["controlnet_xl_mid"] in checkpoint:
529
+ model_type = "controlnet_xl_mid"
530
+ else:
531
+ model_type = "controlnet_xl_small"
532
+ else:
533
+ model_type = "controlnet"
479
534
 
480
535
  elif (
481
536
  CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
@@ -501,8 +556,21 @@ def infer_diffusers_model_type(checkpoint):
501
556
  ):
502
557
  model_type = "stable_cascade_stage_b"
503
558
 
504
- elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
505
- model_type = "sd3"
559
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any(
560
+ checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"]
561
+ ):
562
+ if "model.diffusion_model.pos_embed" in checkpoint:
563
+ key = "model.diffusion_model.pos_embed"
564
+ else:
565
+ key = "pos_embed"
566
+
567
+ if checkpoint[key].shape[1] == 36864:
568
+ model_type = "sd3"
569
+ elif checkpoint[key].shape[1] == 147456:
570
+ model_type = "sd35_medium"
571
+
572
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]):
573
+ model_type = "sd35_large"
506
574
 
507
575
  elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
508
576
  if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
@@ -527,9 +595,44 @@ def infer_diffusers_model_type(checkpoint):
527
595
  if any(
528
596
  g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
529
597
  ):
530
- model_type = "flux-dev"
598
+ if checkpoint["img_in.weight"].shape[1] == 384:
599
+ model_type = "flux-fill"
600
+
601
+ elif checkpoint["img_in.weight"].shape[1] == 128:
602
+ model_type = "flux-depth"
603
+ else:
604
+ model_type = "flux-dev"
531
605
  else:
532
606
  model_type = "flux-schnell"
607
+
608
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
609
+ if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
610
+ model_type = "ltx-video-0.9.1"
611
+ else:
612
+ model_type = "ltx-video"
613
+
614
+ elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
615
+ encoder_key = "encoder.project_in.conv.conv.bias"
616
+ decoder_key = "decoder.project_in.main.conv.weight"
617
+
618
+ if CHECKPOINT_KEY_NAMES["autoencoder-dc-sana"] in checkpoint:
619
+ model_type = "autoencoder-dc-f32c32-sana"
620
+
621
+ elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 32:
622
+ model_type = "autoencoder-dc-f32c32"
623
+
624
+ elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 128:
625
+ model_type = "autoencoder-dc-f64c128"
626
+
627
+ else:
628
+ model_type = "autoencoder-dc-f128c512"
629
+
630
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
631
+ model_type = "mochi-1-preview"
632
+
633
+ elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
634
+ model_type = "hunyuan-video"
635
+
533
636
  else:
534
637
  model_type = "v1"
535
638
 
@@ -539,6 +642,7 @@ def infer_diffusers_model_type(checkpoint):
539
642
  def fetch_diffusers_config(checkpoint):
540
643
  model_type = infer_diffusers_model_type(checkpoint)
541
644
  model_path = DIFFUSERS_DEFAULT_PIPELINE_PATHS[model_type]
645
+ model_path = copy.deepcopy(model_path)
542
646
 
543
647
  return model_path
544
648
 
@@ -1061,6 +1165,9 @@ def convert_controlnet_checkpoint(
1061
1165
  config,
1062
1166
  **kwargs,
1063
1167
  ):
1168
+ # Return checkpoint if it's already been converted
1169
+ if "time_embedding.linear_1.weight" in checkpoint:
1170
+ return checkpoint
1064
1171
  # Some controlnet ckpt files are distributed independently from the rest of the
1065
1172
  # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
1066
1173
  if "time_embed.0.weight" in checkpoint:
@@ -1666,6 +1773,28 @@ def swap_scale_shift(weight, dim):
1666
1773
  return new_weight
1667
1774
 
1668
1775
 
1776
+ def swap_proj_gate(weight):
1777
+ proj, gate = weight.chunk(2, dim=0)
1778
+ new_weight = torch.cat([gate, proj], dim=0)
1779
+ return new_weight
1780
+
1781
+
1782
+ def get_attn2_layers(state_dict):
1783
+ attn2_layers = []
1784
+ for key in state_dict.keys():
1785
+ if "attn2." in key:
1786
+ # Extract the layer number from the key
1787
+ layer_num = int(key.split(".")[1])
1788
+ attn2_layers.append(layer_num)
1789
+
1790
+ return tuple(sorted(set(attn2_layers)))
1791
+
1792
+
1793
+ def get_caption_projection_dim(state_dict):
1794
+ caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
1795
+ return caption_projection_dim
1796
+
1797
+
1669
1798
  def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1670
1799
  converted_state_dict = {}
1671
1800
  keys = list(checkpoint.keys())
@@ -1674,7 +1803,10 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1674
1803
  checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
1675
1804
 
1676
1805
  num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
1677
- caption_projection_dim = 1536
1806
+ dual_attention_layers = get_attn2_layers(checkpoint)
1807
+
1808
+ caption_projection_dim = get_caption_projection_dim(checkpoint)
1809
+ has_qk_norm = any("ln_q" in key for key in checkpoint.keys())
1678
1810
 
1679
1811
  # Positional and patch embeddings.
1680
1812
  converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
@@ -1731,6 +1863,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1731
1863
  converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
1732
1864
  converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
1733
1865
 
1866
+ # qk norm
1867
+ if has_qk_norm:
1868
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = checkpoint.pop(
1869
+ f"joint_blocks.{i}.x_block.attn.ln_q.weight"
1870
+ )
1871
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = checkpoint.pop(
1872
+ f"joint_blocks.{i}.x_block.attn.ln_k.weight"
1873
+ )
1874
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = checkpoint.pop(
1875
+ f"joint_blocks.{i}.context_block.attn.ln_q.weight"
1876
+ )
1877
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = checkpoint.pop(
1878
+ f"joint_blocks.{i}.context_block.attn.ln_k.weight"
1879
+ )
1880
+
1734
1881
  # output projections.
1735
1882
  converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
1736
1883
  f"joint_blocks.{i}.x_block.attn.proj.weight"
@@ -1746,6 +1893,38 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1746
1893
  f"joint_blocks.{i}.context_block.attn.proj.bias"
1747
1894
  )
1748
1895
 
1896
+ if i in dual_attention_layers:
1897
+ # Q, K, V
1898
+ sample_q2, sample_k2, sample_v2 = torch.chunk(
1899
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
1900
+ )
1901
+ sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk(
1902
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0
1903
+ )
1904
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2])
1905
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias])
1906
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2])
1907
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias])
1908
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2])
1909
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias])
1910
+
1911
+ # qk norm
1912
+ if has_qk_norm:
1913
+ converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = checkpoint.pop(
1914
+ f"joint_blocks.{i}.x_block.attn2.ln_q.weight"
1915
+ )
1916
+ converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = checkpoint.pop(
1917
+ f"joint_blocks.{i}.x_block.attn2.ln_k.weight"
1918
+ )
1919
+
1920
+ # output projections.
1921
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
1922
+ f"joint_blocks.{i}.x_block.attn2.proj.weight"
1923
+ )
1924
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
1925
+ f"joint_blocks.{i}.x_block.attn2.proj.bias"
1926
+ )
1927
+
1749
1928
  # norms.
1750
1929
  converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
1751
1930
  f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
@@ -2094,3 +2273,411 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2094
2273
  )
2095
2274
 
2096
2275
  return converted_state_dict
2276
+
2277
+
2278
+ def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2279
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key}
2280
+
2281
+ TRANSFORMER_KEYS_RENAME_DICT = {
2282
+ "model.diffusion_model.": "",
2283
+ "patchify_proj": "proj_in",
2284
+ "adaln_single": "time_embed",
2285
+ "q_norm": "norm_q",
2286
+ "k_norm": "norm_k",
2287
+ }
2288
+
2289
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {}
2290
+
2291
+ for key in list(converted_state_dict.keys()):
2292
+ new_key = key
2293
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
2294
+ new_key = new_key.replace(replace_key, rename_key)
2295
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
2296
+
2297
+ for key in list(converted_state_dict.keys()):
2298
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
2299
+ if special_key not in key:
2300
+ continue
2301
+ handler_fn_inplace(key, converted_state_dict)
2302
+
2303
+ return converted_state_dict
2304
+
2305
+
2306
+ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
2307
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae." in key}
2308
+
2309
+ def remove_keys_(key: str, state_dict):
2310
+ state_dict.pop(key)
2311
+
2312
+ VAE_KEYS_RENAME_DICT = {
2313
+ # common
2314
+ "vae.": "",
2315
+ # decoder
2316
+ "up_blocks.0": "mid_block",
2317
+ "up_blocks.1": "up_blocks.0",
2318
+ "up_blocks.2": "up_blocks.1.upsamplers.0",
2319
+ "up_blocks.3": "up_blocks.1",
2320
+ "up_blocks.4": "up_blocks.2.conv_in",
2321
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
2322
+ "up_blocks.6": "up_blocks.2",
2323
+ "up_blocks.7": "up_blocks.3.conv_in",
2324
+ "up_blocks.8": "up_blocks.3.upsamplers.0",
2325
+ "up_blocks.9": "up_blocks.3",
2326
+ # encoder
2327
+ "down_blocks.0": "down_blocks.0",
2328
+ "down_blocks.1": "down_blocks.0.downsamplers.0",
2329
+ "down_blocks.2": "down_blocks.0.conv_out",
2330
+ "down_blocks.3": "down_blocks.1",
2331
+ "down_blocks.4": "down_blocks.1.downsamplers.0",
2332
+ "down_blocks.5": "down_blocks.1.conv_out",
2333
+ "down_blocks.6": "down_blocks.2",
2334
+ "down_blocks.7": "down_blocks.2.downsamplers.0",
2335
+ "down_blocks.8": "down_blocks.3",
2336
+ "down_blocks.9": "mid_block",
2337
+ # common
2338
+ "conv_shortcut": "conv_shortcut.conv",
2339
+ "res_blocks": "resnets",
2340
+ "norm3.norm": "norm3",
2341
+ "per_channel_statistics.mean-of-means": "latents_mean",
2342
+ "per_channel_statistics.std-of-means": "latents_std",
2343
+ }
2344
+
2345
+ VAE_091_RENAME_DICT = {
2346
+ # decoder
2347
+ "up_blocks.0": "mid_block",
2348
+ "up_blocks.1": "up_blocks.0.upsamplers.0",
2349
+ "up_blocks.2": "up_blocks.0",
2350
+ "up_blocks.3": "up_blocks.1.upsamplers.0",
2351
+ "up_blocks.4": "up_blocks.1",
2352
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
2353
+ "up_blocks.6": "up_blocks.2",
2354
+ "up_blocks.7": "up_blocks.3.upsamplers.0",
2355
+ "up_blocks.8": "up_blocks.3",
2356
+ # common
2357
+ "last_time_embedder": "time_embedder",
2358
+ "last_scale_shift_table": "scale_shift_table",
2359
+ }
2360
+
2361
+ VAE_SPECIAL_KEYS_REMAP = {
2362
+ "per_channel_statistics.channel": remove_keys_,
2363
+ "per_channel_statistics.mean-of-means": remove_keys_,
2364
+ "per_channel_statistics.mean-of-stds": remove_keys_,
2365
+ "timestep_scale_multiplier": remove_keys_,
2366
+ }
2367
+
2368
+ if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
2369
+ VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
2370
+
2371
+ for key in list(converted_state_dict.keys()):
2372
+ new_key = key
2373
+ for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
2374
+ new_key = new_key.replace(replace_key, rename_key)
2375
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
2376
+
2377
+ for key in list(converted_state_dict.keys()):
2378
+ for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
2379
+ if special_key not in key:
2380
+ continue
2381
+ handler_fn_inplace(key, converted_state_dict)
2382
+
2383
+ return converted_state_dict
2384
+
2385
+
2386
+ def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
2387
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
2388
+
2389
+ def remap_qkv_(key: str, state_dict):
2390
+ qkv = state_dict.pop(key)
2391
+ q, k, v = torch.chunk(qkv, 3, dim=0)
2392
+ parent_module, _, _ = key.rpartition(".qkv.conv.weight")
2393
+ state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
2394
+ state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
2395
+ state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
2396
+
2397
+ def remap_proj_conv_(key: str, state_dict):
2398
+ parent_module, _, _ = key.rpartition(".proj.conv.weight")
2399
+ state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
2400
+
2401
+ AE_KEYS_RENAME_DICT = {
2402
+ # common
2403
+ "main.": "",
2404
+ "op_list.": "",
2405
+ "context_module": "attn",
2406
+ "local_module": "conv_out",
2407
+ # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
2408
+ # If there were more scales, there would be more layers, so a loop would be better to handle this
2409
+ "aggreg.0.0": "to_qkv_multiscale.0.proj_in",
2410
+ "aggreg.0.1": "to_qkv_multiscale.0.proj_out",
2411
+ "depth_conv.conv": "conv_depth",
2412
+ "inverted_conv.conv": "conv_inverted",
2413
+ "point_conv.conv": "conv_point",
2414
+ "point_conv.norm": "norm",
2415
+ "conv.conv.": "conv.",
2416
+ "conv1.conv": "conv1",
2417
+ "conv2.conv": "conv2",
2418
+ "conv2.norm": "norm",
2419
+ "proj.norm": "norm_out",
2420
+ # encoder
2421
+ "encoder.project_in.conv": "encoder.conv_in",
2422
+ "encoder.project_out.0.conv": "encoder.conv_out",
2423
+ "encoder.stages": "encoder.down_blocks",
2424
+ # decoder
2425
+ "decoder.project_in.conv": "decoder.conv_in",
2426
+ "decoder.project_out.0": "decoder.norm_out",
2427
+ "decoder.project_out.2.conv": "decoder.conv_out",
2428
+ "decoder.stages": "decoder.up_blocks",
2429
+ }
2430
+
2431
+ AE_F32C32_F64C128_F128C512_KEYS = {
2432
+ "encoder.project_in.conv": "encoder.conv_in.conv",
2433
+ "decoder.project_out.2.conv": "decoder.conv_out.conv",
2434
+ }
2435
+
2436
+ AE_SPECIAL_KEYS_REMAP = {
2437
+ "qkv.conv.weight": remap_qkv_,
2438
+ "proj.conv.weight": remap_proj_conv_,
2439
+ }
2440
+ if "encoder.project_in.conv.bias" not in converted_state_dict:
2441
+ AE_KEYS_RENAME_DICT.update(AE_F32C32_F64C128_F128C512_KEYS)
2442
+
2443
+ for key in list(converted_state_dict.keys()):
2444
+ new_key = key[:]
2445
+ for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
2446
+ new_key = new_key.replace(replace_key, rename_key)
2447
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
2448
+
2449
+ for key in list(converted_state_dict.keys()):
2450
+ for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
2451
+ if special_key not in key:
2452
+ continue
2453
+ handler_fn_inplace(key, converted_state_dict)
2454
+
2455
+ return converted_state_dict
2456
+
2457
+
2458
+ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2459
+ new_state_dict = {}
2460
+
2461
+ # Comfy checkpoints add this prefix
2462
+ keys = list(checkpoint.keys())
2463
+ for k in keys:
2464
+ if "model.diffusion_model." in k:
2465
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
2466
+
2467
+ # Convert patch_embed
2468
+ new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
2469
+ new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
2470
+
2471
+ # Convert time_embed
2472
+ new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
2473
+ new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
2474
+ new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
2475
+ new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
2476
+ new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
2477
+ new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
2478
+ new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
2479
+ new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
2480
+ new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
2481
+ new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
2482
+ new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
2483
+ new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
2484
+
2485
+ # Convert transformer blocks
2486
+ num_layers = 48
2487
+ for i in range(num_layers):
2488
+ block_prefix = f"transformer_blocks.{i}."
2489
+ old_prefix = f"blocks.{i}."
2490
+
2491
+ # norm1
2492
+ new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
2493
+ new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
2494
+ if i < num_layers - 1:
2495
+ new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
2496
+ new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
2497
+ else:
2498
+ new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
2499
+ old_prefix + "mod_y.weight"
2500
+ )
2501
+ new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
2502
+
2503
+ # Visual attention
2504
+ qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
2505
+ q, k, v = qkv_weight.chunk(3, dim=0)
2506
+
2507
+ new_state_dict[block_prefix + "attn1.to_q.weight"] = q
2508
+ new_state_dict[block_prefix + "attn1.to_k.weight"] = k
2509
+ new_state_dict[block_prefix + "attn1.to_v.weight"] = v
2510
+ new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
2511
+ new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
2512
+ new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
2513
+ new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
2514
+
2515
+ # Context attention
2516
+ qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
2517
+ q, k, v = qkv_weight.chunk(3, dim=0)
2518
+
2519
+ new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
2520
+ new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
2521
+ new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
2522
+ new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
2523
+ old_prefix + "attn.q_norm_y.weight"
2524
+ )
2525
+ new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
2526
+ old_prefix + "attn.k_norm_y.weight"
2527
+ )
2528
+ if i < num_layers - 1:
2529
+ new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
2530
+ old_prefix + "attn.proj_y.weight"
2531
+ )
2532
+ new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
2533
+
2534
+ # MLP
2535
+ new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
2536
+ checkpoint.pop(old_prefix + "mlp_x.w1.weight")
2537
+ )
2538
+ new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
2539
+ if i < num_layers - 1:
2540
+ new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
2541
+ checkpoint.pop(old_prefix + "mlp_y.w1.weight")
2542
+ )
2543
+ new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
2544
+
2545
+ # Output layers
2546
+ new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
2547
+ new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
2548
+ new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
2549
+ new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
2550
+
2551
+ new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
2552
+
2553
+ return new_state_dict
2554
+
2555
+
2556
+ def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
2557
+ def remap_norm_scale_shift_(key, state_dict):
2558
+ weight = state_dict.pop(key)
2559
+ shift, scale = weight.chunk(2, dim=0)
2560
+ new_weight = torch.cat([scale, shift], dim=0)
2561
+ state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
2562
+
2563
+ def remap_txt_in_(key, state_dict):
2564
+ def rename_key(key):
2565
+ new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
2566
+ new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
2567
+ new_key = new_key.replace("txt_in", "context_embedder")
2568
+ new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
2569
+ new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
2570
+ new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
2571
+ new_key = new_key.replace("mlp", "ff")
2572
+ return new_key
2573
+
2574
+ if "self_attn_qkv" in key:
2575
+ weight = state_dict.pop(key)
2576
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
2577
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
2578
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
2579
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
2580
+ else:
2581
+ state_dict[rename_key(key)] = state_dict.pop(key)
2582
+
2583
+ def remap_img_attn_qkv_(key, state_dict):
2584
+ weight = state_dict.pop(key)
2585
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
2586
+ state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
2587
+ state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
2588
+ state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
2589
+
2590
+ def remap_txt_attn_qkv_(key, state_dict):
2591
+ weight = state_dict.pop(key)
2592
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
2593
+ state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
2594
+ state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
2595
+ state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
2596
+
2597
+ def remap_single_transformer_blocks_(key, state_dict):
2598
+ hidden_size = 3072
2599
+
2600
+ if "linear1.weight" in key:
2601
+ linear1_weight = state_dict.pop(key)
2602
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
2603
+ q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
2604
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
2605
+ state_dict[f"{new_key}.attn.to_q.weight"] = q
2606
+ state_dict[f"{new_key}.attn.to_k.weight"] = k
2607
+ state_dict[f"{new_key}.attn.to_v.weight"] = v
2608
+ state_dict[f"{new_key}.proj_mlp.weight"] = mlp
2609
+
2610
+ elif "linear1.bias" in key:
2611
+ linear1_bias = state_dict.pop(key)
2612
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
2613
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
2614
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
2615
+ state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
2616
+ state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
2617
+ state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
2618
+ state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
2619
+
2620
+ else:
2621
+ new_key = key.replace("single_blocks", "single_transformer_blocks")
2622
+ new_key = new_key.replace("linear2", "proj_out")
2623
+ new_key = new_key.replace("q_norm", "attn.norm_q")
2624
+ new_key = new_key.replace("k_norm", "attn.norm_k")
2625
+ state_dict[new_key] = state_dict.pop(key)
2626
+
2627
+ TRANSFORMER_KEYS_RENAME_DICT = {
2628
+ "img_in": "x_embedder",
2629
+ "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
2630
+ "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
2631
+ "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
2632
+ "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
2633
+ "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
2634
+ "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
2635
+ "double_blocks": "transformer_blocks",
2636
+ "img_attn_q_norm": "attn.norm_q",
2637
+ "img_attn_k_norm": "attn.norm_k",
2638
+ "img_attn_proj": "attn.to_out.0",
2639
+ "txt_attn_q_norm": "attn.norm_added_q",
2640
+ "txt_attn_k_norm": "attn.norm_added_k",
2641
+ "txt_attn_proj": "attn.to_add_out",
2642
+ "img_mod.linear": "norm1.linear",
2643
+ "img_norm1": "norm1.norm",
2644
+ "img_norm2": "norm2",
2645
+ "img_mlp": "ff",
2646
+ "txt_mod.linear": "norm1_context.linear",
2647
+ "txt_norm1": "norm1.norm",
2648
+ "txt_norm2": "norm2_context",
2649
+ "txt_mlp": "ff_context",
2650
+ "self_attn_proj": "attn.to_out.0",
2651
+ "modulation.linear": "norm.linear",
2652
+ "pre_norm": "norm.norm",
2653
+ "final_layer.norm_final": "norm_out.norm",
2654
+ "final_layer.linear": "proj_out",
2655
+ "fc1": "net.0.proj",
2656
+ "fc2": "net.2",
2657
+ "input_embedder": "proj_in",
2658
+ }
2659
+
2660
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
2661
+ "txt_in": remap_txt_in_,
2662
+ "img_attn_qkv": remap_img_attn_qkv_,
2663
+ "txt_attn_qkv": remap_txt_attn_qkv_,
2664
+ "single_blocks": remap_single_transformer_blocks_,
2665
+ "final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
2666
+ }
2667
+
2668
+ def update_state_dict_(state_dict, old_key, new_key):
2669
+ state_dict[new_key] = state_dict.pop(old_key)
2670
+
2671
+ for key in list(checkpoint.keys()):
2672
+ new_key = key[:]
2673
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
2674
+ new_key = new_key.replace(replace_key, rename_key)
2675
+ update_state_dict_(checkpoint, key, new_key)
2676
+
2677
+ for key in list(checkpoint.keys()):
2678
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
2679
+ if special_key not in key:
2680
+ continue
2681
+ handler_fn_inplace(key, checkpoint)
2682
+
2683
+ return checkpoint