diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +595 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
  384. diffusers-0.33.1.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -22,17 +22,20 @@ from diffusers.loaders import FromOriginalModelMixin
22
22
 
23
23
  from ...configuration_utils import ConfigMixin, register_to_config
24
24
  from ...loaders import PeftAdapterMixin
25
- from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
25
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
26
26
  from ..attention import FeedForward
27
27
  from ..attention_processor import Attention, AttentionProcessor
28
+ from ..cache_utils import CacheMixin
28
29
  from ..embeddings import (
29
- CombinedTimestepGuidanceTextProjEmbeddings,
30
30
  CombinedTimestepTextProjEmbeddings,
31
+ PixArtAlphaTextProjection,
32
+ TimestepEmbedding,
33
+ Timesteps,
31
34
  get_1d_rotary_pos_embed,
32
35
  )
33
36
  from ..modeling_outputs import Transformer2DModelOutput
34
37
  from ..modeling_utils import ModelMixin
35
- from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
38
+ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
36
39
 
37
40
 
38
41
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -172,6 +175,141 @@ class HunyuanVideoAdaNorm(nn.Module):
172
175
  return gate_msa, gate_mlp
173
176
 
174
177
 
178
+ class HunyuanVideoTokenReplaceAdaLayerNormZero(nn.Module):
179
+ def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
180
+ super().__init__()
181
+
182
+ self.silu = nn.SiLU()
183
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
184
+
185
+ if norm_type == "layer_norm":
186
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
187
+ elif norm_type == "fp32_layer_norm":
188
+ self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
189
+ else:
190
+ raise ValueError(
191
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
192
+ )
193
+
194
+ def forward(
195
+ self,
196
+ hidden_states: torch.Tensor,
197
+ emb: torch.Tensor,
198
+ token_replace_emb: torch.Tensor,
199
+ first_frame_num_tokens: int,
200
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
201
+ emb = self.linear(self.silu(emb))
202
+ token_replace_emb = self.linear(self.silu(token_replace_emb))
203
+
204
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
205
+ tr_shift_msa, tr_scale_msa, tr_gate_msa, tr_shift_mlp, tr_scale_mlp, tr_gate_mlp = token_replace_emb.chunk(
206
+ 6, dim=1
207
+ )
208
+
209
+ norm_hidden_states = self.norm(hidden_states)
210
+ hidden_states_zero = (
211
+ norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
212
+ )
213
+ hidden_states_orig = (
214
+ norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
215
+ )
216
+ hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
217
+
218
+ return (
219
+ hidden_states,
220
+ gate_msa,
221
+ shift_mlp,
222
+ scale_mlp,
223
+ gate_mlp,
224
+ tr_gate_msa,
225
+ tr_shift_mlp,
226
+ tr_scale_mlp,
227
+ tr_gate_mlp,
228
+ )
229
+
230
+
231
+ class HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(nn.Module):
232
+ def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
233
+ super().__init__()
234
+
235
+ self.silu = nn.SiLU()
236
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
237
+
238
+ if norm_type == "layer_norm":
239
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
240
+ else:
241
+ raise ValueError(
242
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
243
+ )
244
+
245
+ def forward(
246
+ self,
247
+ hidden_states: torch.Tensor,
248
+ emb: torch.Tensor,
249
+ token_replace_emb: torch.Tensor,
250
+ first_frame_num_tokens: int,
251
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
252
+ emb = self.linear(self.silu(emb))
253
+ token_replace_emb = self.linear(self.silu(token_replace_emb))
254
+
255
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
256
+ tr_shift_msa, tr_scale_msa, tr_gate_msa = token_replace_emb.chunk(3, dim=1)
257
+
258
+ norm_hidden_states = self.norm(hidden_states)
259
+ hidden_states_zero = (
260
+ norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
261
+ )
262
+ hidden_states_orig = (
263
+ norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
264
+ )
265
+ hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
266
+
267
+ return hidden_states, gate_msa, tr_gate_msa
268
+
269
+
270
+ class HunyuanVideoConditionEmbedding(nn.Module):
271
+ def __init__(
272
+ self,
273
+ embedding_dim: int,
274
+ pooled_projection_dim: int,
275
+ guidance_embeds: bool,
276
+ image_condition_type: Optional[str] = None,
277
+ ):
278
+ super().__init__()
279
+
280
+ self.image_condition_type = image_condition_type
281
+
282
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
283
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
284
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
285
+
286
+ self.guidance_embedder = None
287
+ if guidance_embeds:
288
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
289
+
290
+ def forward(
291
+ self, timestep: torch.Tensor, pooled_projection: torch.Tensor, guidance: Optional[torch.Tensor] = None
292
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
293
+ timesteps_proj = self.time_proj(timestep)
294
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
295
+ pooled_projections = self.text_embedder(pooled_projection)
296
+ conditioning = timesteps_emb + pooled_projections
297
+
298
+ token_replace_emb = None
299
+ if self.image_condition_type == "token_replace":
300
+ token_replace_timestep = torch.zeros_like(timestep)
301
+ token_replace_proj = self.time_proj(token_replace_timestep)
302
+ token_replace_emb = self.timestep_embedder(token_replace_proj.to(dtype=pooled_projection.dtype))
303
+ token_replace_emb = token_replace_emb + pooled_projections
304
+
305
+ if self.guidance_embedder is not None:
306
+ guidance_proj = self.time_proj(guidance)
307
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
308
+ conditioning = conditioning + guidance_emb
309
+
310
+ return conditioning, token_replace_emb
311
+
312
+
175
313
  class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
176
314
  def __init__(
177
315
  self,
@@ -389,6 +527,8 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
389
527
  temb: torch.Tensor,
390
528
  attention_mask: Optional[torch.Tensor] = None,
391
529
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
530
+ *args,
531
+ **kwargs,
392
532
  ) -> torch.Tensor:
393
533
  text_seq_length = encoder_hidden_states.shape[1]
394
534
  hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -467,6 +607,8 @@ class HunyuanVideoTransformerBlock(nn.Module):
467
607
  temb: torch.Tensor,
468
608
  attention_mask: Optional[torch.Tensor] = None,
469
609
  freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
610
+ *args,
611
+ **kwargs,
470
612
  ) -> Tuple[torch.Tensor, torch.Tensor]:
471
613
  # 1. Input normalization
472
614
  norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
@@ -502,7 +644,182 @@ class HunyuanVideoTransformerBlock(nn.Module):
502
644
  return hidden_states, encoder_hidden_states
503
645
 
504
646
 
505
- class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
647
+ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
648
+ def __init__(
649
+ self,
650
+ num_attention_heads: int,
651
+ attention_head_dim: int,
652
+ mlp_ratio: float = 4.0,
653
+ qk_norm: str = "rms_norm",
654
+ ) -> None:
655
+ super().__init__()
656
+
657
+ hidden_size = num_attention_heads * attention_head_dim
658
+ mlp_dim = int(hidden_size * mlp_ratio)
659
+
660
+ self.attn = Attention(
661
+ query_dim=hidden_size,
662
+ cross_attention_dim=None,
663
+ dim_head=attention_head_dim,
664
+ heads=num_attention_heads,
665
+ out_dim=hidden_size,
666
+ bias=True,
667
+ processor=HunyuanVideoAttnProcessor2_0(),
668
+ qk_norm=qk_norm,
669
+ eps=1e-6,
670
+ pre_only=True,
671
+ )
672
+
673
+ self.norm = HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
674
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
675
+ self.act_mlp = nn.GELU(approximate="tanh")
676
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
677
+
678
+ def forward(
679
+ self,
680
+ hidden_states: torch.Tensor,
681
+ encoder_hidden_states: torch.Tensor,
682
+ temb: torch.Tensor,
683
+ attention_mask: Optional[torch.Tensor] = None,
684
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
685
+ token_replace_emb: torch.Tensor = None,
686
+ num_tokens: int = None,
687
+ ) -> torch.Tensor:
688
+ text_seq_length = encoder_hidden_states.shape[1]
689
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
690
+
691
+ residual = hidden_states
692
+
693
+ # 1. Input normalization
694
+ norm_hidden_states, gate, tr_gate = self.norm(hidden_states, temb, token_replace_emb, num_tokens)
695
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
696
+
697
+ norm_hidden_states, norm_encoder_hidden_states = (
698
+ norm_hidden_states[:, :-text_seq_length, :],
699
+ norm_hidden_states[:, -text_seq_length:, :],
700
+ )
701
+
702
+ # 2. Attention
703
+ attn_output, context_attn_output = self.attn(
704
+ hidden_states=norm_hidden_states,
705
+ encoder_hidden_states=norm_encoder_hidden_states,
706
+ attention_mask=attention_mask,
707
+ image_rotary_emb=image_rotary_emb,
708
+ )
709
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
710
+
711
+ # 3. Modulation and residual connection
712
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
713
+
714
+ proj_output = self.proj_out(hidden_states)
715
+ hidden_states_zero = proj_output[:, :num_tokens] * tr_gate.unsqueeze(1)
716
+ hidden_states_orig = proj_output[:, num_tokens:] * gate.unsqueeze(1)
717
+ hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
718
+ hidden_states = hidden_states + residual
719
+
720
+ hidden_states, encoder_hidden_states = (
721
+ hidden_states[:, :-text_seq_length, :],
722
+ hidden_states[:, -text_seq_length:, :],
723
+ )
724
+ return hidden_states, encoder_hidden_states
725
+
726
+
727
+ class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
728
+ def __init__(
729
+ self,
730
+ num_attention_heads: int,
731
+ attention_head_dim: int,
732
+ mlp_ratio: float,
733
+ qk_norm: str = "rms_norm",
734
+ ) -> None:
735
+ super().__init__()
736
+
737
+ hidden_size = num_attention_heads * attention_head_dim
738
+
739
+ self.norm1 = HunyuanVideoTokenReplaceAdaLayerNormZero(hidden_size, norm_type="layer_norm")
740
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
741
+
742
+ self.attn = Attention(
743
+ query_dim=hidden_size,
744
+ cross_attention_dim=None,
745
+ added_kv_proj_dim=hidden_size,
746
+ dim_head=attention_head_dim,
747
+ heads=num_attention_heads,
748
+ out_dim=hidden_size,
749
+ context_pre_only=False,
750
+ bias=True,
751
+ processor=HunyuanVideoAttnProcessor2_0(),
752
+ qk_norm=qk_norm,
753
+ eps=1e-6,
754
+ )
755
+
756
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
757
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
758
+
759
+ self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
760
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
761
+
762
+ def forward(
763
+ self,
764
+ hidden_states: torch.Tensor,
765
+ encoder_hidden_states: torch.Tensor,
766
+ temb: torch.Tensor,
767
+ attention_mask: Optional[torch.Tensor] = None,
768
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
769
+ token_replace_emb: torch.Tensor = None,
770
+ num_tokens: int = None,
771
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
772
+ # 1. Input normalization
773
+ (
774
+ norm_hidden_states,
775
+ gate_msa,
776
+ shift_mlp,
777
+ scale_mlp,
778
+ gate_mlp,
779
+ tr_gate_msa,
780
+ tr_shift_mlp,
781
+ tr_scale_mlp,
782
+ tr_gate_mlp,
783
+ ) = self.norm1(hidden_states, temb, token_replace_emb, num_tokens)
784
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
785
+ encoder_hidden_states, emb=temb
786
+ )
787
+
788
+ # 2. Joint attention
789
+ attn_output, context_attn_output = self.attn(
790
+ hidden_states=norm_hidden_states,
791
+ encoder_hidden_states=norm_encoder_hidden_states,
792
+ attention_mask=attention_mask,
793
+ image_rotary_emb=freqs_cis,
794
+ )
795
+
796
+ # 3. Modulation and residual connection
797
+ hidden_states_zero = hidden_states[:, :num_tokens] + attn_output[:, :num_tokens] * tr_gate_msa.unsqueeze(1)
798
+ hidden_states_orig = hidden_states[:, num_tokens:] + attn_output[:, num_tokens:] * gate_msa.unsqueeze(1)
799
+ hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
800
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
801
+
802
+ norm_hidden_states = self.norm2(hidden_states)
803
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
804
+
805
+ hidden_states_zero = norm_hidden_states[:, :num_tokens] * (1 + tr_scale_mlp[:, None]) + tr_shift_mlp[:, None]
806
+ hidden_states_orig = norm_hidden_states[:, num_tokens:] * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
807
+ norm_hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
808
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
809
+
810
+ # 4. Feed-forward
811
+ ff_output = self.ff(norm_hidden_states)
812
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
813
+
814
+ hidden_states_zero = hidden_states[:, :num_tokens] + ff_output[:, :num_tokens] * tr_gate_mlp.unsqueeze(1)
815
+ hidden_states_orig = hidden_states[:, num_tokens:] + ff_output[:, num_tokens:] * gate_mlp.unsqueeze(1)
816
+ hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
817
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
818
+
819
+ return hidden_states, encoder_hidden_states
820
+
821
+
822
+ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
506
823
  r"""
507
824
  A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
508
825
 
@@ -539,9 +856,20 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
539
856
  The value of theta to use in the RoPE layer.
540
857
  rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
541
858
  The dimensions of the axes to use in the RoPE layer.
859
+ image_condition_type (`str`, *optional*, defaults to `None`):
860
+ The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
861
+ image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
862
+ tokens in the latent stream and apply conditioning.
542
863
  """
543
864
 
544
865
  _supports_gradient_checkpointing = True
866
+ _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
867
+ _no_split_modules = [
868
+ "HunyuanVideoTransformerBlock",
869
+ "HunyuanVideoSingleTransformerBlock",
870
+ "HunyuanVideoPatchEmbed",
871
+ "HunyuanVideoTokenRefiner",
872
+ ]
545
873
 
546
874
  @register_to_config
547
875
  def __init__(
@@ -562,9 +890,16 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
562
890
  pooled_projection_dim: int = 768,
563
891
  rope_theta: float = 256.0,
564
892
  rope_axes_dim: Tuple[int] = (16, 56, 56),
893
+ image_condition_type: Optional[str] = None,
565
894
  ) -> None:
566
895
  super().__init__()
567
896
 
897
+ supported_image_condition_types = ["latent_concat", "token_replace"]
898
+ if image_condition_type is not None and image_condition_type not in supported_image_condition_types:
899
+ raise ValueError(
900
+ f"Invalid `image_condition_type` ({image_condition_type}). Supported ones are: {supported_image_condition_types}"
901
+ )
902
+
568
903
  inner_dim = num_attention_heads * attention_head_dim
569
904
  out_channels = out_channels or in_channels
570
905
 
@@ -573,30 +908,53 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
573
908
  self.context_embedder = HunyuanVideoTokenRefiner(
574
909
  text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
575
910
  )
576
- self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
911
+
912
+ self.time_text_embed = HunyuanVideoConditionEmbedding(
913
+ inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
914
+ )
577
915
 
578
916
  # 2. RoPE
579
917
  self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
580
918
 
581
919
  # 3. Dual stream transformer blocks
582
- self.transformer_blocks = nn.ModuleList(
583
- [
584
- HunyuanVideoTransformerBlock(
585
- num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
586
- )
587
- for _ in range(num_layers)
588
- ]
589
- )
920
+ if image_condition_type == "token_replace":
921
+ self.transformer_blocks = nn.ModuleList(
922
+ [
923
+ HunyuanVideoTokenReplaceTransformerBlock(
924
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
925
+ )
926
+ for _ in range(num_layers)
927
+ ]
928
+ )
929
+ else:
930
+ self.transformer_blocks = nn.ModuleList(
931
+ [
932
+ HunyuanVideoTransformerBlock(
933
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
934
+ )
935
+ for _ in range(num_layers)
936
+ ]
937
+ )
590
938
 
591
939
  # 4. Single stream transformer blocks
592
- self.single_transformer_blocks = nn.ModuleList(
593
- [
594
- HunyuanVideoSingleTransformerBlock(
595
- num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
596
- )
597
- for _ in range(num_single_layers)
598
- ]
599
- )
940
+ if image_condition_type == "token_replace":
941
+ self.single_transformer_blocks = nn.ModuleList(
942
+ [
943
+ HunyuanVideoTokenReplaceSingleTransformerBlock(
944
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
945
+ )
946
+ for _ in range(num_single_layers)
947
+ ]
948
+ )
949
+ else:
950
+ self.single_transformer_blocks = nn.ModuleList(
951
+ [
952
+ HunyuanVideoSingleTransformerBlock(
953
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
954
+ )
955
+ for _ in range(num_single_layers)
956
+ ]
957
+ )
600
958
 
601
959
  # 5. Output projection
602
960
  self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
@@ -664,10 +1022,6 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
664
1022
  for name, module in self.named_children():
665
1023
  fn_recursive_attn_processor(name, module, processor)
666
1024
 
667
- def _set_gradient_checkpointing(self, module, value=False):
668
- if hasattr(module, "gradient_checkpointing"):
669
- module.gradient_checkpointing = value
670
-
671
1025
  def forward(
672
1026
  self,
673
1027
  hidden_states: torch.Tensor,
@@ -699,12 +1053,14 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
699
1053
  post_patch_num_frames = num_frames // p_t
700
1054
  post_patch_height = height // p
701
1055
  post_patch_width = width // p
1056
+ first_frame_num_tokens = 1 * post_patch_height * post_patch_width
702
1057
 
703
1058
  # 1. RoPE
704
1059
  image_rotary_emb = self.rope(hidden_states)
705
1060
 
706
1061
  # 2. Conditional embeddings
707
- temb = self.time_text_embed(timestep, guidance, pooled_projections)
1062
+ temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance)
1063
+
708
1064
  hidden_states = self.x_embedder(hidden_states)
709
1065
  encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
710
1066
 
@@ -726,49 +1082,51 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
726
1082
 
727
1083
  # 4. Transformer blocks
728
1084
  if torch.is_grad_enabled() and self.gradient_checkpointing:
729
-
730
- def create_custom_forward(module, return_dict=None):
731
- def custom_forward(*inputs):
732
- if return_dict is not None:
733
- return module(*inputs, return_dict=return_dict)
734
- else:
735
- return module(*inputs)
736
-
737
- return custom_forward
738
-
739
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
740
-
741
1085
  for block in self.transformer_blocks:
742
- hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
743
- create_custom_forward(block),
1086
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
1087
+ block,
744
1088
  hidden_states,
745
1089
  encoder_hidden_states,
746
1090
  temb,
747
1091
  attention_mask,
748
1092
  image_rotary_emb,
749
- **ckpt_kwargs,
1093
+ token_replace_emb,
1094
+ first_frame_num_tokens,
750
1095
  )
751
1096
 
752
1097
  for block in self.single_transformer_blocks:
753
- hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
754
- create_custom_forward(block),
1098
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
1099
+ block,
755
1100
  hidden_states,
756
1101
  encoder_hidden_states,
757
1102
  temb,
758
1103
  attention_mask,
759
1104
  image_rotary_emb,
760
- **ckpt_kwargs,
1105
+ token_replace_emb,
1106
+ first_frame_num_tokens,
761
1107
  )
762
1108
 
763
1109
  else:
764
1110
  for block in self.transformer_blocks:
765
1111
  hidden_states, encoder_hidden_states = block(
766
- hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
1112
+ hidden_states,
1113
+ encoder_hidden_states,
1114
+ temb,
1115
+ attention_mask,
1116
+ image_rotary_emb,
1117
+ token_replace_emb,
1118
+ first_frame_num_tokens,
767
1119
  )
768
1120
 
769
1121
  for block in self.single_transformer_blocks:
770
1122
  hidden_states, encoder_hidden_states = block(
771
- hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
1123
+ hidden_states,
1124
+ encoder_hidden_states,
1125
+ temb,
1126
+ attention_mask,
1127
+ image_rotary_emb,
1128
+ token_replace_emb,
1129
+ first_frame_num_tokens,
772
1130
  )
773
1131
 
774
1132
  # 5. Output projection