diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -46,6 +46,7 @@ from ..utils import (
46
46
  )
47
47
  from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
48
48
  from ..utils.hub_utils import _get_model_file
49
+ from ..utils.torch_utils import empty_device_cache
49
50
 
50
51
 
51
52
  if is_transformers_available():
@@ -54,11 +55,12 @@ if is_transformers_available():
54
55
  if is_accelerate_available():
55
56
  from accelerate import init_empty_weights
56
57
 
57
- from ..models.modeling_utils import load_model_dict_into_meta
58
+ from ..models.model_loading_utils import load_model_dict_into_meta
58
59
 
59
60
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
60
61
 
61
62
  CHECKPOINT_KEY_NAMES = {
63
+ "v1": "model.diffusion_model.output_blocks.11.0.skip_connection.weight",
62
64
  "v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
63
65
  "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
64
66
  "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
@@ -126,6 +128,18 @@ CHECKPOINT_KEY_NAMES = {
126
128
  ],
127
129
  "wan": ["model.diffusion_model.head.modulation", "head.modulation"],
128
130
  "wan_vae": "decoder.middle.0.residual.0.gamma",
131
+ "wan_vace": "vace_blocks.0.after_proj.bias",
132
+ "hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
133
+ "cosmos-1.0": [
134
+ "net.x_embedder.proj.1.weight",
135
+ "net.blocks.block1.blocks.0.block.attn.to_q.0.weight",
136
+ "net.extra_pos_embedder.pos_emb_h",
137
+ ],
138
+ "cosmos-2.0": [
139
+ "net.x_embedder.proj.1.weight",
140
+ "net.blocks.0.self_attn.q_proj.weight",
141
+ "net.pos_embedder.dim_spatial_range",
142
+ ],
129
143
  }
130
144
 
131
145
  DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -177,6 +191,8 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
177
191
  "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
178
192
  "ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
179
193
  "ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
194
+ "ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
195
+ "ltx-video-0.9.7": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.7-dev"},
180
196
  "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
181
197
  "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
182
198
  "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
@@ -189,6 +205,17 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
189
205
  "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
190
206
  "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
191
207
  "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
208
+ "wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
209
+ "wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
210
+ "hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
211
+ "cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"},
212
+ "cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"},
213
+ "cosmos-1.0-v2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"},
214
+ "cosmos-1.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Video2World"},
215
+ "cosmos-2.0-t2i-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Text2Image"},
216
+ "cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"},
217
+ "cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
218
+ "cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
192
219
  }
193
220
 
194
221
  # Use to configure model sample size when original config is provided
@@ -404,13 +431,16 @@ def load_single_file_checkpoint(
404
431
  local_files_only=None,
405
432
  revision=None,
406
433
  disable_mmap=False,
434
+ user_agent=None,
407
435
  ):
436
+ if user_agent is None:
437
+ user_agent = {"file_type": "single_file", "framework": "pytorch"}
438
+
408
439
  if os.path.isfile(pretrained_model_link_or_path):
409
440
  pretrained_model_link_or_path = pretrained_model_link_or_path
410
441
 
411
442
  else:
412
443
  repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
413
- user_agent = {"file_type": "single_file", "framework": "pytorch"}
414
444
  pretrained_model_link_or_path = _get_model_file(
415
445
  repo_id,
416
446
  weights_name=weights_name,
@@ -638,7 +668,12 @@ def infer_diffusers_model_type(checkpoint):
638
668
  model_type = "flux-schnell"
639
669
 
640
670
  elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
641
- if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
671
+ has_vae = "vae.encoder.conv_in.conv.bias" in checkpoint
672
+ if any(key.endswith("transformer_blocks.47.scale_shift_table") for key in checkpoint):
673
+ model_type = "ltx-video-0.9.7"
674
+ elif has_vae and checkpoint["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
675
+ model_type = "ltx-video-0.9.5"
676
+ elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
642
677
  model_type = "ltx-video-0.9.1"
643
678
  else:
644
679
  model_type = "ltx-video"
@@ -686,15 +721,44 @@ def infer_diffusers_model_type(checkpoint):
686
721
  else:
687
722
  target_key = "patch_embedding.weight"
688
723
 
689
- if checkpoint[target_key].shape[0] == 1536:
724
+ if CHECKPOINT_KEY_NAMES["wan_vace"] in checkpoint:
725
+ if checkpoint[target_key].shape[0] == 1536:
726
+ model_type = "wan-vace-1.3B"
727
+ elif checkpoint[target_key].shape[0] == 5120:
728
+ model_type = "wan-vace-14B"
729
+
730
+ elif checkpoint[target_key].shape[0] == 1536:
690
731
  model_type = "wan-t2v-1.3B"
691
732
  elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
692
733
  model_type = "wan-t2v-14B"
693
734
  else:
694
735
  model_type = "wan-i2v-14B"
736
+
695
737
  elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
696
738
  # All Wan models use the same VAE so we can use the same default model repo to fetch the config
697
739
  model_type = "wan-t2v-14B"
740
+
741
+ elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
742
+ model_type = "hidream"
743
+
744
+ elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-1.0"]):
745
+ x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-1.0"][0]].shape
746
+ if x_embedder_shape[1] == 68:
747
+ model_type = "cosmos-1.0-t2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-t2w-14B"
748
+ elif x_embedder_shape[1] == 72:
749
+ model_type = "cosmos-1.0-v2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-v2w-14B"
750
+ else:
751
+ raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 1.0 model.")
752
+
753
+ elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-2.0"]):
754
+ x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-2.0"][0]].shape
755
+ if x_embedder_shape[1] == 68:
756
+ model_type = "cosmos-2.0-t2i-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-t2i-14B"
757
+ elif x_embedder_shape[1] == 72:
758
+ model_type = "cosmos-2.0-v2w-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-v2w-14B"
759
+ else:
760
+ raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
761
+
698
762
  else:
699
763
  model_type = "v1"
700
764
 
@@ -1627,6 +1691,7 @@ def create_diffusers_clip_model_from_ldm(
1627
1691
 
1628
1692
  if is_accelerate_available():
1629
1693
  load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1694
+ empty_device_cache()
1630
1695
  else:
1631
1696
  model.load_state_dict(diffusers_format_checkpoint, strict=False)
1632
1697
 
@@ -2086,6 +2151,7 @@ def create_diffusers_t5_model_from_checkpoint(
2086
2151
 
2087
2152
  if is_accelerate_available():
2088
2153
  load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
2154
+ empty_device_cache()
2089
2155
  else:
2090
2156
  model.load_state_dict(diffusers_format_checkpoint)
2091
2157
 
@@ -2272,7 +2338,7 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2272
2338
  f"double_blocks.{i}.txt_attn.proj.bias"
2273
2339
  )
2274
2340
 
2275
- # single transfomer blocks
2341
+ # single transformer blocks
2276
2342
  for i in range(num_single_layers):
2277
2343
  block_prefix = f"single_transformer_blocks.{i}."
2278
2344
  # norm.linear <- single_blocks.0.modulation.lin
@@ -2403,13 +2469,41 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
2403
2469
  "last_scale_shift_table": "scale_shift_table",
2404
2470
  }
2405
2471
 
2472
+ VAE_095_RENAME_DICT = {
2473
+ # decoder
2474
+ "up_blocks.0": "mid_block",
2475
+ "up_blocks.1": "up_blocks.0.upsamplers.0",
2476
+ "up_blocks.2": "up_blocks.0",
2477
+ "up_blocks.3": "up_blocks.1.upsamplers.0",
2478
+ "up_blocks.4": "up_blocks.1",
2479
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
2480
+ "up_blocks.6": "up_blocks.2",
2481
+ "up_blocks.7": "up_blocks.3.upsamplers.0",
2482
+ "up_blocks.8": "up_blocks.3",
2483
+ # encoder
2484
+ "down_blocks.0": "down_blocks.0",
2485
+ "down_blocks.1": "down_blocks.0.downsamplers.0",
2486
+ "down_blocks.2": "down_blocks.1",
2487
+ "down_blocks.3": "down_blocks.1.downsamplers.0",
2488
+ "down_blocks.4": "down_blocks.2",
2489
+ "down_blocks.5": "down_blocks.2.downsamplers.0",
2490
+ "down_blocks.6": "down_blocks.3",
2491
+ "down_blocks.7": "down_blocks.3.downsamplers.0",
2492
+ "down_blocks.8": "mid_block",
2493
+ # common
2494
+ "last_time_embedder": "time_embedder",
2495
+ "last_scale_shift_table": "scale_shift_table",
2496
+ }
2497
+
2406
2498
  VAE_SPECIAL_KEYS_REMAP = {
2407
2499
  "per_channel_statistics.channel": remove_keys_,
2408
2500
  "per_channel_statistics.mean-of-means": remove_keys_,
2409
2501
  "per_channel_statistics.mean-of-stds": remove_keys_,
2410
2502
  }
2411
2503
 
2412
- if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
2504
+ if converted_state_dict["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
2505
+ VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
2506
+ elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
2413
2507
  VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
2414
2508
 
2415
2509
  for key in list(converted_state_dict.keys()):
@@ -2838,7 +2932,7 @@ def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2838
2932
  def convert_lumina2_to_diffusers(checkpoint, **kwargs):
2839
2933
  converted_state_dict = {}
2840
2934
 
2841
- # Original Lumina-Image-2 has an extra norm paramter that is unused
2935
+ # Original Lumina-Image-2 has an extra norm parameter that is unused
2842
2936
  # We just remove it here
2843
2937
  checkpoint.pop("norm_final.weight", None)
2844
2938
 
@@ -3051,6 +3145,9 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
3051
3145
  "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
3052
3146
  "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
3053
3147
  "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
3148
+ # For the VACE model
3149
+ "before_proj": "proj_in",
3150
+ "after_proj": "proj_out",
3054
3151
  }
3055
3152
 
3056
3153
  for key in list(checkpoint.keys()):
@@ -3259,3 +3356,294 @@ def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
3259
3356
  converted_state_dict[key] = value
3260
3357
 
3261
3358
  return converted_state_dict
3359
+
3360
+
3361
+ def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
3362
+ keys = list(checkpoint.keys())
3363
+ for k in keys:
3364
+ if "model.diffusion_model." in k:
3365
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
3366
+
3367
+ return checkpoint
3368
+
3369
+
3370
+ def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
3371
+ converted_state_dict = {}
3372
+ keys = list(checkpoint.keys())
3373
+
3374
+ for k in keys:
3375
+ if "model.diffusion_model." in k:
3376
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
3377
+
3378
+ num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
3379
+ num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
3380
+ num_guidance_layers = (
3381
+ list(set(int(k.split(".", 3)[2]) for k in checkpoint if "distilled_guidance_layer.layers." in k))[-1] + 1 # noqa: C401
3382
+ )
3383
+ mlp_ratio = 4.0
3384
+ inner_dim = 3072
3385
+
3386
+ # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
3387
+ # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
3388
+ def swap_scale_shift(weight):
3389
+ shift, scale = weight.chunk(2, dim=0)
3390
+ new_weight = torch.cat([scale, shift], dim=0)
3391
+ return new_weight
3392
+
3393
+ # guidance
3394
+ converted_state_dict["distilled_guidance_layer.in_proj.bias"] = checkpoint.pop(
3395
+ "distilled_guidance_layer.in_proj.bias"
3396
+ )
3397
+ converted_state_dict["distilled_guidance_layer.in_proj.weight"] = checkpoint.pop(
3398
+ "distilled_guidance_layer.in_proj.weight"
3399
+ )
3400
+ converted_state_dict["distilled_guidance_layer.out_proj.bias"] = checkpoint.pop(
3401
+ "distilled_guidance_layer.out_proj.bias"
3402
+ )
3403
+ converted_state_dict["distilled_guidance_layer.out_proj.weight"] = checkpoint.pop(
3404
+ "distilled_guidance_layer.out_proj.weight"
3405
+ )
3406
+ for i in range(num_guidance_layers):
3407
+ block_prefix = f"distilled_guidance_layer.layers.{i}."
3408
+ converted_state_dict[f"{block_prefix}linear_1.bias"] = checkpoint.pop(
3409
+ f"distilled_guidance_layer.layers.{i}.in_layer.bias"
3410
+ )
3411
+ converted_state_dict[f"{block_prefix}linear_1.weight"] = checkpoint.pop(
3412
+ f"distilled_guidance_layer.layers.{i}.in_layer.weight"
3413
+ )
3414
+ converted_state_dict[f"{block_prefix}linear_2.bias"] = checkpoint.pop(
3415
+ f"distilled_guidance_layer.layers.{i}.out_layer.bias"
3416
+ )
3417
+ converted_state_dict[f"{block_prefix}linear_2.weight"] = checkpoint.pop(
3418
+ f"distilled_guidance_layer.layers.{i}.out_layer.weight"
3419
+ )
3420
+ converted_state_dict[f"distilled_guidance_layer.norms.{i}.weight"] = checkpoint.pop(
3421
+ f"distilled_guidance_layer.norms.{i}.scale"
3422
+ )
3423
+
3424
+ # context_embedder
3425
+ converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
3426
+ converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
3427
+
3428
+ # x_embedder
3429
+ converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
3430
+ converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
3431
+
3432
+ # double transformer blocks
3433
+ for i in range(num_layers):
3434
+ block_prefix = f"transformer_blocks.{i}."
3435
+ # Q, K, V
3436
+ sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
3437
+ context_q, context_k, context_v = torch.chunk(
3438
+ checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
3439
+ )
3440
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
3441
+ checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
3442
+ )
3443
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
3444
+ checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
3445
+ )
3446
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
3447
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
3448
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
3449
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
3450
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
3451
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
3452
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
3453
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
3454
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
3455
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
3456
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
3457
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
3458
+ # qk_norm
3459
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
3460
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale"
3461
+ )
3462
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
3463
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale"
3464
+ )
3465
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
3466
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
3467
+ )
3468
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
3469
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
3470
+ )
3471
+ # ff img_mlp
3472
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
3473
+ f"double_blocks.{i}.img_mlp.0.weight"
3474
+ )
3475
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
3476
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
3477
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
3478
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
3479
+ f"double_blocks.{i}.txt_mlp.0.weight"
3480
+ )
3481
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
3482
+ f"double_blocks.{i}.txt_mlp.0.bias"
3483
+ )
3484
+ converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
3485
+ f"double_blocks.{i}.txt_mlp.2.weight"
3486
+ )
3487
+ converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
3488
+ f"double_blocks.{i}.txt_mlp.2.bias"
3489
+ )
3490
+ # output projections.
3491
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
3492
+ f"double_blocks.{i}.img_attn.proj.weight"
3493
+ )
3494
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
3495
+ f"double_blocks.{i}.img_attn.proj.bias"
3496
+ )
3497
+ converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
3498
+ f"double_blocks.{i}.txt_attn.proj.weight"
3499
+ )
3500
+ converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
3501
+ f"double_blocks.{i}.txt_attn.proj.bias"
3502
+ )
3503
+
3504
+ # single transformer blocks
3505
+ for i in range(num_single_layers):
3506
+ block_prefix = f"single_transformer_blocks.{i}."
3507
+ # Q, K, V, mlp
3508
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
3509
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
3510
+ q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
3511
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
3512
+ checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
3513
+ )
3514
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
3515
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
3516
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
3517
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
3518
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
3519
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
3520
+ converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
3521
+ converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
3522
+ # qk norm
3523
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
3524
+ f"single_blocks.{i}.norm.query_norm.scale"
3525
+ )
3526
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
3527
+ f"single_blocks.{i}.norm.key_norm.scale"
3528
+ )
3529
+ # output projections.
3530
+ converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
3531
+ converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
3532
+
3533
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
3534
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
3535
+
3536
+ return converted_state_dict
3537
+
3538
+
3539
+ def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
3540
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
3541
+
3542
+ def remove_keys_(key: str, state_dict):
3543
+ state_dict.pop(key)
3544
+
3545
+ def rename_transformer_blocks_(key: str, state_dict):
3546
+ block_index = int(key.split(".")[1].removeprefix("block"))
3547
+ new_key = key
3548
+ old_prefix = f"blocks.block{block_index}"
3549
+ new_prefix = f"transformer_blocks.{block_index}"
3550
+ new_key = new_prefix + new_key.removeprefix(old_prefix)
3551
+ state_dict[new_key] = state_dict.pop(key)
3552
+
3553
+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
3554
+ "t_embedder.1": "time_embed.t_embedder",
3555
+ "affline_norm": "time_embed.norm",
3556
+ ".blocks.0.block.attn": ".attn1",
3557
+ ".blocks.1.block.attn": ".attn2",
3558
+ ".blocks.2.block": ".ff",
3559
+ ".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
3560
+ ".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
3561
+ ".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
3562
+ ".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
3563
+ ".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
3564
+ ".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
3565
+ "to_q.0": "to_q",
3566
+ "to_q.1": "norm_q",
3567
+ "to_k.0": "to_k",
3568
+ "to_k.1": "norm_k",
3569
+ "to_v.0": "to_v",
3570
+ "layer1": "net.0.proj",
3571
+ "layer2": "net.2",
3572
+ "proj.1": "proj",
3573
+ "x_embedder": "patch_embed",
3574
+ "extra_pos_embedder": "learnable_pos_embed",
3575
+ "final_layer.adaLN_modulation.1": "norm_out.linear_1",
3576
+ "final_layer.adaLN_modulation.2": "norm_out.linear_2",
3577
+ "final_layer.linear": "proj_out",
3578
+ }
3579
+
3580
+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
3581
+ "blocks.block": rename_transformer_blocks_,
3582
+ "logvar.0.freqs": remove_keys_,
3583
+ "logvar.0.phases": remove_keys_,
3584
+ "logvar.1.weight": remove_keys_,
3585
+ "pos_embedder.seq": remove_keys_,
3586
+ }
3587
+
3588
+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
3589
+ "t_embedder.1": "time_embed.t_embedder",
3590
+ "t_embedding_norm": "time_embed.norm",
3591
+ "blocks": "transformer_blocks",
3592
+ "adaln_modulation_self_attn.1": "norm1.linear_1",
3593
+ "adaln_modulation_self_attn.2": "norm1.linear_2",
3594
+ "adaln_modulation_cross_attn.1": "norm2.linear_1",
3595
+ "adaln_modulation_cross_attn.2": "norm2.linear_2",
3596
+ "adaln_modulation_mlp.1": "norm3.linear_1",
3597
+ "adaln_modulation_mlp.2": "norm3.linear_2",
3598
+ "self_attn": "attn1",
3599
+ "cross_attn": "attn2",
3600
+ "q_proj": "to_q",
3601
+ "k_proj": "to_k",
3602
+ "v_proj": "to_v",
3603
+ "output_proj": "to_out.0",
3604
+ "q_norm": "norm_q",
3605
+ "k_norm": "norm_k",
3606
+ "mlp.layer1": "ff.net.0.proj",
3607
+ "mlp.layer2": "ff.net.2",
3608
+ "x_embedder.proj.1": "patch_embed.proj",
3609
+ "final_layer.adaln_modulation.1": "norm_out.linear_1",
3610
+ "final_layer.adaln_modulation.2": "norm_out.linear_2",
3611
+ "final_layer.linear": "proj_out",
3612
+ }
3613
+
3614
+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
3615
+ "accum_video_sample_counter": remove_keys_,
3616
+ "accum_image_sample_counter": remove_keys_,
3617
+ "accum_iteration": remove_keys_,
3618
+ "accum_train_in_hours": remove_keys_,
3619
+ "pos_embedder.seq": remove_keys_,
3620
+ "pos_embedder.dim_spatial_range": remove_keys_,
3621
+ "pos_embedder.dim_temporal_range": remove_keys_,
3622
+ "_extra_state": remove_keys_,
3623
+ }
3624
+
3625
+ PREFIX_KEY = "net."
3626
+ if "net.blocks.block1.blocks.0.block.attn.to_q.0.weight" in checkpoint:
3627
+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
3628
+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
3629
+ else:
3630
+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
3631
+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
3632
+
3633
+ state_dict_keys = list(converted_state_dict.keys())
3634
+ for key in state_dict_keys:
3635
+ new_key = key[:]
3636
+ if new_key.startswith(PREFIX_KEY):
3637
+ new_key = new_key.removeprefix(PREFIX_KEY)
3638
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
3639
+ new_key = new_key.replace(replace_key, rename_key)
3640
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
3641
+
3642
+ state_dict_keys = list(converted_state_dict.keys())
3643
+ for key in state_dict_keys:
3644
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
3645
+ if special_key not in key:
3646
+ continue
3647
+ handler_fn_inplace(key, converted_state_dict)
3648
+
3649
+ return converted_state_dict
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -427,7 +427,8 @@ class TextualInversionLoaderMixin:
427
427
  logger.info(
428
428
  "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
429
429
  )
430
- remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
430
+ if is_sequential_cpu_offload or is_model_cpu_offload:
431
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
431
432
 
432
433
  # 7.2 save expected device and dtype
433
434
  device = text_encoder.device
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -17,12 +17,10 @@ from ..models.embeddings import (
17
17
  ImageProjection,
18
18
  MultiIPAdapterImageProjection,
19
19
  )
20
- from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
21
- from ..utils import (
22
- is_accelerate_available,
23
- is_torch_version,
24
- logging,
25
- )
20
+ from ..models.model_loading_utils import load_model_dict_into_meta
21
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
22
+ from ..utils import is_accelerate_available, is_torch_version, logging
23
+ from ..utils.torch_utils import empty_device_cache
26
24
 
27
25
 
28
26
  if is_accelerate_available():
@@ -84,13 +82,12 @@ class FluxTransformer2DLoadersMixin:
84
82
  else:
85
83
  device_map = {"": self.device}
86
84
  load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
85
+ empty_device_cache()
87
86
 
88
87
  return image_projection
89
88
 
90
89
  def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
91
- from ..models.attention_processor import (
92
- FluxIPAdapterJointAttnProcessor2_0,
93
- )
90
+ from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
94
91
 
95
92
  if low_cpu_mem_usage:
96
93
  if is_accelerate_available():
@@ -122,7 +119,7 @@ class FluxTransformer2DLoadersMixin:
122
119
  else:
123
120
  cross_attention_dim = self.config.joint_attention_dim
124
121
  hidden_size = self.inner_dim
125
- attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
122
+ attn_processor_class = FluxIPAdapterAttnProcessor
126
123
  num_image_text_embeds = []
127
124
  for state_dict in state_dicts:
128
125
  if "proj.weight" in state_dict["image_proj"]:
@@ -158,6 +155,8 @@ class FluxTransformer2DLoadersMixin:
158
155
 
159
156
  key_id += 1
160
157
 
158
+ empty_device_cache()
159
+
161
160
  return attn_procs
162
161
 
163
162
  def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -16,8 +16,10 @@ from typing import Dict
16
16
 
17
17
  from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
18
18
  from ..models.embeddings import IPAdapterTimeImageProjection
19
- from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
19
+ from ..models.model_loading_utils import load_model_dict_into_meta
20
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
20
21
  from ..utils import is_accelerate_available, is_torch_version, logging
22
+ from ..utils.torch_utils import empty_device_cache
21
23
 
22
24
 
23
25
  logger = logging.get_logger(__name__)
@@ -80,6 +82,8 @@ class SD3Transformer2DLoadersMixin:
80
82
  attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
81
83
  )
82
84
 
85
+ empty_device_cache()
86
+
83
87
  return attn_procs
84
88
 
85
89
  def _convert_ip_adapter_image_proj_to_diffusers(
@@ -123,7 +127,7 @@ class SD3Transformer2DLoadersMixin:
123
127
  key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
124
128
  updated_state_dict[key] = value
125
129
 
126
- # Image projetion parameters
130
+ # Image projection parameters
127
131
  embed_dim = updated_state_dict["proj_in.weight"].shape[1]
128
132
  output_dim = updated_state_dict["proj_out.weight"].shape[0]
129
133
  hidden_dim = updated_state_dict["proj_in.weight"].shape[0]
@@ -147,6 +151,7 @@ class SD3Transformer2DLoadersMixin:
147
151
  else:
148
152
  device_map = {"": self.device}
149
153
  load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
154
+ empty_device_cache()
150
155
 
151
156
  return image_proj
152
157