diffusers 0.27.0__py3-none-any.whl → 0.32.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (445) hide show
  1. diffusers/__init__.py +233 -6
  2. diffusers/callbacks.py +209 -0
  3. diffusers/commands/env.py +102 -6
  4. diffusers/configuration_utils.py +45 -16
  5. diffusers/dependency_versions_table.py +4 -3
  6. diffusers/image_processor.py +434 -110
  7. diffusers/loaders/__init__.py +42 -9
  8. diffusers/loaders/ip_adapter.py +626 -36
  9. diffusers/loaders/lora_base.py +900 -0
  10. diffusers/loaders/lora_conversion_utils.py +991 -125
  11. diffusers/loaders/lora_pipeline.py +3812 -0
  12. diffusers/loaders/peft.py +571 -7
  13. diffusers/loaders/single_file.py +405 -173
  14. diffusers/loaders/single_file_model.py +385 -0
  15. diffusers/loaders/single_file_utils.py +1783 -713
  16. diffusers/loaders/textual_inversion.py +41 -23
  17. diffusers/loaders/transformer_flux.py +181 -0
  18. diffusers/loaders/transformer_sd3.py +89 -0
  19. diffusers/loaders/unet.py +464 -540
  20. diffusers/loaders/unet_loader_utils.py +163 -0
  21. diffusers/models/__init__.py +76 -7
  22. diffusers/models/activations.py +65 -10
  23. diffusers/models/adapter.py +53 -53
  24. diffusers/models/attention.py +605 -18
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +4304 -687
  27. diffusers/models/autoencoders/__init__.py +8 -0
  28. diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
  29. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  30. diffusers/models/autoencoders/autoencoder_kl.py +110 -28
  31. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  32. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  33. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  34. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  35. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  36. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
  37. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  38. diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
  39. diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
  40. diffusers/models/autoencoders/vae.py +41 -29
  41. diffusers/models/autoencoders/vq_model.py +182 -0
  42. diffusers/models/controlnet.py +47 -800
  43. diffusers/models/controlnet_flux.py +70 -0
  44. diffusers/models/controlnet_sd3.py +68 -0
  45. diffusers/models/controlnet_sparsectrl.py +116 -0
  46. diffusers/models/controlnets/__init__.py +23 -0
  47. diffusers/models/controlnets/controlnet.py +872 -0
  48. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
  49. diffusers/models/controlnets/controlnet_flux.py +536 -0
  50. diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
  51. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  52. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  53. diffusers/models/controlnets/controlnet_union.py +832 -0
  54. diffusers/models/controlnets/controlnet_xs.py +1946 -0
  55. diffusers/models/controlnets/multicontrolnet.py +183 -0
  56. diffusers/models/downsampling.py +85 -18
  57. diffusers/models/embeddings.py +1856 -158
  58. diffusers/models/embeddings_flax.py +23 -9
  59. diffusers/models/model_loading_utils.py +480 -0
  60. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  61. diffusers/models/modeling_flax_utils.py +2 -7
  62. diffusers/models/modeling_outputs.py +14 -0
  63. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  64. diffusers/models/modeling_utils.py +611 -146
  65. diffusers/models/normalization.py +361 -20
  66. diffusers/models/resnet.py +18 -23
  67. diffusers/models/transformers/__init__.py +16 -0
  68. diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
  69. diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
  70. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  71. diffusers/models/transformers/dual_transformer_2d.py +9 -8
  72. diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
  73. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  74. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  75. diffusers/models/transformers/pixart_transformer_2d.py +445 -0
  76. diffusers/models/transformers/prior_transformer.py +13 -13
  77. diffusers/models/transformers/sana_transformer.py +488 -0
  78. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  79. diffusers/models/transformers/t5_film_transformer.py +17 -19
  80. diffusers/models/transformers/transformer_2d.py +297 -187
  81. diffusers/models/transformers/transformer_allegro.py +422 -0
  82. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  83. diffusers/models/transformers/transformer_flux.py +593 -0
  84. diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
  85. diffusers/models/transformers/transformer_ltx.py +469 -0
  86. diffusers/models/transformers/transformer_mochi.py +499 -0
  87. diffusers/models/transformers/transformer_sd3.py +461 -0
  88. diffusers/models/transformers/transformer_temporal.py +21 -19
  89. diffusers/models/unets/unet_1d.py +8 -8
  90. diffusers/models/unets/unet_1d_blocks.py +31 -31
  91. diffusers/models/unets/unet_2d.py +17 -10
  92. diffusers/models/unets/unet_2d_blocks.py +225 -149
  93. diffusers/models/unets/unet_2d_condition.py +50 -53
  94. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  95. diffusers/models/unets/unet_3d_blocks.py +192 -1057
  96. diffusers/models/unets/unet_3d_condition.py +22 -27
  97. diffusers/models/unets/unet_i2vgen_xl.py +22 -18
  98. diffusers/models/unets/unet_kandinsky3.py +2 -2
  99. diffusers/models/unets/unet_motion_model.py +1413 -89
  100. diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
  101. diffusers/models/unets/unet_stable_cascade.py +19 -18
  102. diffusers/models/unets/uvit_2d.py +2 -2
  103. diffusers/models/upsampling.py +95 -26
  104. diffusers/models/vq_model.py +12 -164
  105. diffusers/optimization.py +1 -1
  106. diffusers/pipelines/__init__.py +202 -3
  107. diffusers/pipelines/allegro/__init__.py +48 -0
  108. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  109. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  110. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  111. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  112. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  113. diffusers/pipelines/animatediff/__init__.py +8 -0
  114. diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
  115. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
  116. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
  117. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
  118. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
  119. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  120. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  121. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  122. diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
  123. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
  124. diffusers/pipelines/aura_flow/__init__.py +48 -0
  125. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
  126. diffusers/pipelines/auto_pipeline.py +196 -28
  127. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  128. diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
  129. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  130. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  131. diffusers/pipelines/cogvideo/__init__.py +54 -0
  132. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
  133. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  134. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
  135. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
  136. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  137. diffusers/pipelines/cogview3/__init__.py +47 -0
  138. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  139. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  140. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
  141. diffusers/pipelines/controlnet/__init__.py +86 -80
  142. diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
  143. diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
  144. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  145. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
  146. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
  147. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
  148. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
  149. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
  150. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  151. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  152. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  153. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  154. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  155. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
  156. diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
  157. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
  158. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  159. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  160. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
  161. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
  162. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  163. diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
  164. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
  165. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
  166. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
  167. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
  168. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
  169. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  170. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  171. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  172. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
  173. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
  174. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  175. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  176. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
  177. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
  178. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
  179. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
  180. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
  181. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
  182. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  183. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  184. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  185. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  186. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  187. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  188. diffusers/pipelines/flux/__init__.py +69 -0
  189. diffusers/pipelines/flux/modeling_flux.py +47 -0
  190. diffusers/pipelines/flux/pipeline_flux.py +957 -0
  191. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  192. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  193. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  194. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  195. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  196. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  197. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  198. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  199. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  200. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  201. diffusers/pipelines/flux/pipeline_output.py +37 -0
  202. diffusers/pipelines/free_init_utils.py +41 -38
  203. diffusers/pipelines/free_noise_utils.py +596 -0
  204. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  205. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  206. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  207. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  208. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
  209. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  210. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  211. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
  212. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  213. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  214. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  215. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  216. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
  217. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  218. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  219. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  220. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  221. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  222. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  223. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  224. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
  225. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
  226. diffusers/pipelines/kolors/__init__.py +54 -0
  227. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  228. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
  229. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  230. diffusers/pipelines/kolors/text_encoder.py +889 -0
  231. diffusers/pipelines/kolors/tokenizer.py +338 -0
  232. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
  233. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
  234. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
  235. diffusers/pipelines/latte/__init__.py +48 -0
  236. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  237. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
  238. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
  239. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  240. diffusers/pipelines/ltx/__init__.py +50 -0
  241. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  242. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  243. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  244. diffusers/pipelines/lumina/__init__.py +48 -0
  245. diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
  246. diffusers/pipelines/marigold/__init__.py +50 -0
  247. diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
  248. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  249. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  250. diffusers/pipelines/mochi/__init__.py +48 -0
  251. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  252. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  253. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  254. diffusers/pipelines/pag/__init__.py +80 -0
  255. diffusers/pipelines/pag/pag_utils.py +243 -0
  256. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
  257. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  258. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
  259. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  260. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
  261. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  262. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
  263. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  264. diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
  265. diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
  266. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  267. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
  268. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  269. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  270. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
  271. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
  272. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
  273. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  274. diffusers/pipelines/pia/pipeline_pia.py +74 -164
  275. diffusers/pipelines/pipeline_flax_utils.py +5 -10
  276. diffusers/pipelines/pipeline_loading_utils.py +515 -53
  277. diffusers/pipelines/pipeline_utils.py +411 -222
  278. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  279. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
  280. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
  281. diffusers/pipelines/sana/__init__.py +47 -0
  282. diffusers/pipelines/sana/pipeline_output.py +21 -0
  283. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  284. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
  285. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  286. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  287. diffusers/pipelines/shap_e/renderer.py +1 -1
  288. diffusers/pipelines/stable_audio/__init__.py +50 -0
  289. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  290. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
  291. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
  292. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  293. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
  294. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  295. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
  296. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  297. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  298. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  299. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
  300. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
  301. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  302. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
  303. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
  304. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
  305. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
  306. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
  307. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
  308. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
  309. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  310. diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  311. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  312. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  313. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
  314. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
  315. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
  316. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
  317. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
  318. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
  319. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
  320. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
  321. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
  322. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
  323. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  324. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  325. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
  326. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
  327. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
  328. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
  329. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
  330. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  331. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  332. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
  333. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
  334. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  335. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
  336. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
  337. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
  338. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
  339. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  340. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  341. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  342. diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
  343. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
  344. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  345. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  346. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
  347. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  348. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  349. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
  350. diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
  351. diffusers/quantizers/auto.py +139 -0
  352. diffusers/quantizers/base.py +233 -0
  353. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  354. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  355. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  356. diffusers/quantizers/gguf/__init__.py +1 -0
  357. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  358. diffusers/quantizers/gguf/utils.py +456 -0
  359. diffusers/quantizers/quantization_config.py +669 -0
  360. diffusers/quantizers/torchao/__init__.py +15 -0
  361. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  362. diffusers/schedulers/__init__.py +12 -2
  363. diffusers/schedulers/deprecated/__init__.py +1 -1
  364. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  365. diffusers/schedulers/scheduling_amused.py +5 -5
  366. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  367. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  368. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  369. diffusers/schedulers/scheduling_ddim.py +27 -26
  370. diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
  371. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  372. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  373. diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
  374. diffusers/schedulers/scheduling_ddpm.py +27 -30
  375. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  376. diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
  377. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  378. diffusers/schedulers/scheduling_deis_multistep.py +150 -50
  379. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  380. diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
  381. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  382. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
  383. diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
  384. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
  385. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
  386. diffusers/schedulers/scheduling_edm_euler.py +62 -39
  387. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
  388. diffusers/schedulers/scheduling_euler_discrete.py +255 -74
  389. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
  390. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
  391. diffusers/schedulers/scheduling_heun_discrete.py +174 -46
  392. diffusers/schedulers/scheduling_ipndm.py +9 -9
  393. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
  394. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
  395. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  396. diffusers/schedulers/scheduling_lcm.py +23 -29
  397. diffusers/schedulers/scheduling_lms_discrete.py +105 -28
  398. diffusers/schedulers/scheduling_pndm.py +20 -20
  399. diffusers/schedulers/scheduling_repaint.py +21 -21
  400. diffusers/schedulers/scheduling_sasolver.py +157 -60
  401. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  402. diffusers/schedulers/scheduling_tcd.py +41 -36
  403. diffusers/schedulers/scheduling_unclip.py +19 -16
  404. diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
  405. diffusers/schedulers/scheduling_utils.py +12 -5
  406. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  407. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  408. diffusers/training_utils.py +214 -30
  409. diffusers/utils/__init__.py +17 -1
  410. diffusers/utils/constants.py +3 -0
  411. diffusers/utils/doc_utils.py +1 -0
  412. diffusers/utils/dummy_pt_objects.py +592 -7
  413. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  414. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  415. diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
  416. diffusers/utils/dynamic_modules_utils.py +34 -29
  417. diffusers/utils/export_utils.py +50 -6
  418. diffusers/utils/hub_utils.py +131 -17
  419. diffusers/utils/import_utils.py +210 -8
  420. diffusers/utils/loading_utils.py +118 -5
  421. diffusers/utils/logging.py +4 -2
  422. diffusers/utils/peft_utils.py +37 -7
  423. diffusers/utils/state_dict_utils.py +13 -2
  424. diffusers/utils/testing_utils.py +193 -11
  425. diffusers/utils/torch_utils.py +4 -0
  426. diffusers/video_processor.py +113 -0
  427. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
  428. diffusers-0.32.2.dist-info/RECORD +550 -0
  429. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
  430. diffusers/loaders/autoencoder.py +0 -146
  431. diffusers/loaders/controlnet.py +0 -136
  432. diffusers/loaders/lora.py +0 -1349
  433. diffusers/models/prior_transformer.py +0 -12
  434. diffusers/models/t5_film_transformer.py +0 -70
  435. diffusers/models/transformer_2d.py +0 -25
  436. diffusers/models/transformer_temporal.py +0 -34
  437. diffusers/models/unet_1d.py +0 -26
  438. diffusers/models/unet_1d_blocks.py +0 -203
  439. diffusers/models/unet_2d.py +0 -27
  440. diffusers/models/unet_2d_blocks.py +0 -375
  441. diffusers/models/unet_2d_condition.py +0 -25
  442. diffusers-0.27.0.dist-info/RECORD +0 -399
  443. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
  444. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
  445. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -12,8 +12,9 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
- """ Conversion script for the Stable Diffusion checkpoints."""
15
+ """Conversion script for the Stable Diffusion checkpoints."""
16
16
 
17
+ import copy
17
18
  import os
18
19
  import re
19
20
  from contextlib import nullcontext
@@ -21,12 +22,12 @@ from io import BytesIO
21
22
  from urllib.parse import urlparse
22
23
 
23
24
  import requests
25
+ import torch
24
26
  import yaml
25
27
 
26
28
  from ..models.modeling_utils import load_state_dict
27
29
  from ..schedulers import (
28
30
  DDIMScheduler,
29
- DDPMScheduler,
30
31
  DPMSolverMultistepScheduler,
31
32
  EDMDPMSolverMultistepScheduler,
32
33
  EulerAncestralDiscreteScheduler,
@@ -35,133 +36,152 @@ from ..schedulers import (
35
36
  LMSDiscreteScheduler,
36
37
  PNDMScheduler,
37
38
  )
38
- from ..utils import is_accelerate_available, is_transformers_available, logging
39
+ from ..utils import (
40
+ SAFETENSORS_WEIGHTS_NAME,
41
+ WEIGHTS_NAME,
42
+ deprecate,
43
+ is_accelerate_available,
44
+ is_transformers_available,
45
+ logging,
46
+ )
39
47
  from ..utils.hub_utils import _get_model_file
40
48
 
41
49
 
42
50
  if is_transformers_available():
43
- from transformers import (
44
- CLIPTextConfig,
45
- CLIPTextModel,
46
- CLIPTextModelWithProjection,
47
- CLIPTokenizer,
48
- )
51
+ from transformers import AutoImageProcessor
49
52
 
50
53
  if is_accelerate_available():
51
54
  from accelerate import init_empty_weights
52
55
 
53
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+ from ..models.modeling_utils import load_model_dict_into_meta
54
57
 
55
- CONFIG_URLS = {
56
- "v1": "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml",
57
- "v2": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml",
58
- "xl": "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml",
59
- "xl_refiner": "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml",
60
- "upscale": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml",
61
- "controlnet": "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml",
62
- }
58
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
63
59
 
64
60
  CHECKPOINT_KEY_NAMES = {
65
61
  "v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
66
62
  "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
67
63
  "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
64
+ "upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias",
65
+ "controlnet": [
66
+ "control_model.time_embed.0.weight",
67
+ "controlnet_cond_embedding.conv_in.weight",
68
+ ],
69
+ # TODO: find non-Diffusers keys for controlnet_xl
70
+ "controlnet_xl": "add_embedding.linear_1.weight",
71
+ "controlnet_xl_large": "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
72
+ "controlnet_xl_mid": "down_blocks.1.attentions.0.norm.weight",
73
+ "playground-v2-5": "edm_mean",
74
+ "inpainting": "model.diffusion_model.input_blocks.0.0.weight",
75
+ "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
76
+ "clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight",
77
+ "clip_sd3": "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight",
78
+ "open_clip": "cond_stage_model.model.token_embedding.weight",
79
+ "open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding",
80
+ "open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection",
81
+ "open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
82
+ "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
83
+ "stable_cascade_stage_c": "clip_txt_mapper.weight",
84
+ "sd3": [
85
+ "joint_blocks.0.context_block.adaLN_modulation.1.bias",
86
+ "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
87
+ ],
88
+ "sd35_large": [
89
+ "joint_blocks.37.x_block.mlp.fc1.weight",
90
+ "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
91
+ ],
92
+ "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
93
+ "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
94
+ "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
95
+ "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
96
+ "animatediff_rgb": "controlnet_cond_embedding.weight",
97
+ "flux": [
98
+ "double_blocks.0.img_attn.norm.key_norm.scale",
99
+ "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
100
+ ],
101
+ "ltx-video": [
102
+ "model.diffusion_model.patchify_proj.weight",
103
+ "model.diffusion_model.transformer_blocks.27.scale_shift_table",
104
+ "patchify_proj.weight",
105
+ "transformer_blocks.27.scale_shift_table",
106
+ "vae.per_channel_statistics.mean-of-means",
107
+ ],
108
+ "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
109
+ "autoencoder-dc-sana": "encoder.project_in.conv.bias",
110
+ "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
111
+ "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
68
112
  }
69
113
 
70
- SCHEDULER_DEFAULT_CONFIG = {
71
- "beta_schedule": "scaled_linear",
72
- "beta_start": 0.00085,
73
- "beta_end": 0.012,
74
- "interpolation_type": "linear",
75
- "num_train_timesteps": 1000,
76
- "prediction_type": "epsilon",
77
- "sample_max_value": 1.0,
78
- "set_alpha_to_one": False,
79
- "skip_prk_steps": True,
80
- "steps_offset": 1,
81
- "timestep_spacing": "leading",
114
+ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
115
+ "xl_base": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0"},
116
+ "xl_refiner": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-refiner-1.0"},
117
+ "xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"},
118
+ "playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"},
119
+ "upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"},
120
+ "inpainting": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-inpainting"},
121
+ "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"},
122
+ "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"},
123
+ "controlnet_xl_large": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0"},
124
+ "controlnet_xl_mid": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-mid"},
125
+ "controlnet_xl_small": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-small"},
126
+ "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"},
127
+ "v1": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-v1-5"},
128
+ "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"},
129
+ "stable_cascade_stage_b_lite": {
130
+ "pretrained_model_name_or_path": "stabilityai/stable-cascade",
131
+ "subfolder": "decoder_lite",
132
+ },
133
+ "stable_cascade_stage_c": {
134
+ "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
135
+ "subfolder": "prior",
136
+ },
137
+ "stable_cascade_stage_c_lite": {
138
+ "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
139
+ "subfolder": "prior_lite",
140
+ },
141
+ "sd3": {
142
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
143
+ },
144
+ "sd35_large": {
145
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-large",
146
+ },
147
+ "sd35_medium": {
148
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-medium",
149
+ },
150
+ "animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
151
+ "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
152
+ "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
153
+ "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
154
+ "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
155
+ "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
156
+ "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
157
+ "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
158
+ "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
159
+ "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
160
+ "ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
161
+ "ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
162
+ "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
163
+ "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
164
+ "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
165
+ "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
166
+ "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
167
+ "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
82
168
  }
83
169
 
84
-
85
- STABLE_CASCADE_DEFAULT_CONFIGS = {
86
- "stage_c": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior"},
87
- "stage_c_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior_lite"},
88
- "stage_b": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder"},
89
- "stage_b_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder_lite"},
170
+ # Use to configure model sample size when original config is provided
171
+ DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = {
172
+ "xl_base": 1024,
173
+ "xl_refiner": 1024,
174
+ "xl_inpaint": 1024,
175
+ "playground-v2-5": 1024,
176
+ "upscale": 512,
177
+ "inpainting": 512,
178
+ "inpainting_v2": 512,
179
+ "controlnet": 512,
180
+ "v2": 768,
181
+ "v1": 512,
90
182
  }
91
183
 
92
184
 
93
- def convert_stable_cascade_unet_single_file_to_diffusers(original_state_dict):
94
- is_stage_c = "clip_txt_mapper.weight" in original_state_dict
95
-
96
- if is_stage_c:
97
- state_dict = {}
98
- for key in original_state_dict.keys():
99
- if key.endswith("in_proj_weight"):
100
- weights = original_state_dict[key].chunk(3, 0)
101
- state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
102
- state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
103
- state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
104
- elif key.endswith("in_proj_bias"):
105
- weights = original_state_dict[key].chunk(3, 0)
106
- state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
107
- state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
108
- state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
109
- elif key.endswith("out_proj.weight"):
110
- weights = original_state_dict[key]
111
- state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
112
- elif key.endswith("out_proj.bias"):
113
- weights = original_state_dict[key]
114
- state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
115
- else:
116
- state_dict[key] = original_state_dict[key]
117
- else:
118
- state_dict = {}
119
- for key in original_state_dict.keys():
120
- if key.endswith("in_proj_weight"):
121
- weights = original_state_dict[key].chunk(3, 0)
122
- state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
123
- state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
124
- state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
125
- elif key.endswith("in_proj_bias"):
126
- weights = original_state_dict[key].chunk(3, 0)
127
- state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
128
- state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
129
- state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
130
- elif key.endswith("out_proj.weight"):
131
- weights = original_state_dict[key]
132
- state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
133
- elif key.endswith("out_proj.bias"):
134
- weights = original_state_dict[key]
135
- state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
136
- # rename clip_mapper to clip_txt_pooled_mapper
137
- elif key.endswith("clip_mapper.weight"):
138
- weights = original_state_dict[key]
139
- state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
140
- elif key.endswith("clip_mapper.bias"):
141
- weights = original_state_dict[key]
142
- state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
143
- else:
144
- state_dict[key] = original_state_dict[key]
145
-
146
- return state_dict
147
-
148
-
149
- def infer_stable_cascade_single_file_config(checkpoint):
150
- is_stage_c = "clip_txt_mapper.weight" in checkpoint
151
- is_stage_b = "down_blocks.1.0.channelwise.0.weight" in checkpoint
152
-
153
- if is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 1536):
154
- config_type = "stage_c_lite"
155
- elif is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 2048):
156
- config_type = "stage_c"
157
- elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 576:
158
- config_type = "stage_b_lite"
159
- elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 640:
160
- config_type = "stage_b"
161
-
162
- return STABLE_CASCADE_DEFAULT_CONFIGS[config_type]
163
-
164
-
165
185
  DIFFUSERS_TO_LDM_MAPPING = {
166
186
  "unet": {
167
187
  "layers": {
@@ -255,14 +275,6 @@ DIFFUSERS_TO_LDM_MAPPING = {
255
275
  },
256
276
  }
257
277
 
258
- LDM_VAE_KEY = "first_stage_model."
259
- LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
260
- PLAYGROUND_VAE_SCALING_FACTOR = 0.5
261
- LDM_UNET_KEY = "model.diffusion_model."
262
- LDM_CONTROLNET_KEY = "control_model."
263
- LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
264
- LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
265
-
266
278
  SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
267
279
  "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias",
268
280
  "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight",
@@ -279,11 +291,54 @@ SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
279
291
  "cond_stage_model.model.text_projection",
280
292
  ]
281
293
 
294
+ # To support legacy scheduler_type argument
295
+ SCHEDULER_DEFAULT_CONFIG = {
296
+ "beta_schedule": "scaled_linear",
297
+ "beta_start": 0.00085,
298
+ "beta_end": 0.012,
299
+ "interpolation_type": "linear",
300
+ "num_train_timesteps": 1000,
301
+ "prediction_type": "epsilon",
302
+ "sample_max_value": 1.0,
303
+ "set_alpha_to_one": False,
304
+ "skip_prk_steps": True,
305
+ "steps_offset": 1,
306
+ "timestep_spacing": "leading",
307
+ }
308
+
309
+ LDM_VAE_KEYS = ["first_stage_model.", "vae."]
310
+ LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
311
+ PLAYGROUND_VAE_SCALING_FACTOR = 0.5
312
+ LDM_UNET_KEY = "model.diffusion_model."
313
+ LDM_CONTROLNET_KEY = "control_model."
314
+ LDM_CLIP_PREFIX_TO_REMOVE = [
315
+ "cond_stage_model.transformer.",
316
+ "conditioner.embedders.0.transformer.",
317
+ ]
318
+ LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
319
+ SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"]
282
320
 
283
321
  VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
284
322
 
285
323
 
324
+ class SingleFileComponentError(Exception):
325
+ def __init__(self, message=None):
326
+ self.message = message
327
+ super().__init__(self.message)
328
+
329
+
330
+ def is_valid_url(url):
331
+ result = urlparse(url)
332
+ if result.scheme and result.netloc:
333
+ return True
334
+
335
+ return False
336
+
337
+
286
338
  def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
339
+ if not is_valid_url(pretrained_model_name_or_path):
340
+ raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")
341
+
287
342
  pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)"
288
343
  weights_name = None
289
344
  repo_id = (None,)
@@ -291,6 +346,7 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
291
346
  pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
292
347
  match = re.match(pattern, pretrained_model_name_or_path)
293
348
  if not match:
349
+ logger.warning("Unable to identify the repo_id and weights_name from the provided URL.")
294
350
  return repo_id, weights_name
295
351
 
296
352
  repo_id = f"{match.group(1)}/{match.group(2)}"
@@ -299,36 +355,23 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
299
355
  return repo_id, weights_name
300
356
 
301
357
 
302
- def fetch_ldm_config_and_checkpoint(
303
- pretrained_model_link_or_path,
304
- class_name,
305
- original_config_file=None,
306
- resume_download=False,
307
- force_download=False,
308
- proxies=None,
309
- token=None,
310
- cache_dir=None,
311
- local_files_only=None,
312
- revision=None,
313
- ):
314
- checkpoint = load_single_file_model_checkpoint(
315
- pretrained_model_link_or_path,
316
- resume_download=resume_download,
317
- force_download=force_download,
318
- proxies=proxies,
319
- token=token,
320
- cache_dir=cache_dir,
321
- local_files_only=local_files_only,
322
- revision=revision,
323
- )
324
- original_config = fetch_original_config(class_name, checkpoint, original_config_file)
358
+ def _is_model_weights_in_cached_folder(cached_folder, name):
359
+ pretrained_model_name_or_path = os.path.join(cached_folder, name)
360
+ weights_exist = False
361
+
362
+ for weights_name in [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME]:
363
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
364
+ weights_exist = True
365
+
366
+ return weights_exist
325
367
 
326
- return original_config, checkpoint
327
368
 
369
+ def _is_legacy_scheduler_kwargs(kwargs):
370
+ return any(k in SCHEDULER_LEGACY_KWARGS for k in kwargs.keys())
328
371
 
329
- def load_single_file_model_checkpoint(
372
+
373
+ def load_single_file_checkpoint(
330
374
  pretrained_model_link_or_path,
331
- resume_download=False,
332
375
  force_download=False,
333
376
  proxies=None,
334
377
  token=None,
@@ -337,21 +380,22 @@ def load_single_file_model_checkpoint(
337
380
  revision=None,
338
381
  ):
339
382
  if os.path.isfile(pretrained_model_link_or_path):
340
- checkpoint = load_state_dict(pretrained_model_link_or_path)
383
+ pretrained_model_link_or_path = pretrained_model_link_or_path
384
+
341
385
  else:
342
386
  repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
343
- checkpoint_path = _get_model_file(
387
+ pretrained_model_link_or_path = _get_model_file(
344
388
  repo_id,
345
389
  weights_name=weights_name,
346
390
  force_download=force_download,
347
391
  cache_dir=cache_dir,
348
- resume_download=resume_download,
349
392
  proxies=proxies,
350
393
  local_files_only=local_files_only,
351
394
  token=token,
352
395
  revision=revision,
353
396
  )
354
- checkpoint = load_state_dict(checkpoint_path)
397
+
398
+ checkpoint = load_state_dict(pretrained_model_link_or_path)
355
399
 
356
400
  # some checkpoints contain the model state dict under a "state_dict" key
357
401
  while "state_dict" in checkpoint:
@@ -360,120 +404,262 @@ def load_single_file_model_checkpoint(
360
404
  return checkpoint
361
405
 
362
406
 
363
- def infer_original_config_file(class_name, checkpoint):
364
- if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
365
- config_url = CONFIG_URLS["v2"]
407
+ def fetch_original_config(original_config_file, local_files_only=False):
408
+ if os.path.isfile(original_config_file):
409
+ with open(original_config_file, "r") as fp:
410
+ original_config_file = fp.read()
366
411
 
367
- elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
368
- config_url = CONFIG_URLS["xl"]
412
+ elif is_valid_url(original_config_file):
413
+ if local_files_only:
414
+ raise ValueError(
415
+ "`local_files_only` is set to True, but a URL was provided as `original_config_file`. "
416
+ "Please provide a valid local file path."
417
+ )
369
418
 
370
- elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint:
371
- config_url = CONFIG_URLS["xl_refiner"]
419
+ original_config_file = BytesIO(requests.get(original_config_file).content)
372
420
 
373
- elif class_name == "StableDiffusionUpscalePipeline":
374
- config_url = CONFIG_URLS["upscale"]
421
+ else:
422
+ raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
375
423
 
376
- elif class_name == "ControlNetModel":
377
- config_url = CONFIG_URLS["controlnet"]
424
+ original_config = yaml.safe_load(original_config_file)
378
425
 
379
- else:
380
- config_url = CONFIG_URLS["v1"]
426
+ return original_config
381
427
 
382
- original_config_file = BytesIO(requests.get(config_url).content)
383
428
 
384
- return original_config_file
429
+ def is_clip_model(checkpoint):
430
+ if CHECKPOINT_KEY_NAMES["clip"] in checkpoint:
431
+ return True
385
432
 
433
+ return False
386
434
 
387
- def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=None):
388
- def is_valid_url(url):
389
- result = urlparse(url)
390
- if result.scheme and result.netloc:
391
- return True
392
435
 
393
- return False
436
+ def is_clip_sdxl_model(checkpoint):
437
+ if CHECKPOINT_KEY_NAMES["clip_sdxl"] in checkpoint:
438
+ return True
394
439
 
395
- if original_config_file is None:
396
- original_config_file = infer_original_config_file(pipeline_class_name, checkpoint)
440
+ return False
397
441
 
398
- elif os.path.isfile(original_config_file):
399
- with open(original_config_file, "r") as fp:
400
- original_config_file = fp.read()
401
442
 
402
- elif is_valid_url(original_config_file):
403
- original_config_file = BytesIO(requests.get(original_config_file).content)
443
+ def is_clip_sd3_model(checkpoint):
444
+ if CHECKPOINT_KEY_NAMES["clip_sd3"] in checkpoint:
445
+ return True
404
446
 
405
- else:
406
- raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
447
+ return False
407
448
 
408
- original_config = yaml.safe_load(original_config_file)
409
449
 
410
- return original_config
450
+ def is_open_clip_model(checkpoint):
451
+ if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint:
452
+ return True
411
453
 
454
+ return False
412
455
 
413
- def infer_model_type(original_config, checkpoint, model_type=None):
414
- if model_type is not None:
415
- return model_type
416
456
 
417
- has_cond_stage_config = (
418
- "cond_stage_config" in original_config["model"]["params"]
419
- and original_config["model"]["params"]["cond_stage_config"] is not None
420
- )
421
- has_network_config = (
422
- "network_config" in original_config["model"]["params"]
423
- and original_config["model"]["params"]["network_config"] is not None
457
+ def is_open_clip_sdxl_model(checkpoint):
458
+ if CHECKPOINT_KEY_NAMES["open_clip_sdxl"] in checkpoint:
459
+ return True
460
+
461
+ return False
462
+
463
+
464
+ def is_open_clip_sd3_model(checkpoint):
465
+ if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
466
+ return True
467
+
468
+ return False
469
+
470
+
471
+ def is_open_clip_sdxl_refiner_model(checkpoint):
472
+ if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
473
+ return True
474
+
475
+ return False
476
+
477
+
478
+ def is_clip_model_in_single_file(class_obj, checkpoint):
479
+ is_clip_in_checkpoint = any(
480
+ [
481
+ is_clip_model(checkpoint),
482
+ is_clip_sd3_model(checkpoint),
483
+ is_open_clip_model(checkpoint),
484
+ is_open_clip_sdxl_model(checkpoint),
485
+ is_open_clip_sdxl_refiner_model(checkpoint),
486
+ is_open_clip_sd3_model(checkpoint),
487
+ ]
424
488
  )
489
+ if (
490
+ class_obj.__name__ == "CLIPTextModel" or class_obj.__name__ == "CLIPTextModelWithProjection"
491
+ ) and is_clip_in_checkpoint:
492
+ return True
425
493
 
426
- if has_cond_stage_config:
427
- model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1]
494
+ return False
428
495
 
429
- elif has_network_config:
430
- context_dim = original_config["model"]["params"]["network_config"]["params"]["context_dim"]
431
- if "edm_mean" in checkpoint and "edm_std" in checkpoint:
432
- model_type = "Playground"
433
- elif context_dim == 2048:
434
- model_type = "SDXL"
496
+
497
+ def infer_diffusers_model_type(checkpoint):
498
+ if (
499
+ CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint
500
+ and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9
501
+ ):
502
+ if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
503
+ model_type = "inpainting_v2"
504
+ elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
505
+ model_type = "xl_inpaint"
435
506
  else:
436
- model_type = "SDXL-Refiner"
437
- else:
438
- raise ValueError("Unable to infer model type from config")
507
+ model_type = "inpainting"
439
508
 
440
- logger.debug(f"No `model_type` given, `model_type` inferred as: {model_type}")
509
+ elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
510
+ model_type = "v2"
441
511
 
442
- return model_type
512
+ elif CHECKPOINT_KEY_NAMES["playground-v2-5"] in checkpoint:
513
+ model_type = "playground-v2-5"
443
514
 
515
+ elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
516
+ model_type = "xl_base"
444
517
 
445
- def get_default_scheduler_config():
446
- return SCHEDULER_DEFAULT_CONFIG
518
+ elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint:
519
+ model_type = "xl_refiner"
447
520
 
521
+ elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint:
522
+ model_type = "upscale"
448
523
 
449
- def set_image_size(pipeline_class_name, original_config, checkpoint, image_size=None, model_type=None):
450
- if image_size:
451
- return image_size
524
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["controlnet"]):
525
+ if CHECKPOINT_KEY_NAMES["controlnet_xl"] in checkpoint:
526
+ if CHECKPOINT_KEY_NAMES["controlnet_xl_large"] in checkpoint:
527
+ model_type = "controlnet_xl_large"
528
+ elif CHECKPOINT_KEY_NAMES["controlnet_xl_mid"] in checkpoint:
529
+ model_type = "controlnet_xl_mid"
530
+ else:
531
+ model_type = "controlnet_xl_small"
532
+ else:
533
+ model_type = "controlnet"
452
534
 
453
- global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
454
- model_type = infer_model_type(original_config, checkpoint, model_type)
535
+ elif (
536
+ CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
537
+ and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 1536
538
+ ):
539
+ model_type = "stable_cascade_stage_c_lite"
455
540
 
456
- if pipeline_class_name == "StableDiffusionUpscalePipeline":
457
- image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"]
458
- return image_size
541
+ elif (
542
+ CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
543
+ and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 2048
544
+ ):
545
+ model_type = "stable_cascade_stage_c"
459
546
 
460
- elif model_type in ["SDXL", "SDXL-Refiner", "Playground"]:
461
- image_size = 1024
462
- return image_size
547
+ elif (
548
+ CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
549
+ and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 576
550
+ ):
551
+ model_type = "stable_cascade_stage_b_lite"
463
552
 
464
553
  elif (
465
- "parameterization" in original_config["model"]["params"]
466
- and original_config["model"]["params"]["parameterization"] == "v"
554
+ CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
555
+ and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 640
467
556
  ):
468
- # NOTE: For stable diffusion 2 base one has to pass `image_size==512`
469
- # as it relies on a brittle global step parameter here
470
- image_size = 512 if global_step == 875000 else 768
471
- return image_size
557
+ model_type = "stable_cascade_stage_b"
558
+
559
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any(
560
+ checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"]
561
+ ):
562
+ if "model.diffusion_model.pos_embed" in checkpoint:
563
+ key = "model.diffusion_model.pos_embed"
564
+ else:
565
+ key = "pos_embed"
566
+
567
+ if checkpoint[key].shape[1] == 36864:
568
+ model_type = "sd3"
569
+ elif checkpoint[key].shape[1] == 147456:
570
+ model_type = "sd35_medium"
571
+
572
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]):
573
+ model_type = "sd35_large"
574
+
575
+ elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
576
+ if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
577
+ model_type = "animatediff_scribble"
578
+
579
+ elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint:
580
+ model_type = "animatediff_rgb"
581
+
582
+ elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
583
+ model_type = "animatediff_v2"
584
+
585
+ elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
586
+ model_type = "animatediff_sdxl_beta"
587
+
588
+ elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff"]].shape[1] == 24:
589
+ model_type = "animatediff_v1"
590
+
591
+ else:
592
+ model_type = "animatediff_v3"
593
+
594
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
595
+ if any(
596
+ g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
597
+ ):
598
+ if "model.diffusion_model.img_in.weight" in checkpoint:
599
+ key = "model.diffusion_model.img_in.weight"
600
+ else:
601
+ key = "img_in.weight"
602
+
603
+ if checkpoint[key].shape[1] == 384:
604
+ model_type = "flux-fill"
605
+ elif checkpoint[key].shape[1] == 128:
606
+ model_type = "flux-depth"
607
+ else:
608
+ model_type = "flux-dev"
609
+ else:
610
+ model_type = "flux-schnell"
611
+
612
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
613
+ if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
614
+ model_type = "ltx-video-0.9.1"
615
+ else:
616
+ model_type = "ltx-video"
617
+
618
+ elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
619
+ encoder_key = "encoder.project_in.conv.conv.bias"
620
+ decoder_key = "decoder.project_in.main.conv.weight"
621
+
622
+ if CHECKPOINT_KEY_NAMES["autoencoder-dc-sana"] in checkpoint:
623
+ model_type = "autoencoder-dc-f32c32-sana"
624
+
625
+ elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 32:
626
+ model_type = "autoencoder-dc-f32c32"
627
+
628
+ elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 128:
629
+ model_type = "autoencoder-dc-f64c128"
630
+
631
+ else:
632
+ model_type = "autoencoder-dc-f128c512"
633
+
634
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
635
+ model_type = "mochi-1-preview"
636
+
637
+ elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
638
+ model_type = "hunyuan-video"
472
639
 
473
640
  else:
474
- image_size = 512
641
+ model_type = "v1"
642
+
643
+ return model_type
644
+
645
+
646
+ def fetch_diffusers_config(checkpoint):
647
+ model_type = infer_diffusers_model_type(checkpoint)
648
+ model_path = DIFFUSERS_DEFAULT_PIPELINE_PATHS[model_type]
649
+ model_path = copy.deepcopy(model_path)
650
+
651
+ return model_path
652
+
653
+
654
+ def set_image_size(checkpoint, image_size=None):
655
+ if image_size:
475
656
  return image_size
476
657
 
658
+ model_type = infer_diffusers_model_type(checkpoint)
659
+ image_size = DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP[model_type]
660
+
661
+ return image_size
662
+
477
663
 
478
664
  # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
479
665
  def conv_attn_to_linear(checkpoint):
@@ -488,10 +674,21 @@ def conv_attn_to_linear(checkpoint):
488
674
  checkpoint[key] = checkpoint[key][:, :, 0]
489
675
 
490
676
 
491
- def create_unet_diffusers_config(original_config, image_size: int):
677
+ def create_unet_diffusers_config_from_ldm(
678
+ original_config, checkpoint, image_size=None, upcast_attention=None, num_in_channels=None
679
+ ):
492
680
  """
493
681
  Creates a config for the diffusers based on the config of the LDM model.
494
682
  """
683
+ if image_size is not None:
684
+ deprecation_message = (
685
+ "Configuring UNet2DConditionModel with the `image_size` argument to `from_single_file`"
686
+ "is deprecated and will be ignored in future versions."
687
+ )
688
+ deprecate("image_size", "1.0.0", deprecation_message)
689
+
690
+ image_size = set_image_size(checkpoint, image_size=image_size)
691
+
495
692
  if (
496
693
  "unet_config" in original_config["model"]["params"]
497
694
  and original_config["model"]["params"]["unet_config"] is not None
@@ -500,6 +697,16 @@ def create_unet_diffusers_config(original_config, image_size: int):
500
697
  else:
501
698
  unet_params = original_config["model"]["params"]["network_config"]["params"]
502
699
 
700
+ if num_in_channels is not None:
701
+ deprecation_message = (
702
+ "Configuring UNet2DConditionModel with the `num_in_channels` argument to `from_single_file`"
703
+ "is deprecated and will be ignored in future versions."
704
+ )
705
+ deprecate("image_size", "1.0.0", deprecation_message)
706
+ in_channels = num_in_channels
707
+ else:
708
+ in_channels = unet_params["in_channels"]
709
+
503
710
  vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
504
711
  block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
505
712
 
@@ -564,7 +771,7 @@ def create_unet_diffusers_config(original_config, image_size: int):
564
771
 
565
772
  config = {
566
773
  "sample_size": image_size // vae_scale_factor,
567
- "in_channels": unet_params["in_channels"],
774
+ "in_channels": in_channels,
568
775
  "down_block_types": down_block_types,
569
776
  "block_out_channels": block_out_channels,
570
777
  "layers_per_block": unet_params["num_res_blocks"],
@@ -578,6 +785,14 @@ def create_unet_diffusers_config(original_config, image_size: int):
578
785
  "transformer_layers_per_block": transformer_layers_per_block,
579
786
  }
580
787
 
788
+ if upcast_attention is not None:
789
+ deprecation_message = (
790
+ "Configuring UNet2DConditionModel with the `upcast_attention` argument to `from_single_file`"
791
+ "is deprecated and will be ignored in future versions."
792
+ )
793
+ deprecate("image_size", "1.0.0", deprecation_message)
794
+ config["upcast_attention"] = upcast_attention
795
+
581
796
  if "disable_self_attentions" in unet_params:
582
797
  config["only_cross_attention"] = unet_params["disable_self_attentions"]
583
798
 
@@ -590,9 +805,18 @@ def create_unet_diffusers_config(original_config, image_size: int):
590
805
  return config
591
806
 
592
807
 
593
- def create_controlnet_diffusers_config(original_config, image_size: int):
808
+ def create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, **kwargs):
809
+ if image_size is not None:
810
+ deprecation_message = (
811
+ "Configuring ControlNetModel with the `image_size` argument"
812
+ "is deprecated and will be ignored in future versions."
813
+ )
814
+ deprecate("image_size", "1.0.0", deprecation_message)
815
+
816
+ image_size = set_image_size(checkpoint, image_size=image_size)
817
+
594
818
  unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
595
- diffusers_unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
819
+ diffusers_unet_config = create_unet_diffusers_config_from_ldm(original_config, image_size=image_size)
596
820
 
597
821
  controlnet_config = {
598
822
  "conditioning_channels": unet_params["hint_channels"],
@@ -613,15 +837,33 @@ def create_controlnet_diffusers_config(original_config, image_size: int):
613
837
  return controlnet_config
614
838
 
615
839
 
616
- def create_vae_diffusers_config(original_config, image_size, scaling_factor=None, latents_mean=None, latents_std=None):
840
+ def create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, scaling_factor=None):
617
841
  """
618
842
  Creates a config for the diffusers based on the config of the LDM model.
619
843
  """
844
+ if image_size is not None:
845
+ deprecation_message = (
846
+ "Configuring AutoencoderKL with the `image_size` argument"
847
+ "is deprecated and will be ignored in future versions."
848
+ )
849
+ deprecate("image_size", "1.0.0", deprecation_message)
850
+
851
+ image_size = set_image_size(checkpoint, image_size=image_size)
852
+
853
+ if "edm_mean" in checkpoint and "edm_std" in checkpoint:
854
+ latents_mean = checkpoint["edm_mean"]
855
+ latents_std = checkpoint["edm_std"]
856
+ else:
857
+ latents_mean = None
858
+ latents_std = None
859
+
620
860
  vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
621
861
  if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None):
622
862
  scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR
863
+
623
864
  elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]):
624
865
  scaling_factor = original_config["model"]["params"]["scale_factor"]
866
+
625
867
  elif scaling_factor is None:
626
868
  scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
627
869
 
@@ -658,48 +900,136 @@ def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, ma
658
900
  )
659
901
  if mapping:
660
902
  diffusers_key = diffusers_key.replace(mapping["old"], mapping["new"])
661
- new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
903
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
662
904
 
663
905
 
664
906
  def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping):
665
907
  for ldm_key in ldm_keys:
666
908
  diffusers_key = ldm_key.replace(mapping["old"], mapping["new"])
667
- new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
909
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
668
910
 
669
911
 
670
- def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
671
- """
672
- Takes a state dict and a config, and returns a converted checkpoint.
673
- """
674
- # extract state_dict for UNet
675
- unet_state_dict = {}
676
- keys = list(checkpoint.keys())
677
- unet_key = LDM_UNET_KEY
912
+ def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
913
+ for ldm_key in keys:
914
+ diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
915
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
678
916
 
679
- # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
680
- if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
681
- logger.warning("Checkpoint has both EMA and non-EMA weights.")
682
- logger.warning(
683
- "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
684
- " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
685
- )
686
- for key in keys:
687
- if key.startswith("model.diffusion_model"):
688
- flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
689
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
690
- else:
691
- if sum(k.startswith("model_ema") for k in keys) > 100:
692
- logger.warning(
693
- "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
694
- " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
695
- )
696
- for key in keys:
697
- if key.startswith(unet_key):
698
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
699
917
 
700
- new_checkpoint = {}
701
- ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"]
702
- for diffusers_key, ldm_key in ldm_unet_keys.items():
918
+ def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
919
+ for ldm_key in keys:
920
+ diffusers_key = (
921
+ ldm_key.replace(mapping["old"], mapping["new"])
922
+ .replace("norm.weight", "group_norm.weight")
923
+ .replace("norm.bias", "group_norm.bias")
924
+ .replace("q.weight", "to_q.weight")
925
+ .replace("q.bias", "to_q.bias")
926
+ .replace("k.weight", "to_k.weight")
927
+ .replace("k.bias", "to_k.bias")
928
+ .replace("v.weight", "to_v.weight")
929
+ .replace("v.bias", "to_v.bias")
930
+ .replace("proj_out.weight", "to_out.0.weight")
931
+ .replace("proj_out.bias", "to_out.0.bias")
932
+ )
933
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
934
+
935
+ # proj_attn.weight has to be converted from conv 1D to linear
936
+ shape = new_checkpoint[diffusers_key].shape
937
+
938
+ if len(shape) == 3:
939
+ new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
940
+ elif len(shape) == 4:
941
+ new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
942
+
943
+
944
+ def convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs):
945
+ is_stage_c = "clip_txt_mapper.weight" in checkpoint
946
+
947
+ if is_stage_c:
948
+ state_dict = {}
949
+ for key in checkpoint.keys():
950
+ if key.endswith("in_proj_weight"):
951
+ weights = checkpoint[key].chunk(3, 0)
952
+ state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
953
+ state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
954
+ state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
955
+ elif key.endswith("in_proj_bias"):
956
+ weights = checkpoint[key].chunk(3, 0)
957
+ state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
958
+ state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
959
+ state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
960
+ elif key.endswith("out_proj.weight"):
961
+ weights = checkpoint[key]
962
+ state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
963
+ elif key.endswith("out_proj.bias"):
964
+ weights = checkpoint[key]
965
+ state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
966
+ else:
967
+ state_dict[key] = checkpoint[key]
968
+ else:
969
+ state_dict = {}
970
+ for key in checkpoint.keys():
971
+ if key.endswith("in_proj_weight"):
972
+ weights = checkpoint[key].chunk(3, 0)
973
+ state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
974
+ state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
975
+ state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
976
+ elif key.endswith("in_proj_bias"):
977
+ weights = checkpoint[key].chunk(3, 0)
978
+ state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
979
+ state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
980
+ state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
981
+ elif key.endswith("out_proj.weight"):
982
+ weights = checkpoint[key]
983
+ state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
984
+ elif key.endswith("out_proj.bias"):
985
+ weights = checkpoint[key]
986
+ state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
987
+ # rename clip_mapper to clip_txt_pooled_mapper
988
+ elif key.endswith("clip_mapper.weight"):
989
+ weights = checkpoint[key]
990
+ state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
991
+ elif key.endswith("clip_mapper.bias"):
992
+ weights = checkpoint[key]
993
+ state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
994
+ else:
995
+ state_dict[key] = checkpoint[key]
996
+
997
+ return state_dict
998
+
999
+
1000
+ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False, **kwargs):
1001
+ """
1002
+ Takes a state dict and a config, and returns a converted checkpoint.
1003
+ """
1004
+ # extract state_dict for UNet
1005
+ unet_state_dict = {}
1006
+ keys = list(checkpoint.keys())
1007
+ unet_key = LDM_UNET_KEY
1008
+
1009
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
1010
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
1011
+ logger.warning("Checkpoint has both EMA and non-EMA weights.")
1012
+ logger.warning(
1013
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
1014
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
1015
+ )
1016
+ for key in keys:
1017
+ if key.startswith("model.diffusion_model"):
1018
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
1019
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(flat_ema_key)
1020
+ else:
1021
+ if sum(k.startswith("model_ema") for k in keys) > 100:
1022
+ logger.warning(
1023
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
1024
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
1025
+ )
1026
+ for key in keys:
1027
+ if key.startswith(unet_key):
1028
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(key)
1029
+
1030
+ new_checkpoint = {}
1031
+ ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"]
1032
+ for diffusers_key, ldm_key in ldm_unet_keys.items():
703
1033
  if ldm_key not in unet_state_dict:
704
1034
  continue
705
1035
  new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
@@ -756,10 +1086,10 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
756
1086
  )
757
1087
 
758
1088
  if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
759
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
1089
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.get(
760
1090
  f"input_blocks.{i}.0.op.weight"
761
1091
  )
762
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
1092
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.get(
763
1093
  f"input_blocks.{i}.0.op.bias"
764
1094
  )
765
1095
 
@@ -773,19 +1103,22 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
773
1103
  )
774
1104
 
775
1105
  # Mid blocks
776
- resnet_0 = middle_blocks[0]
777
- attentions = middle_blocks[1]
778
- resnet_1 = middle_blocks[2]
779
-
780
- update_unet_resnet_ldm_to_diffusers(
781
- resnet_0, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.0", "new": "mid_block.resnets.0"}
782
- )
783
- update_unet_resnet_ldm_to_diffusers(
784
- resnet_1, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.2", "new": "mid_block.resnets.1"}
785
- )
786
- update_unet_attention_ldm_to_diffusers(
787
- attentions, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.1", "new": "mid_block.attentions.0"}
788
- )
1106
+ for key in middle_blocks.keys():
1107
+ diffusers_key = max(key - 1, 0)
1108
+ if key % 2 == 0:
1109
+ update_unet_resnet_ldm_to_diffusers(
1110
+ middle_blocks[key],
1111
+ new_checkpoint,
1112
+ unet_state_dict,
1113
+ mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
1114
+ )
1115
+ else:
1116
+ update_unet_attention_ldm_to_diffusers(
1117
+ middle_blocks[key],
1118
+ new_checkpoint,
1119
+ unet_state_dict,
1120
+ mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
1121
+ )
789
1122
 
790
1123
  # Up Blocks
791
1124
  for i in range(num_output_blocks):
@@ -834,7 +1167,11 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
834
1167
  def convert_controlnet_checkpoint(
835
1168
  checkpoint,
836
1169
  config,
1170
+ **kwargs,
837
1171
  ):
1172
+ # Return checkpoint if it's already been converted
1173
+ if "time_embedding.linear_1.weight" in checkpoint:
1174
+ return checkpoint
838
1175
  # Some controlnet ckpt files are distributed independently from the rest of the
839
1176
  # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
840
1177
  if "time_embed.0.weight" in checkpoint:
@@ -846,7 +1183,7 @@ def convert_controlnet_checkpoint(
846
1183
  controlnet_key = LDM_CONTROLNET_KEY
847
1184
  for key in keys:
848
1185
  if key.startswith(controlnet_key):
849
- controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.pop(key)
1186
+ controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key)
850
1187
 
851
1188
  new_checkpoint = {}
852
1189
  ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"]
@@ -880,10 +1217,10 @@ def convert_controlnet_checkpoint(
880
1217
  )
881
1218
 
882
1219
  if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict:
883
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.pop(
1220
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get(
884
1221
  f"input_blocks.{i}.0.op.weight"
885
1222
  )
886
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.pop(
1223
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get(
887
1224
  f"input_blocks.{i}.0.op.bias"
888
1225
  )
889
1226
 
@@ -898,8 +1235,8 @@ def convert_controlnet_checkpoint(
898
1235
 
899
1236
  # controlnet down blocks
900
1237
  for i in range(num_input_blocks):
901
- new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.weight")
902
- new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.bias")
1238
+ new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight")
1239
+ new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias")
903
1240
 
904
1241
  # Retrieves the keys for the middle blocks only
905
1242
  num_middle_blocks = len(
@@ -909,33 +1246,28 @@ def convert_controlnet_checkpoint(
909
1246
  layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key]
910
1247
  for layer_id in range(num_middle_blocks)
911
1248
  }
912
- if middle_blocks:
913
- resnet_0 = middle_blocks[0]
914
- attentions = middle_blocks[1]
915
- resnet_1 = middle_blocks[2]
916
1249
 
917
- update_unet_resnet_ldm_to_diffusers(
918
- resnet_0,
919
- new_checkpoint,
920
- controlnet_state_dict,
921
- mapping={"old": "middle_block.0", "new": "mid_block.resnets.0"},
922
- )
923
- update_unet_resnet_ldm_to_diffusers(
924
- resnet_1,
925
- new_checkpoint,
926
- controlnet_state_dict,
927
- mapping={"old": "middle_block.2", "new": "mid_block.resnets.1"},
928
- )
929
- update_unet_attention_ldm_to_diffusers(
930
- attentions,
931
- new_checkpoint,
932
- controlnet_state_dict,
933
- mapping={"old": "middle_block.1", "new": "mid_block.attentions.0"},
934
- )
1250
+ # Mid blocks
1251
+ for key in middle_blocks.keys():
1252
+ diffusers_key = max(key - 1, 0)
1253
+ if key % 2 == 0:
1254
+ update_unet_resnet_ldm_to_diffusers(
1255
+ middle_blocks[key],
1256
+ new_checkpoint,
1257
+ controlnet_state_dict,
1258
+ mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
1259
+ )
1260
+ else:
1261
+ update_unet_attention_ldm_to_diffusers(
1262
+ middle_blocks[key],
1263
+ new_checkpoint,
1264
+ controlnet_state_dict,
1265
+ mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
1266
+ )
935
1267
 
936
1268
  # mid block
937
- new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.pop("middle_block_out.0.weight")
938
- new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.pop("middle_block_out.0.bias")
1269
+ new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight")
1270
+ new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias")
939
1271
 
940
1272
  # controlnet cond embedding blocks
941
1273
  cond_embedding_blocks = {
@@ -949,94 +1281,26 @@ def convert_controlnet_checkpoint(
949
1281
  diffusers_idx = idx - 1
950
1282
  cond_block_id = 2 * idx
951
1283
 
952
- new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.pop(
1284
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get(
953
1285
  f"input_hint_block.{cond_block_id}.weight"
954
1286
  )
955
- new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.pop(
1287
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get(
956
1288
  f"input_hint_block.{cond_block_id}.bias"
957
1289
  )
958
1290
 
959
1291
  return new_checkpoint
960
1292
 
961
1293
 
962
- def create_diffusers_controlnet_model_from_ldm(
963
- pipeline_class_name, original_config, checkpoint, upcast_attention=False, image_size=None, torch_dtype=None
964
- ):
965
- # import here to avoid circular imports
966
- from ..models import ControlNetModel
967
-
968
- image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
969
-
970
- diffusers_config = create_controlnet_diffusers_config(original_config, image_size=image_size)
971
- diffusers_config["upcast_attention"] = upcast_attention
972
-
973
- diffusers_format_controlnet_checkpoint = convert_controlnet_checkpoint(checkpoint, diffusers_config)
974
-
975
- ctx = init_empty_weights if is_accelerate_available() else nullcontext
976
- with ctx():
977
- controlnet = ControlNetModel(**diffusers_config)
978
-
979
- if is_accelerate_available():
980
- from ..models.modeling_utils import load_model_dict_into_meta
981
-
982
- unexpected_keys = load_model_dict_into_meta(
983
- controlnet, diffusers_format_controlnet_checkpoint, dtype=torch_dtype
984
- )
985
- if controlnet._keys_to_ignore_on_load_unexpected is not None:
986
- for pat in controlnet._keys_to_ignore_on_load_unexpected:
987
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
988
-
989
- if len(unexpected_keys) > 0:
990
- logger.warning(
991
- f"Some weights of the model checkpoint were not used when initializing {controlnet.__name__}: \n {[', '.join(unexpected_keys)]}"
992
- )
993
- else:
994
- controlnet.load_state_dict(diffusers_format_controlnet_checkpoint)
995
-
996
- if torch_dtype is not None:
997
- controlnet = controlnet.to(torch_dtype)
998
-
999
- return {"controlnet": controlnet}
1000
-
1001
-
1002
- def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
1003
- for ldm_key in keys:
1004
- diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
1005
- new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
1006
-
1007
-
1008
- def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
1009
- for ldm_key in keys:
1010
- diffusers_key = (
1011
- ldm_key.replace(mapping["old"], mapping["new"])
1012
- .replace("norm.weight", "group_norm.weight")
1013
- .replace("norm.bias", "group_norm.bias")
1014
- .replace("q.weight", "to_q.weight")
1015
- .replace("q.bias", "to_q.bias")
1016
- .replace("k.weight", "to_k.weight")
1017
- .replace("k.bias", "to_k.bias")
1018
- .replace("v.weight", "to_v.weight")
1019
- .replace("v.bias", "to_v.bias")
1020
- .replace("proj_out.weight", "to_out.0.weight")
1021
- .replace("proj_out.bias", "to_out.0.bias")
1022
- )
1023
- new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
1024
-
1025
- # proj_attn.weight has to be converted from conv 1D to linear
1026
- shape = new_checkpoint[diffusers_key].shape
1027
-
1028
- if len(shape) == 3:
1029
- new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
1030
- elif len(shape) == 4:
1031
- new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
1032
-
1033
-
1034
1294
  def convert_ldm_vae_checkpoint(checkpoint, config):
1035
1295
  # extract state dict for VAE
1036
1296
  # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
1037
1297
  vae_state_dict = {}
1038
1298
  keys = list(checkpoint.keys())
1039
- vae_key = LDM_VAE_KEY if any(k.startswith(LDM_VAE_KEY) for k in keys) else ""
1299
+ vae_key = ""
1300
+ for ldm_vae_key in LDM_VAE_KEYS:
1301
+ if any(k.startswith(ldm_vae_key) for k in keys):
1302
+ vae_key = ldm_vae_key
1303
+
1040
1304
  for key in keys:
1041
1305
  if key.startswith(vae_key):
1042
1306
  vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
@@ -1063,10 +1327,10 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
1063
1327
  mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"},
1064
1328
  )
1065
1329
  if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
1066
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
1330
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get(
1067
1331
  f"encoder.down.{i}.downsample.conv.weight"
1068
1332
  )
1069
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
1333
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get(
1070
1334
  f"encoder.down.{i}.downsample.conv.bias"
1071
1335
  )
1072
1336
 
@@ -1131,79 +1395,38 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
1131
1395
  return new_checkpoint
1132
1396
 
1133
1397
 
1134
- def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_files_only=False, torch_dtype=None):
1135
- try:
1136
- config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
1137
- except Exception:
1138
- raise ValueError(
1139
- f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'."
1140
- )
1141
-
1142
- ctx = init_empty_weights if is_accelerate_available() else nullcontext
1143
- with ctx():
1144
- text_model = CLIPTextModel(config)
1145
-
1398
+ def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
1146
1399
  keys = list(checkpoint.keys())
1147
1400
  text_model_dict = {}
1148
1401
 
1149
- remove_prefixes = LDM_CLIP_PREFIX_TO_REMOVE
1402
+ remove_prefixes = []
1403
+ remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
1404
+ if remove_prefix:
1405
+ remove_prefixes.append(remove_prefix)
1150
1406
 
1151
1407
  for key in keys:
1152
1408
  for prefix in remove_prefixes:
1153
1409
  if key.startswith(prefix):
1154
1410
  diffusers_key = key.replace(prefix, "")
1155
- text_model_dict[diffusers_key] = checkpoint[key]
1156
-
1157
- if is_accelerate_available():
1158
- from ..models.modeling_utils import load_model_dict_into_meta
1159
-
1160
- unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict, dtype=torch_dtype)
1161
- if text_model._keys_to_ignore_on_load_unexpected is not None:
1162
- for pat in text_model._keys_to_ignore_on_load_unexpected:
1163
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1164
-
1165
- if len(unexpected_keys) > 0:
1166
- logger.warning(
1167
- f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
1168
- )
1169
- else:
1170
- if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
1171
- text_model_dict.pop("text_model.embeddings.position_ids", None)
1172
-
1173
- text_model.load_state_dict(text_model_dict)
1411
+ text_model_dict[diffusers_key] = checkpoint.get(key)
1174
1412
 
1175
- if torch_dtype is not None:
1176
- text_model = text_model.to(torch_dtype)
1177
-
1178
- return text_model
1413
+ return text_model_dict
1179
1414
 
1180
1415
 
1181
- def create_text_encoder_from_open_clip_checkpoint(
1182
- config_name,
1416
+ def convert_open_clip_checkpoint(
1417
+ text_model,
1183
1418
  checkpoint,
1184
1419
  prefix="cond_stage_model.model.",
1185
- has_projection=False,
1186
- local_files_only=False,
1187
- torch_dtype=None,
1188
- **config_kwargs,
1189
1420
  ):
1190
- try:
1191
- config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only)
1192
- except Exception:
1193
- raise ValueError(
1194
- f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'."
1195
- )
1196
-
1197
- ctx = init_empty_weights if is_accelerate_available() else nullcontext
1198
- with ctx():
1199
- text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)
1200
-
1201
1421
  text_model_dict = {}
1202
1422
  text_proj_key = prefix + "text_projection"
1203
- text_proj_dim = (
1204
- int(checkpoint[text_proj_key].shape[0]) if text_proj_key in checkpoint else LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
1205
- )
1206
- text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
1423
+
1424
+ if text_proj_key in checkpoint:
1425
+ text_proj_dim = int(checkpoint[text_proj_key].shape[0])
1426
+ elif hasattr(text_model.config, "projection_dim"):
1427
+ text_proj_dim = text_model.config.projection_dim
1428
+ else:
1429
+ text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
1207
1430
 
1208
1431
  keys = list(checkpoint.keys())
1209
1432
  keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE
@@ -1235,309 +1458,183 @@ def create_text_encoder_from_open_clip_checkpoint(
1235
1458
  )
1236
1459
 
1237
1460
  if key.endswith(".in_proj_weight"):
1238
- weight_value = checkpoint[key]
1461
+ weight_value = checkpoint.get(key)
1239
1462
 
1240
- text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :]
1241
- text_model_dict[diffusers_key + ".k_proj.weight"] = weight_value[text_proj_dim : text_proj_dim * 2, :]
1242
- text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :]
1463
+ text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :].clone().detach()
1464
+ text_model_dict[diffusers_key + ".k_proj.weight"] = (
1465
+ weight_value[text_proj_dim : text_proj_dim * 2, :].clone().detach()
1466
+ )
1467
+ text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :].clone().detach()
1243
1468
 
1244
1469
  elif key.endswith(".in_proj_bias"):
1245
- weight_value = checkpoint[key]
1246
- text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim]
1247
- text_model_dict[diffusers_key + ".k_proj.bias"] = weight_value[text_proj_dim : text_proj_dim * 2]
1248
- text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :]
1470
+ weight_value = checkpoint.get(key)
1471
+ text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim].clone().detach()
1472
+ text_model_dict[diffusers_key + ".k_proj.bias"] = (
1473
+ weight_value[text_proj_dim : text_proj_dim * 2].clone().detach()
1474
+ )
1475
+ text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :].clone().detach()
1249
1476
  else:
1250
- text_model_dict[diffusers_key] = checkpoint[key]
1477
+ text_model_dict[diffusers_key] = checkpoint.get(key)
1251
1478
 
1252
- if is_accelerate_available():
1253
- from ..models.modeling_utils import load_model_dict_into_meta
1479
+ return text_model_dict
1254
1480
 
1255
- unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict, dtype=torch_dtype)
1256
- if text_model._keys_to_ignore_on_load_unexpected is not None:
1257
- for pat in text_model._keys_to_ignore_on_load_unexpected:
1258
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1259
1481
 
1260
- if len(unexpected_keys) > 0:
1261
- logger.warning(
1262
- f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
1482
+ def create_diffusers_clip_model_from_ldm(
1483
+ cls,
1484
+ checkpoint,
1485
+ subfolder="",
1486
+ config=None,
1487
+ torch_dtype=None,
1488
+ local_files_only=None,
1489
+ is_legacy_loading=False,
1490
+ ):
1491
+ if config:
1492
+ config = {"pretrained_model_name_or_path": config}
1493
+ else:
1494
+ config = fetch_diffusers_config(checkpoint)
1495
+
1496
+ # For backwards compatibility
1497
+ # Older versions of `from_single_file` expected CLIP configs to be placed in their original transformers model repo
1498
+ # in the cache_dir, rather than in a subfolder of the Diffusers model
1499
+ if is_legacy_loading:
1500
+ logger.warning(
1501
+ (
1502
+ "Detected legacy CLIP loading behavior. Please run `from_single_file` with `local_files_only=False once to update "
1503
+ "the local cache directory with the necessary CLIP model config files. "
1504
+ "Attempting to load CLIP model from legacy cache directory."
1263
1505
  )
1506
+ )
1264
1507
 
1265
- else:
1266
- if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
1267
- text_model_dict.pop("text_model.embeddings.position_ids", None)
1508
+ if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
1509
+ clip_config = "openai/clip-vit-large-patch14"
1510
+ config["pretrained_model_name_or_path"] = clip_config
1511
+ subfolder = ""
1268
1512
 
1269
- text_model.load_state_dict(text_model_dict)
1513
+ elif is_open_clip_model(checkpoint):
1514
+ clip_config = "stabilityai/stable-diffusion-2"
1515
+ config["pretrained_model_name_or_path"] = clip_config
1516
+ subfolder = "text_encoder"
1270
1517
 
1271
- if torch_dtype is not None:
1272
- text_model = text_model.to(torch_dtype)
1518
+ else:
1519
+ clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1520
+ config["pretrained_model_name_or_path"] = clip_config
1521
+ subfolder = ""
1273
1522
 
1274
- return text_model
1523
+ model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
1524
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
1525
+ with ctx():
1526
+ model = cls(model_config)
1275
1527
 
1528
+ position_embedding_dim = model.text_model.embeddings.position_embedding.weight.shape[-1]
1276
1529
 
1277
- def create_diffusers_unet_model_from_ldm(
1278
- pipeline_class_name,
1279
- original_config,
1280
- checkpoint,
1281
- num_in_channels=None,
1282
- upcast_attention=None,
1283
- extract_ema=False,
1284
- image_size=None,
1285
- torch_dtype=None,
1286
- model_type=None,
1287
- ):
1288
- from ..models import UNet2DConditionModel
1530
+ if is_clip_model(checkpoint):
1531
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
1289
1532
 
1290
- if num_in_channels is None:
1291
- if pipeline_class_name in [
1292
- "StableDiffusionInpaintPipeline",
1293
- "StableDiffusionControlNetInpaintPipeline",
1294
- "StableDiffusionXLInpaintPipeline",
1295
- "StableDiffusionXLControlNetInpaintPipeline",
1296
- ]:
1297
- num_in_channels = 9
1533
+ elif (
1534
+ is_clip_sdxl_model(checkpoint)
1535
+ and checkpoint[CHECKPOINT_KEY_NAMES["clip_sdxl"]].shape[-1] == position_embedding_dim
1536
+ ):
1537
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
1298
1538
 
1299
- elif pipeline_class_name == "StableDiffusionUpscalePipeline":
1300
- num_in_channels = 7
1539
+ elif (
1540
+ is_clip_sd3_model(checkpoint)
1541
+ and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
1542
+ ):
1543
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
1544
+ diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)
1301
1545
 
1302
- else:
1303
- num_in_channels = 4
1546
+ elif is_open_clip_model(checkpoint):
1547
+ prefix = "cond_stage_model.model."
1548
+ diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
1304
1549
 
1305
- image_size = set_image_size(
1306
- pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
1307
- )
1308
- unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
1309
- unet_config["in_channels"] = num_in_channels
1310
- if upcast_attention is not None:
1311
- unet_config["upcast_attention"] = upcast_attention
1550
+ elif (
1551
+ is_open_clip_sdxl_model(checkpoint)
1552
+ and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sdxl"]].shape[-1] == position_embedding_dim
1553
+ ):
1554
+ prefix = "conditioner.embedders.1.model."
1555
+ diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
1312
1556
 
1313
- diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema)
1314
- ctx = init_empty_weights if is_accelerate_available() else nullcontext
1557
+ elif is_open_clip_sdxl_refiner_model(checkpoint):
1558
+ prefix = "conditioner.embedders.0.model."
1559
+ diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
1315
1560
 
1316
- with ctx():
1317
- unet = UNet2DConditionModel(**unet_config)
1561
+ elif (
1562
+ is_open_clip_sd3_model(checkpoint)
1563
+ and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
1564
+ ):
1565
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")
1566
+
1567
+ else:
1568
+ raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
1318
1569
 
1319
1570
  if is_accelerate_available():
1320
- from ..models.modeling_utils import load_model_dict_into_meta
1571
+ unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1572
+ else:
1573
+ _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
1321
1574
 
1322
- unexpected_keys = load_model_dict_into_meta(unet, diffusers_format_unet_checkpoint, dtype=torch_dtype)
1323
- if unet._keys_to_ignore_on_load_unexpected is not None:
1324
- for pat in unet._keys_to_ignore_on_load_unexpected:
1325
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1575
+ if model._keys_to_ignore_on_load_unexpected is not None:
1576
+ for pat in model._keys_to_ignore_on_load_unexpected:
1577
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1326
1578
 
1327
- if len(unexpected_keys) > 0:
1328
- logger.warning(
1329
- f"Some weights of the model checkpoint were not used when initializing {unet.__name__}: \n {[', '.join(unexpected_keys)]}"
1330
- )
1331
- else:
1332
- unet.load_state_dict(diffusers_format_unet_checkpoint)
1579
+ if len(unexpected_keys) > 0:
1580
+ logger.warning(
1581
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1582
+ )
1333
1583
 
1334
1584
  if torch_dtype is not None:
1335
- unet = unet.to(torch_dtype)
1585
+ model.to(torch_dtype)
1586
+
1587
+ model.eval()
1336
1588
 
1337
- return {"unet": unet}
1589
+ return model
1338
1590
 
1339
1591
 
1340
- def create_diffusers_vae_model_from_ldm(
1341
- pipeline_class_name,
1342
- original_config,
1592
+ def _legacy_load_scheduler(
1593
+ cls,
1343
1594
  checkpoint,
1344
- image_size=None,
1345
- scaling_factor=None,
1346
- torch_dtype=None,
1347
- model_type=None,
1595
+ component_name,
1596
+ original_config=None,
1597
+ **kwargs,
1348
1598
  ):
1349
- # import here to avoid circular imports
1350
- from ..models import AutoencoderKL
1351
-
1352
- image_size = set_image_size(
1353
- pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
1354
- )
1355
- model_type = infer_model_type(original_config, checkpoint, model_type)
1356
-
1357
- if model_type == "Playground":
1358
- edm_mean = (
1359
- checkpoint["edm_mean"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_mean"].tolist()
1599
+ scheduler_type = kwargs.get("scheduler_type", None)
1600
+ prediction_type = kwargs.get("prediction_type", None)
1601
+
1602
+ if scheduler_type is not None:
1603
+ deprecation_message = (
1604
+ "Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`\n\n"
1605
+ "Example:\n\n"
1606
+ "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
1607
+ "scheduler = DDIMScheduler()\n"
1608
+ "pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
1360
1609
  )
1361
- edm_std = (
1362
- checkpoint["edm_std"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_std"].tolist()
1610
+ deprecate("scheduler_type", "1.0.0", deprecation_message)
1611
+
1612
+ if prediction_type is not None:
1613
+ deprecation_message = (
1614
+ "Please configure an instance of a Scheduler with the appropriate `prediction_type` and "
1615
+ "pass the object directly to the `scheduler` argument in `from_single_file`.\n\n"
1616
+ "Example:\n\n"
1617
+ "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
1618
+ 'scheduler = DDIMScheduler(prediction_type="v_prediction")\n'
1619
+ "pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
1363
1620
  )
1364
- else:
1365
- edm_mean = None
1366
- edm_std = None
1367
-
1368
- vae_config = create_vae_diffusers_config(
1369
- original_config,
1370
- image_size=image_size,
1371
- scaling_factor=scaling_factor,
1372
- latents_mean=edm_mean,
1373
- latents_std=edm_std,
1374
- )
1375
- diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
1376
- ctx = init_empty_weights if is_accelerate_available() else nullcontext
1621
+ deprecate("prediction_type", "1.0.0", deprecation_message)
1377
1622
 
1378
- with ctx():
1379
- vae = AutoencoderKL(**vae_config)
1380
-
1381
- if is_accelerate_available():
1382
- from ..models.modeling_utils import load_model_dict_into_meta
1623
+ scheduler_config = SCHEDULER_DEFAULT_CONFIG
1624
+ model_type = infer_diffusers_model_type(checkpoint=checkpoint)
1383
1625
 
1384
- unexpected_keys = load_model_dict_into_meta(vae, diffusers_format_vae_checkpoint, dtype=torch_dtype)
1385
- if vae._keys_to_ignore_on_load_unexpected is not None:
1386
- for pat in vae._keys_to_ignore_on_load_unexpected:
1387
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1626
+ global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
1388
1627
 
1389
- if len(unexpected_keys) > 0:
1390
- logger.warning(
1391
- f"Some weights of the model checkpoint were not used when initializing {vae.__name__}: \n {[', '.join(unexpected_keys)]}"
1392
- )
1628
+ if original_config:
1629
+ num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", 1000)
1393
1630
  else:
1394
- vae.load_state_dict(diffusers_format_vae_checkpoint)
1395
-
1396
- if torch_dtype is not None:
1397
- vae = vae.to(torch_dtype)
1398
-
1399
- return {"vae": vae}
1400
-
1401
-
1402
- def create_text_encoders_and_tokenizers_from_ldm(
1403
- original_config,
1404
- checkpoint,
1405
- model_type=None,
1406
- local_files_only=False,
1407
- torch_dtype=None,
1408
- ):
1409
- model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
1410
-
1411
- if model_type == "FrozenOpenCLIPEmbedder":
1412
- config_name = "stabilityai/stable-diffusion-2"
1413
- config_kwargs = {"subfolder": "text_encoder"}
1414
-
1415
- try:
1416
- text_encoder = create_text_encoder_from_open_clip_checkpoint(
1417
- config_name, checkpoint, local_files_only=local_files_only, torch_dtype=torch_dtype, **config_kwargs
1418
- )
1419
- tokenizer = CLIPTokenizer.from_pretrained(
1420
- config_name, subfolder="tokenizer", local_files_only=local_files_only
1421
- )
1422
- except Exception:
1423
- raise ValueError(
1424
- f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder in the following path: '{config_name}'."
1425
- )
1426
- else:
1427
- return {"text_encoder": text_encoder, "tokenizer": tokenizer}
1428
-
1429
- elif model_type == "FrozenCLIPEmbedder":
1430
- try:
1431
- config_name = "openai/clip-vit-large-patch14"
1432
- text_encoder = create_text_encoder_from_ldm_clip_checkpoint(
1433
- config_name,
1434
- checkpoint,
1435
- local_files_only=local_files_only,
1436
- torch_dtype=torch_dtype,
1437
- )
1438
- tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
1439
-
1440
- except Exception:
1441
- raise ValueError(
1442
- f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: '{config_name}'."
1443
- )
1444
- else:
1445
- return {"text_encoder": text_encoder, "tokenizer": tokenizer}
1446
-
1447
- elif model_type == "SDXL-Refiner":
1448
- config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1449
- config_kwargs = {"projection_dim": 1280}
1450
- prefix = "conditioner.embedders.0.model."
1451
-
1452
- try:
1453
- tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only)
1454
- text_encoder_2 = create_text_encoder_from_open_clip_checkpoint(
1455
- config_name,
1456
- checkpoint,
1457
- prefix=prefix,
1458
- has_projection=True,
1459
- local_files_only=local_files_only,
1460
- torch_dtype=torch_dtype,
1461
- **config_kwargs,
1462
- )
1463
- except Exception:
1464
- raise ValueError(
1465
- f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'."
1466
- )
1467
-
1468
- else:
1469
- return {
1470
- "text_encoder": None,
1471
- "tokenizer": None,
1472
- "tokenizer_2": tokenizer_2,
1473
- "text_encoder_2": text_encoder_2,
1474
- }
1475
-
1476
- elif model_type in ["SDXL", "Playground"]:
1477
- try:
1478
- config_name = "openai/clip-vit-large-patch14"
1479
- tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
1480
- text_encoder = create_text_encoder_from_ldm_clip_checkpoint(
1481
- config_name, checkpoint, local_files_only=local_files_only, torch_dtype=torch_dtype
1482
- )
1483
-
1484
- except Exception:
1485
- raise ValueError(
1486
- f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder and tokenizer in the following path: 'openai/clip-vit-large-patch14'."
1487
- )
1488
-
1489
- try:
1490
- config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1491
- config_kwargs = {"projection_dim": 1280}
1492
- prefix = "conditioner.embedders.1.model."
1493
- tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only)
1494
- text_encoder_2 = create_text_encoder_from_open_clip_checkpoint(
1495
- config_name,
1496
- checkpoint,
1497
- prefix=prefix,
1498
- has_projection=True,
1499
- local_files_only=local_files_only,
1500
- torch_dtype=torch_dtype,
1501
- **config_kwargs,
1502
- )
1503
- except Exception:
1504
- raise ValueError(
1505
- f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'."
1506
- )
1631
+ num_train_timesteps = 1000
1507
1632
 
1508
- return {
1509
- "tokenizer": tokenizer,
1510
- "text_encoder": text_encoder,
1511
- "tokenizer_2": tokenizer_2,
1512
- "text_encoder_2": text_encoder_2,
1513
- }
1514
-
1515
- return
1516
-
1517
-
1518
- def create_scheduler_from_ldm(
1519
- pipeline_class_name,
1520
- original_config,
1521
- checkpoint,
1522
- prediction_type=None,
1523
- scheduler_type="ddim",
1524
- model_type=None,
1525
- ):
1526
- scheduler_config = get_default_scheduler_config()
1527
- model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
1528
-
1529
- global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
1530
-
1531
- num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", None) or 1000
1532
1633
  scheduler_config["num_train_timesteps"] = num_train_timesteps
1533
1634
 
1534
- if (
1535
- "parameterization" in original_config["model"]["params"]
1536
- and original_config["model"]["params"]["parameterization"] == "v"
1537
- ):
1635
+ if model_type == "v2":
1538
1636
  if prediction_type is None:
1539
- # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
1540
- # as it relies on a brittle global step parameter here
1637
+ # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` # as it relies on a brittle global step parameter here
1541
1638
  prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
1542
1639
 
1543
1640
  else:
@@ -1545,20 +1642,44 @@ def create_scheduler_from_ldm(
1545
1642
 
1546
1643
  scheduler_config["prediction_type"] = prediction_type
1547
1644
 
1548
- if model_type in ["SDXL", "SDXL-Refiner"]:
1645
+ if model_type in ["xl_base", "xl_refiner"]:
1549
1646
  scheduler_type = "euler"
1550
- elif model_type == "Playground":
1647
+ elif model_type == "playground":
1551
1648
  scheduler_type = "edm_dpm_solver_multistep"
1552
1649
  else:
1553
- beta_start = original_config["model"]["params"].get("linear_start", 0.02)
1554
- beta_end = original_config["model"]["params"].get("linear_end", 0.085)
1650
+ if original_config:
1651
+ beta_start = original_config["model"]["params"].get("linear_start")
1652
+ beta_end = original_config["model"]["params"].get("linear_end")
1653
+
1654
+ else:
1655
+ beta_start = 0.02
1656
+ beta_end = 0.085
1657
+
1555
1658
  scheduler_config["beta_start"] = beta_start
1556
1659
  scheduler_config["beta_end"] = beta_end
1557
1660
  scheduler_config["beta_schedule"] = "scaled_linear"
1558
1661
  scheduler_config["clip_sample"] = False
1559
1662
  scheduler_config["set_alpha_to_one"] = False
1560
1663
 
1561
- if scheduler_type == "pndm":
1664
+ # to deal with an edge case StableDiffusionUpscale pipeline has two schedulers
1665
+ if component_name == "low_res_scheduler":
1666
+ return cls.from_config(
1667
+ {
1668
+ "beta_end": 0.02,
1669
+ "beta_schedule": "scaled_linear",
1670
+ "beta_start": 0.0001,
1671
+ "clip_sample": True,
1672
+ "num_train_timesteps": 1000,
1673
+ "prediction_type": "epsilon",
1674
+ "trained_betas": None,
1675
+ "variance_type": "fixed_small",
1676
+ }
1677
+ )
1678
+
1679
+ if scheduler_type is None:
1680
+ return cls.from_config(scheduler_config)
1681
+
1682
+ elif scheduler_type == "pndm":
1562
1683
  scheduler_config["skip_prk_steps"] = True
1563
1684
  scheduler = PNDMScheduler.from_config(scheduler_config)
1564
1685
 
@@ -1603,15 +1724,964 @@ def create_scheduler_from_ldm(
1603
1724
  else:
1604
1725
  raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
1605
1726
 
1606
- if pipeline_class_name == "StableDiffusionUpscalePipeline":
1607
- scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler")
1608
- low_res_scheduler = DDPMScheduler.from_pretrained(
1609
- "stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler"
1727
+ return scheduler
1728
+
1729
+
1730
+ def _legacy_load_clip_tokenizer(cls, checkpoint, config=None, local_files_only=False):
1731
+ if config:
1732
+ config = {"pretrained_model_name_or_path": config}
1733
+ else:
1734
+ config = fetch_diffusers_config(checkpoint)
1735
+
1736
+ if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
1737
+ clip_config = "openai/clip-vit-large-patch14"
1738
+ config["pretrained_model_name_or_path"] = clip_config
1739
+ subfolder = ""
1740
+
1741
+ elif is_open_clip_model(checkpoint):
1742
+ clip_config = "stabilityai/stable-diffusion-2"
1743
+ config["pretrained_model_name_or_path"] = clip_config
1744
+ subfolder = "tokenizer"
1745
+
1746
+ else:
1747
+ clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1748
+ config["pretrained_model_name_or_path"] = clip_config
1749
+ subfolder = ""
1750
+
1751
+ tokenizer = cls.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
1752
+
1753
+ return tokenizer
1754
+
1755
+
1756
+ def _legacy_load_safety_checker(local_files_only, torch_dtype):
1757
+ # Support for loading safety checker components using the deprecated
1758
+ # `load_safety_checker` argument.
1759
+
1760
+ from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
1761
+
1762
+ feature_extractor = AutoImageProcessor.from_pretrained(
1763
+ "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
1764
+ )
1765
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
1766
+ "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
1767
+ )
1768
+
1769
+ return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}
1770
+
1771
+
1772
+ # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
1773
+ # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
1774
+ def swap_scale_shift(weight, dim):
1775
+ shift, scale = weight.chunk(2, dim=0)
1776
+ new_weight = torch.cat([scale, shift], dim=0)
1777
+ return new_weight
1778
+
1779
+
1780
+ def swap_proj_gate(weight):
1781
+ proj, gate = weight.chunk(2, dim=0)
1782
+ new_weight = torch.cat([gate, proj], dim=0)
1783
+ return new_weight
1784
+
1785
+
1786
+ def get_attn2_layers(state_dict):
1787
+ attn2_layers = []
1788
+ for key in state_dict.keys():
1789
+ if "attn2." in key:
1790
+ # Extract the layer number from the key
1791
+ layer_num = int(key.split(".")[1])
1792
+ attn2_layers.append(layer_num)
1793
+
1794
+ return tuple(sorted(set(attn2_layers)))
1795
+
1796
+
1797
+ def get_caption_projection_dim(state_dict):
1798
+ caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
1799
+ return caption_projection_dim
1800
+
1801
+
1802
+ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1803
+ converted_state_dict = {}
1804
+ keys = list(checkpoint.keys())
1805
+ for k in keys:
1806
+ if "model.diffusion_model." in k:
1807
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
1808
+
1809
+ num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
1810
+ dual_attention_layers = get_attn2_layers(checkpoint)
1811
+
1812
+ caption_projection_dim = get_caption_projection_dim(checkpoint)
1813
+ has_qk_norm = any("ln_q" in key for key in checkpoint.keys())
1814
+
1815
+ # Positional and patch embeddings.
1816
+ converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
1817
+ converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
1818
+ converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
1819
+
1820
+ # Timestep embeddings.
1821
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
1822
+ "t_embedder.mlp.0.weight"
1823
+ )
1824
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
1825
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
1826
+ "t_embedder.mlp.2.weight"
1827
+ )
1828
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
1829
+
1830
+ # Context projections.
1831
+ converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight")
1832
+ converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias")
1833
+
1834
+ # Pooled context projection.
1835
+ converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight")
1836
+ converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias")
1837
+ converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight")
1838
+ converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias")
1839
+
1840
+ # Transformer blocks 🎸.
1841
+ for i in range(num_layers):
1842
+ # Q, K, V
1843
+ sample_q, sample_k, sample_v = torch.chunk(
1844
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0
1845
+ )
1846
+ context_q, context_k, context_v = torch.chunk(
1847
+ checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0
1848
+ )
1849
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
1850
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0
1851
+ )
1852
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
1853
+ checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0
1610
1854
  )
1611
1855
 
1612
- return {
1613
- "scheduler": scheduler,
1614
- "low_res_scheduler": low_res_scheduler,
1615
- }
1856
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q])
1857
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias])
1858
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k])
1859
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias])
1860
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v])
1861
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias])
1862
+
1863
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q])
1864
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias])
1865
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k])
1866
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias])
1867
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
1868
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
1869
+
1870
+ # qk norm
1871
+ if has_qk_norm:
1872
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = checkpoint.pop(
1873
+ f"joint_blocks.{i}.x_block.attn.ln_q.weight"
1874
+ )
1875
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = checkpoint.pop(
1876
+ f"joint_blocks.{i}.x_block.attn.ln_k.weight"
1877
+ )
1878
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = checkpoint.pop(
1879
+ f"joint_blocks.{i}.context_block.attn.ln_q.weight"
1880
+ )
1881
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = checkpoint.pop(
1882
+ f"joint_blocks.{i}.context_block.attn.ln_k.weight"
1883
+ )
1616
1884
 
1617
- return {"scheduler": scheduler}
1885
+ # output projections.
1886
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
1887
+ f"joint_blocks.{i}.x_block.attn.proj.weight"
1888
+ )
1889
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = checkpoint.pop(
1890
+ f"joint_blocks.{i}.x_block.attn.proj.bias"
1891
+ )
1892
+ if not (i == num_layers - 1):
1893
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = checkpoint.pop(
1894
+ f"joint_blocks.{i}.context_block.attn.proj.weight"
1895
+ )
1896
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = checkpoint.pop(
1897
+ f"joint_blocks.{i}.context_block.attn.proj.bias"
1898
+ )
1899
+
1900
+ if i in dual_attention_layers:
1901
+ # Q, K, V
1902
+ sample_q2, sample_k2, sample_v2 = torch.chunk(
1903
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
1904
+ )
1905
+ sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk(
1906
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0
1907
+ )
1908
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2])
1909
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias])
1910
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2])
1911
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias])
1912
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2])
1913
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias])
1914
+
1915
+ # qk norm
1916
+ if has_qk_norm:
1917
+ converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = checkpoint.pop(
1918
+ f"joint_blocks.{i}.x_block.attn2.ln_q.weight"
1919
+ )
1920
+ converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = checkpoint.pop(
1921
+ f"joint_blocks.{i}.x_block.attn2.ln_k.weight"
1922
+ )
1923
+
1924
+ # output projections.
1925
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
1926
+ f"joint_blocks.{i}.x_block.attn2.proj.weight"
1927
+ )
1928
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
1929
+ f"joint_blocks.{i}.x_block.attn2.proj.bias"
1930
+ )
1931
+
1932
+ # norms.
1933
+ converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
1934
+ f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
1935
+ )
1936
+ converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = checkpoint.pop(
1937
+ f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias"
1938
+ )
1939
+ if not (i == num_layers - 1):
1940
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = checkpoint.pop(
1941
+ f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"
1942
+ )
1943
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = checkpoint.pop(
1944
+ f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"
1945
+ )
1946
+ else:
1947
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift(
1948
+ checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"),
1949
+ dim=caption_projection_dim,
1950
+ )
1951
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift(
1952
+ checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"),
1953
+ dim=caption_projection_dim,
1954
+ )
1955
+
1956
+ # ffs.
1957
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint.pop(
1958
+ f"joint_blocks.{i}.x_block.mlp.fc1.weight"
1959
+ )
1960
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint.pop(
1961
+ f"joint_blocks.{i}.x_block.mlp.fc1.bias"
1962
+ )
1963
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint.pop(
1964
+ f"joint_blocks.{i}.x_block.mlp.fc2.weight"
1965
+ )
1966
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = checkpoint.pop(
1967
+ f"joint_blocks.{i}.x_block.mlp.fc2.bias"
1968
+ )
1969
+ if not (i == num_layers - 1):
1970
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = checkpoint.pop(
1971
+ f"joint_blocks.{i}.context_block.mlp.fc1.weight"
1972
+ )
1973
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = checkpoint.pop(
1974
+ f"joint_blocks.{i}.context_block.mlp.fc1.bias"
1975
+ )
1976
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = checkpoint.pop(
1977
+ f"joint_blocks.{i}.context_block.mlp.fc2.weight"
1978
+ )
1979
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = checkpoint.pop(
1980
+ f"joint_blocks.{i}.context_block.mlp.fc2.bias"
1981
+ )
1982
+
1983
+ # Final blocks.
1984
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
1985
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
1986
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
1987
+ checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim
1988
+ )
1989
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
1990
+ checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim
1991
+ )
1992
+
1993
+ return converted_state_dict
1994
+
1995
+
1996
+ def is_t5_in_single_file(checkpoint):
1997
+ if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint:
1998
+ return True
1999
+
2000
+ return False
2001
+
2002
+
2003
+ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
2004
+ keys = list(checkpoint.keys())
2005
+ text_model_dict = {}
2006
+
2007
+ remove_prefixes = ["text_encoders.t5xxl.transformer."]
2008
+
2009
+ for key in keys:
2010
+ for prefix in remove_prefixes:
2011
+ if key.startswith(prefix):
2012
+ diffusers_key = key.replace(prefix, "")
2013
+ text_model_dict[diffusers_key] = checkpoint.get(key)
2014
+
2015
+ return text_model_dict
2016
+
2017
+
2018
+ def create_diffusers_t5_model_from_checkpoint(
2019
+ cls,
2020
+ checkpoint,
2021
+ subfolder="",
2022
+ config=None,
2023
+ torch_dtype=None,
2024
+ local_files_only=None,
2025
+ ):
2026
+ if config:
2027
+ config = {"pretrained_model_name_or_path": config}
2028
+ else:
2029
+ config = fetch_diffusers_config(checkpoint)
2030
+
2031
+ model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
2032
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
2033
+ with ctx():
2034
+ model = cls(model_config)
2035
+
2036
+ diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
2037
+
2038
+ if is_accelerate_available():
2039
+ unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
2040
+ if model._keys_to_ignore_on_load_unexpected is not None:
2041
+ for pat in model._keys_to_ignore_on_load_unexpected:
2042
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
2043
+
2044
+ if len(unexpected_keys) > 0:
2045
+ logger.warning(
2046
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
2047
+ )
2048
+
2049
+ else:
2050
+ model.load_state_dict(diffusers_format_checkpoint)
2051
+
2052
+ use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16)
2053
+ if use_keep_in_fp32_modules:
2054
+ keep_in_fp32_modules = model._keep_in_fp32_modules
2055
+ else:
2056
+ keep_in_fp32_modules = []
2057
+
2058
+ if keep_in_fp32_modules is not None:
2059
+ for name, param in model.named_parameters():
2060
+ if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
2061
+ # param = param.to(torch.float32) does not work here as only in the local scope.
2062
+ param.data = param.data.to(torch.float32)
2063
+
2064
+ return model
2065
+
2066
+
2067
+ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
2068
+ converted_state_dict = {}
2069
+ for k, v in checkpoint.items():
2070
+ if "pos_encoder" in k:
2071
+ continue
2072
+
2073
+ else:
2074
+ converted_state_dict[
2075
+ k.replace(".norms.0", ".norm1")
2076
+ .replace(".norms.1", ".norm2")
2077
+ .replace(".ff_norm", ".norm3")
2078
+ .replace(".attention_blocks.0", ".attn1")
2079
+ .replace(".attention_blocks.1", ".attn2")
2080
+ .replace(".temporal_transformer", "")
2081
+ ] = v
2082
+
2083
+ return converted_state_dict
2084
+
2085
+
2086
+ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2087
+ converted_state_dict = {}
2088
+ keys = list(checkpoint.keys())
2089
+ for k in keys:
2090
+ if "model.diffusion_model." in k:
2091
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
2092
+
2093
+ num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
2094
+ num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
2095
+ mlp_ratio = 4.0
2096
+ inner_dim = 3072
2097
+
2098
+ # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
2099
+ # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
2100
+ def swap_scale_shift(weight):
2101
+ shift, scale = weight.chunk(2, dim=0)
2102
+ new_weight = torch.cat([scale, shift], dim=0)
2103
+ return new_weight
2104
+
2105
+ ## time_text_embed.timestep_embedder <- time_in
2106
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
2107
+ "time_in.in_layer.weight"
2108
+ )
2109
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
2110
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
2111
+ "time_in.out_layer.weight"
2112
+ )
2113
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")
2114
+
2115
+ ## time_text_embed.text_embedder <- vector_in
2116
+ converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
2117
+ converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
2118
+ converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
2119
+ "vector_in.out_layer.weight"
2120
+ )
2121
+ converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")
2122
+
2123
+ # guidance
2124
+ has_guidance = any("guidance" in k for k in checkpoint)
2125
+ if has_guidance:
2126
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
2127
+ "guidance_in.in_layer.weight"
2128
+ )
2129
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
2130
+ "guidance_in.in_layer.bias"
2131
+ )
2132
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
2133
+ "guidance_in.out_layer.weight"
2134
+ )
2135
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
2136
+ "guidance_in.out_layer.bias"
2137
+ )
2138
+
2139
+ # context_embedder
2140
+ converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
2141
+ converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
2142
+
2143
+ # x_embedder
2144
+ converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
2145
+ converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
2146
+
2147
+ # double transformer blocks
2148
+ for i in range(num_layers):
2149
+ block_prefix = f"transformer_blocks.{i}."
2150
+ # norms.
2151
+ ## norm1
2152
+ converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
2153
+ f"double_blocks.{i}.img_mod.lin.weight"
2154
+ )
2155
+ converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
2156
+ f"double_blocks.{i}.img_mod.lin.bias"
2157
+ )
2158
+ ## norm1_context
2159
+ converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
2160
+ f"double_blocks.{i}.txt_mod.lin.weight"
2161
+ )
2162
+ converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
2163
+ f"double_blocks.{i}.txt_mod.lin.bias"
2164
+ )
2165
+ # Q, K, V
2166
+ sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
2167
+ context_q, context_k, context_v = torch.chunk(
2168
+ checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
2169
+ )
2170
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
2171
+ checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
2172
+ )
2173
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
2174
+ checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
2175
+ )
2176
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
2177
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
2178
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
2179
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
2180
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
2181
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
2182
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
2183
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
2184
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
2185
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
2186
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
2187
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
2188
+ # qk_norm
2189
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
2190
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale"
2191
+ )
2192
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
2193
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale"
2194
+ )
2195
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
2196
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
2197
+ )
2198
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
2199
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
2200
+ )
2201
+ # ff img_mlp
2202
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
2203
+ f"double_blocks.{i}.img_mlp.0.weight"
2204
+ )
2205
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
2206
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
2207
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
2208
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
2209
+ f"double_blocks.{i}.txt_mlp.0.weight"
2210
+ )
2211
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
2212
+ f"double_blocks.{i}.txt_mlp.0.bias"
2213
+ )
2214
+ converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
2215
+ f"double_blocks.{i}.txt_mlp.2.weight"
2216
+ )
2217
+ converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
2218
+ f"double_blocks.{i}.txt_mlp.2.bias"
2219
+ )
2220
+ # output projections.
2221
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
2222
+ f"double_blocks.{i}.img_attn.proj.weight"
2223
+ )
2224
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
2225
+ f"double_blocks.{i}.img_attn.proj.bias"
2226
+ )
2227
+ converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
2228
+ f"double_blocks.{i}.txt_attn.proj.weight"
2229
+ )
2230
+ converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
2231
+ f"double_blocks.{i}.txt_attn.proj.bias"
2232
+ )
2233
+
2234
+ # single transfomer blocks
2235
+ for i in range(num_single_layers):
2236
+ block_prefix = f"single_transformer_blocks.{i}."
2237
+ # norm.linear <- single_blocks.0.modulation.lin
2238
+ converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
2239
+ f"single_blocks.{i}.modulation.lin.weight"
2240
+ )
2241
+ converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
2242
+ f"single_blocks.{i}.modulation.lin.bias"
2243
+ )
2244
+ # Q, K, V, mlp
2245
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
2246
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
2247
+ q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
2248
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
2249
+ checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
2250
+ )
2251
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
2252
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
2253
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
2254
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
2255
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
2256
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
2257
+ converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
2258
+ converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
2259
+ # qk norm
2260
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
2261
+ f"single_blocks.{i}.norm.query_norm.scale"
2262
+ )
2263
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
2264
+ f"single_blocks.{i}.norm.key_norm.scale"
2265
+ )
2266
+ # output projections.
2267
+ converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
2268
+ converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
2269
+
2270
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
2271
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
2272
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
2273
+ checkpoint.pop("final_layer.adaLN_modulation.1.weight")
2274
+ )
2275
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
2276
+ checkpoint.pop("final_layer.adaLN_modulation.1.bias")
2277
+ )
2278
+
2279
+ return converted_state_dict
2280
+
2281
+
2282
+ def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2283
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key}
2284
+
2285
+ TRANSFORMER_KEYS_RENAME_DICT = {
2286
+ "model.diffusion_model.": "",
2287
+ "patchify_proj": "proj_in",
2288
+ "adaln_single": "time_embed",
2289
+ "q_norm": "norm_q",
2290
+ "k_norm": "norm_k",
2291
+ }
2292
+
2293
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {}
2294
+
2295
+ for key in list(converted_state_dict.keys()):
2296
+ new_key = key
2297
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
2298
+ new_key = new_key.replace(replace_key, rename_key)
2299
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
2300
+
2301
+ for key in list(converted_state_dict.keys()):
2302
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
2303
+ if special_key not in key:
2304
+ continue
2305
+ handler_fn_inplace(key, converted_state_dict)
2306
+
2307
+ return converted_state_dict
2308
+
2309
+
2310
+ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
2311
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae." in key}
2312
+
2313
+ def remove_keys_(key: str, state_dict):
2314
+ state_dict.pop(key)
2315
+
2316
+ VAE_KEYS_RENAME_DICT = {
2317
+ # common
2318
+ "vae.": "",
2319
+ # decoder
2320
+ "up_blocks.0": "mid_block",
2321
+ "up_blocks.1": "up_blocks.0",
2322
+ "up_blocks.2": "up_blocks.1.upsamplers.0",
2323
+ "up_blocks.3": "up_blocks.1",
2324
+ "up_blocks.4": "up_blocks.2.conv_in",
2325
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
2326
+ "up_blocks.6": "up_blocks.2",
2327
+ "up_blocks.7": "up_blocks.3.conv_in",
2328
+ "up_blocks.8": "up_blocks.3.upsamplers.0",
2329
+ "up_blocks.9": "up_blocks.3",
2330
+ # encoder
2331
+ "down_blocks.0": "down_blocks.0",
2332
+ "down_blocks.1": "down_blocks.0.downsamplers.0",
2333
+ "down_blocks.2": "down_blocks.0.conv_out",
2334
+ "down_blocks.3": "down_blocks.1",
2335
+ "down_blocks.4": "down_blocks.1.downsamplers.0",
2336
+ "down_blocks.5": "down_blocks.1.conv_out",
2337
+ "down_blocks.6": "down_blocks.2",
2338
+ "down_blocks.7": "down_blocks.2.downsamplers.0",
2339
+ "down_blocks.8": "down_blocks.3",
2340
+ "down_blocks.9": "mid_block",
2341
+ # common
2342
+ "conv_shortcut": "conv_shortcut.conv",
2343
+ "res_blocks": "resnets",
2344
+ "norm3.norm": "norm3",
2345
+ "per_channel_statistics.mean-of-means": "latents_mean",
2346
+ "per_channel_statistics.std-of-means": "latents_std",
2347
+ }
2348
+
2349
+ VAE_091_RENAME_DICT = {
2350
+ # decoder
2351
+ "up_blocks.0": "mid_block",
2352
+ "up_blocks.1": "up_blocks.0.upsamplers.0",
2353
+ "up_blocks.2": "up_blocks.0",
2354
+ "up_blocks.3": "up_blocks.1.upsamplers.0",
2355
+ "up_blocks.4": "up_blocks.1",
2356
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
2357
+ "up_blocks.6": "up_blocks.2",
2358
+ "up_blocks.7": "up_blocks.3.upsamplers.0",
2359
+ "up_blocks.8": "up_blocks.3",
2360
+ # common
2361
+ "last_time_embedder": "time_embedder",
2362
+ "last_scale_shift_table": "scale_shift_table",
2363
+ }
2364
+
2365
+ VAE_SPECIAL_KEYS_REMAP = {
2366
+ "per_channel_statistics.channel": remove_keys_,
2367
+ "per_channel_statistics.mean-of-means": remove_keys_,
2368
+ "per_channel_statistics.mean-of-stds": remove_keys_,
2369
+ "timestep_scale_multiplier": remove_keys_,
2370
+ }
2371
+
2372
+ if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
2373
+ VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
2374
+
2375
+ for key in list(converted_state_dict.keys()):
2376
+ new_key = key
2377
+ for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
2378
+ new_key = new_key.replace(replace_key, rename_key)
2379
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
2380
+
2381
+ for key in list(converted_state_dict.keys()):
2382
+ for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
2383
+ if special_key not in key:
2384
+ continue
2385
+ handler_fn_inplace(key, converted_state_dict)
2386
+
2387
+ return converted_state_dict
2388
+
2389
+
2390
+ def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
2391
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
2392
+
2393
+ def remap_qkv_(key: str, state_dict):
2394
+ qkv = state_dict.pop(key)
2395
+ q, k, v = torch.chunk(qkv, 3, dim=0)
2396
+ parent_module, _, _ = key.rpartition(".qkv.conv.weight")
2397
+ state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
2398
+ state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
2399
+ state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
2400
+
2401
+ def remap_proj_conv_(key: str, state_dict):
2402
+ parent_module, _, _ = key.rpartition(".proj.conv.weight")
2403
+ state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
2404
+
2405
+ AE_KEYS_RENAME_DICT = {
2406
+ # common
2407
+ "main.": "",
2408
+ "op_list.": "",
2409
+ "context_module": "attn",
2410
+ "local_module": "conv_out",
2411
+ # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
2412
+ # If there were more scales, there would be more layers, so a loop would be better to handle this
2413
+ "aggreg.0.0": "to_qkv_multiscale.0.proj_in",
2414
+ "aggreg.0.1": "to_qkv_multiscale.0.proj_out",
2415
+ "depth_conv.conv": "conv_depth",
2416
+ "inverted_conv.conv": "conv_inverted",
2417
+ "point_conv.conv": "conv_point",
2418
+ "point_conv.norm": "norm",
2419
+ "conv.conv.": "conv.",
2420
+ "conv1.conv": "conv1",
2421
+ "conv2.conv": "conv2",
2422
+ "conv2.norm": "norm",
2423
+ "proj.norm": "norm_out",
2424
+ # encoder
2425
+ "encoder.project_in.conv": "encoder.conv_in",
2426
+ "encoder.project_out.0.conv": "encoder.conv_out",
2427
+ "encoder.stages": "encoder.down_blocks",
2428
+ # decoder
2429
+ "decoder.project_in.conv": "decoder.conv_in",
2430
+ "decoder.project_out.0": "decoder.norm_out",
2431
+ "decoder.project_out.2.conv": "decoder.conv_out",
2432
+ "decoder.stages": "decoder.up_blocks",
2433
+ }
2434
+
2435
+ AE_F32C32_F64C128_F128C512_KEYS = {
2436
+ "encoder.project_in.conv": "encoder.conv_in.conv",
2437
+ "decoder.project_out.2.conv": "decoder.conv_out.conv",
2438
+ }
2439
+
2440
+ AE_SPECIAL_KEYS_REMAP = {
2441
+ "qkv.conv.weight": remap_qkv_,
2442
+ "proj.conv.weight": remap_proj_conv_,
2443
+ }
2444
+ if "encoder.project_in.conv.bias" not in converted_state_dict:
2445
+ AE_KEYS_RENAME_DICT.update(AE_F32C32_F64C128_F128C512_KEYS)
2446
+
2447
+ for key in list(converted_state_dict.keys()):
2448
+ new_key = key[:]
2449
+ for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
2450
+ new_key = new_key.replace(replace_key, rename_key)
2451
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
2452
+
2453
+ for key in list(converted_state_dict.keys()):
2454
+ for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
2455
+ if special_key not in key:
2456
+ continue
2457
+ handler_fn_inplace(key, converted_state_dict)
2458
+
2459
+ return converted_state_dict
2460
+
2461
+
2462
+ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2463
+ new_state_dict = {}
2464
+
2465
+ # Comfy checkpoints add this prefix
2466
+ keys = list(checkpoint.keys())
2467
+ for k in keys:
2468
+ if "model.diffusion_model." in k:
2469
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
2470
+
2471
+ # Convert patch_embed
2472
+ new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
2473
+ new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
2474
+
2475
+ # Convert time_embed
2476
+ new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
2477
+ new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
2478
+ new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
2479
+ new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
2480
+ new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
2481
+ new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
2482
+ new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
2483
+ new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
2484
+ new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
2485
+ new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
2486
+ new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
2487
+ new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
2488
+
2489
+ # Convert transformer blocks
2490
+ num_layers = 48
2491
+ for i in range(num_layers):
2492
+ block_prefix = f"transformer_blocks.{i}."
2493
+ old_prefix = f"blocks.{i}."
2494
+
2495
+ # norm1
2496
+ new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
2497
+ new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
2498
+ if i < num_layers - 1:
2499
+ new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
2500
+ new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
2501
+ else:
2502
+ new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
2503
+ old_prefix + "mod_y.weight"
2504
+ )
2505
+ new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
2506
+
2507
+ # Visual attention
2508
+ qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
2509
+ q, k, v = qkv_weight.chunk(3, dim=0)
2510
+
2511
+ new_state_dict[block_prefix + "attn1.to_q.weight"] = q
2512
+ new_state_dict[block_prefix + "attn1.to_k.weight"] = k
2513
+ new_state_dict[block_prefix + "attn1.to_v.weight"] = v
2514
+ new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
2515
+ new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
2516
+ new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
2517
+ new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
2518
+
2519
+ # Context attention
2520
+ qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
2521
+ q, k, v = qkv_weight.chunk(3, dim=0)
2522
+
2523
+ new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
2524
+ new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
2525
+ new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
2526
+ new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
2527
+ old_prefix + "attn.q_norm_y.weight"
2528
+ )
2529
+ new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
2530
+ old_prefix + "attn.k_norm_y.weight"
2531
+ )
2532
+ if i < num_layers - 1:
2533
+ new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
2534
+ old_prefix + "attn.proj_y.weight"
2535
+ )
2536
+ new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
2537
+
2538
+ # MLP
2539
+ new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
2540
+ checkpoint.pop(old_prefix + "mlp_x.w1.weight")
2541
+ )
2542
+ new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
2543
+ if i < num_layers - 1:
2544
+ new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
2545
+ checkpoint.pop(old_prefix + "mlp_y.w1.weight")
2546
+ )
2547
+ new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
2548
+
2549
+ # Output layers
2550
+ new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
2551
+ new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
2552
+ new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
2553
+ new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
2554
+
2555
+ new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
2556
+
2557
+ return new_state_dict
2558
+
2559
+
2560
+ def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
2561
+ def remap_norm_scale_shift_(key, state_dict):
2562
+ weight = state_dict.pop(key)
2563
+ shift, scale = weight.chunk(2, dim=0)
2564
+ new_weight = torch.cat([scale, shift], dim=0)
2565
+ state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
2566
+
2567
+ def remap_txt_in_(key, state_dict):
2568
+ def rename_key(key):
2569
+ new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
2570
+ new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
2571
+ new_key = new_key.replace("txt_in", "context_embedder")
2572
+ new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
2573
+ new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
2574
+ new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
2575
+ new_key = new_key.replace("mlp", "ff")
2576
+ return new_key
2577
+
2578
+ if "self_attn_qkv" in key:
2579
+ weight = state_dict.pop(key)
2580
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
2581
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
2582
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
2583
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
2584
+ else:
2585
+ state_dict[rename_key(key)] = state_dict.pop(key)
2586
+
2587
+ def remap_img_attn_qkv_(key, state_dict):
2588
+ weight = state_dict.pop(key)
2589
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
2590
+ state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
2591
+ state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
2592
+ state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
2593
+
2594
+ def remap_txt_attn_qkv_(key, state_dict):
2595
+ weight = state_dict.pop(key)
2596
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
2597
+ state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
2598
+ state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
2599
+ state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
2600
+
2601
+ def remap_single_transformer_blocks_(key, state_dict):
2602
+ hidden_size = 3072
2603
+
2604
+ if "linear1.weight" in key:
2605
+ linear1_weight = state_dict.pop(key)
2606
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
2607
+ q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
2608
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
2609
+ state_dict[f"{new_key}.attn.to_q.weight"] = q
2610
+ state_dict[f"{new_key}.attn.to_k.weight"] = k
2611
+ state_dict[f"{new_key}.attn.to_v.weight"] = v
2612
+ state_dict[f"{new_key}.proj_mlp.weight"] = mlp
2613
+
2614
+ elif "linear1.bias" in key:
2615
+ linear1_bias = state_dict.pop(key)
2616
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
2617
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
2618
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
2619
+ state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
2620
+ state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
2621
+ state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
2622
+ state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
2623
+
2624
+ else:
2625
+ new_key = key.replace("single_blocks", "single_transformer_blocks")
2626
+ new_key = new_key.replace("linear2", "proj_out")
2627
+ new_key = new_key.replace("q_norm", "attn.norm_q")
2628
+ new_key = new_key.replace("k_norm", "attn.norm_k")
2629
+ state_dict[new_key] = state_dict.pop(key)
2630
+
2631
+ TRANSFORMER_KEYS_RENAME_DICT = {
2632
+ "img_in": "x_embedder",
2633
+ "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
2634
+ "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
2635
+ "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
2636
+ "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
2637
+ "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
2638
+ "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
2639
+ "double_blocks": "transformer_blocks",
2640
+ "img_attn_q_norm": "attn.norm_q",
2641
+ "img_attn_k_norm": "attn.norm_k",
2642
+ "img_attn_proj": "attn.to_out.0",
2643
+ "txt_attn_q_norm": "attn.norm_added_q",
2644
+ "txt_attn_k_norm": "attn.norm_added_k",
2645
+ "txt_attn_proj": "attn.to_add_out",
2646
+ "img_mod.linear": "norm1.linear",
2647
+ "img_norm1": "norm1.norm",
2648
+ "img_norm2": "norm2",
2649
+ "img_mlp": "ff",
2650
+ "txt_mod.linear": "norm1_context.linear",
2651
+ "txt_norm1": "norm1.norm",
2652
+ "txt_norm2": "norm2_context",
2653
+ "txt_mlp": "ff_context",
2654
+ "self_attn_proj": "attn.to_out.0",
2655
+ "modulation.linear": "norm.linear",
2656
+ "pre_norm": "norm.norm",
2657
+ "final_layer.norm_final": "norm_out.norm",
2658
+ "final_layer.linear": "proj_out",
2659
+ "fc1": "net.0.proj",
2660
+ "fc2": "net.2",
2661
+ "input_embedder": "proj_in",
2662
+ }
2663
+
2664
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
2665
+ "txt_in": remap_txt_in_,
2666
+ "img_attn_qkv": remap_img_attn_qkv_,
2667
+ "txt_attn_qkv": remap_txt_attn_qkv_,
2668
+ "single_blocks": remap_single_transformer_blocks_,
2669
+ "final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
2670
+ }
2671
+
2672
+ def update_state_dict_(state_dict, old_key, new_key):
2673
+ state_dict[new_key] = state_dict.pop(old_key)
2674
+
2675
+ for key in list(checkpoint.keys()):
2676
+ new_key = key[:]
2677
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
2678
+ new_key = new_key.replace(replace_key, rename_key)
2679
+ update_state_dict_(checkpoint, key, new_key)
2680
+
2681
+ for key in list(checkpoint.keys()):
2682
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
2683
+ if special_key not in key:
2684
+ continue
2685
+ handler_fn_inplace(key, checkpoint)
2686
+
2687
+ return checkpoint