diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,527 @@
1
+ # Copyright 2025 The EasyAnimate team and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...utils import logging
24
+ from ...utils.torch_utils import maybe_allow_in_graph
25
+ from ..attention import Attention, FeedForward
26
+ from ..embeddings import TimestepEmbedding, Timesteps, get_3d_rotary_pos_embed
27
+ from ..modeling_outputs import Transformer2DModelOutput
28
+ from ..modeling_utils import ModelMixin
29
+ from ..normalization import AdaLayerNorm, FP32LayerNorm, RMSNorm
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ class EasyAnimateLayerNormZero(nn.Module):
36
+ def __init__(
37
+ self,
38
+ conditioning_dim: int,
39
+ embedding_dim: int,
40
+ elementwise_affine: bool = True,
41
+ eps: float = 1e-5,
42
+ bias: bool = True,
43
+ norm_type: str = "fp32_layer_norm",
44
+ ) -> None:
45
+ super().__init__()
46
+
47
+ self.silu = nn.SiLU()
48
+ self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
49
+
50
+ if norm_type == "layer_norm":
51
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
52
+ elif norm_type == "fp32_layer_norm":
53
+ self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
54
+ else:
55
+ raise ValueError(
56
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
57
+ )
58
+
59
+ def forward(
60
+ self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
61
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
62
+ shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
63
+ hidden_states = self.norm(hidden_states) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
64
+ encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale.unsqueeze(1)) + enc_shift.unsqueeze(
65
+ 1
66
+ )
67
+ return hidden_states, encoder_hidden_states, gate, enc_gate
68
+
69
+
70
+ class EasyAnimateRotaryPosEmbed(nn.Module):
71
+ def __init__(self, patch_size: int, rope_dim: List[int]) -> None:
72
+ super().__init__()
73
+
74
+ self.patch_size = patch_size
75
+ self.rope_dim = rope_dim
76
+
77
+ def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
78
+ tw = tgt_width
79
+ th = tgt_height
80
+ h, w = src
81
+ r = h / w
82
+ if r > (th / tw):
83
+ resize_height = th
84
+ resize_width = int(round(th / h * w))
85
+ else:
86
+ resize_width = tw
87
+ resize_height = int(round(tw / w * h))
88
+
89
+ crop_top = int(round((th - resize_height) / 2.0))
90
+ crop_left = int(round((tw - resize_width) / 2.0))
91
+
92
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
93
+
94
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
95
+ bs, c, num_frames, grid_height, grid_width = hidden_states.size()
96
+ grid_height = grid_height // self.patch_size
97
+ grid_width = grid_width // self.patch_size
98
+ base_size_width = 90 // self.patch_size
99
+ base_size_height = 60 // self.patch_size
100
+
101
+ grid_crops_coords = self.get_resize_crop_region_for_grid(
102
+ (grid_height, grid_width), base_size_width, base_size_height
103
+ )
104
+ image_rotary_emb = get_3d_rotary_pos_embed(
105
+ self.rope_dim,
106
+ grid_crops_coords,
107
+ grid_size=(grid_height, grid_width),
108
+ temporal_size=hidden_states.size(2),
109
+ use_real=True,
110
+ )
111
+ return image_rotary_emb
112
+
113
+
114
+ class EasyAnimateAttnProcessor2_0:
115
+ r"""
116
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
117
+ used in the EasyAnimateTransformer3DModel model.
118
+ """
119
+
120
+ def __init__(self):
121
+ if not hasattr(F, "scaled_dot_product_attention"):
122
+ raise ImportError(
123
+ "EasyAnimateAttnProcessor2_0 requires PyTorch 2.0 or above. To use it, please install PyTorch 2.0."
124
+ )
125
+
126
+ def __call__(
127
+ self,
128
+ attn: Attention,
129
+ hidden_states: torch.Tensor,
130
+ encoder_hidden_states: torch.Tensor,
131
+ attention_mask: Optional[torch.Tensor] = None,
132
+ image_rotary_emb: Optional[torch.Tensor] = None,
133
+ ) -> torch.Tensor:
134
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
135
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
136
+
137
+ # 1. QKV projections
138
+ query = attn.to_q(hidden_states)
139
+ key = attn.to_k(hidden_states)
140
+ value = attn.to_v(hidden_states)
141
+
142
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
143
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
144
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
145
+
146
+ # 2. QK normalization
147
+ if attn.norm_q is not None:
148
+ query = attn.norm_q(query)
149
+ if attn.norm_k is not None:
150
+ key = attn.norm_k(key)
151
+
152
+ # 3. Encoder condition QKV projection and normalization
153
+ if attn.add_q_proj is not None and encoder_hidden_states is not None:
154
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
155
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
156
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
157
+
158
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
159
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
160
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
161
+
162
+ if attn.norm_added_q is not None:
163
+ encoder_query = attn.norm_added_q(encoder_query)
164
+ if attn.norm_added_k is not None:
165
+ encoder_key = attn.norm_added_k(encoder_key)
166
+
167
+ query = torch.cat([encoder_query, query], dim=2)
168
+ key = torch.cat([encoder_key, key], dim=2)
169
+ value = torch.cat([encoder_value, value], dim=2)
170
+
171
+ if image_rotary_emb is not None:
172
+ from ..embeddings import apply_rotary_emb
173
+
174
+ query[:, :, encoder_hidden_states.shape[1] :] = apply_rotary_emb(
175
+ query[:, :, encoder_hidden_states.shape[1] :], image_rotary_emb
176
+ )
177
+ if not attn.is_cross_attention:
178
+ key[:, :, encoder_hidden_states.shape[1] :] = apply_rotary_emb(
179
+ key[:, :, encoder_hidden_states.shape[1] :], image_rotary_emb
180
+ )
181
+
182
+ # 5. Attention
183
+ hidden_states = F.scaled_dot_product_attention(
184
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
185
+ )
186
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
187
+ hidden_states = hidden_states.to(query.dtype)
188
+
189
+ # 6. Output projection
190
+ if encoder_hidden_states is not None:
191
+ encoder_hidden_states, hidden_states = (
192
+ hidden_states[:, : encoder_hidden_states.shape[1]],
193
+ hidden_states[:, encoder_hidden_states.shape[1] :],
194
+ )
195
+
196
+ if getattr(attn, "to_out", None) is not None:
197
+ hidden_states = attn.to_out[0](hidden_states)
198
+ hidden_states = attn.to_out[1](hidden_states)
199
+
200
+ if getattr(attn, "to_add_out", None) is not None:
201
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
202
+ else:
203
+ if getattr(attn, "to_out", None) is not None:
204
+ hidden_states = attn.to_out[0](hidden_states)
205
+ hidden_states = attn.to_out[1](hidden_states)
206
+
207
+ return hidden_states, encoder_hidden_states
208
+
209
+
210
+ @maybe_allow_in_graph
211
+ class EasyAnimateTransformerBlock(nn.Module):
212
+ def __init__(
213
+ self,
214
+ dim: int,
215
+ num_attention_heads: int,
216
+ attention_head_dim: int,
217
+ time_embed_dim: int,
218
+ dropout: float = 0.0,
219
+ activation_fn: str = "gelu-approximate",
220
+ norm_elementwise_affine: bool = True,
221
+ norm_eps: float = 1e-6,
222
+ final_dropout: bool = True,
223
+ ff_inner_dim: Optional[int] = None,
224
+ ff_bias: bool = True,
225
+ qk_norm: bool = True,
226
+ after_norm: bool = False,
227
+ norm_type: str = "fp32_layer_norm",
228
+ is_mmdit_block: bool = True,
229
+ ):
230
+ super().__init__()
231
+
232
+ # Attention Part
233
+ self.norm1 = EasyAnimateLayerNormZero(
234
+ time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
235
+ )
236
+
237
+ self.attn1 = Attention(
238
+ query_dim=dim,
239
+ dim_head=attention_head_dim,
240
+ heads=num_attention_heads,
241
+ qk_norm="layer_norm" if qk_norm else None,
242
+ eps=1e-6,
243
+ bias=True,
244
+ added_proj_bias=True,
245
+ added_kv_proj_dim=dim if is_mmdit_block else None,
246
+ context_pre_only=False if is_mmdit_block else None,
247
+ processor=EasyAnimateAttnProcessor2_0(),
248
+ )
249
+
250
+ # FFN Part
251
+ self.norm2 = EasyAnimateLayerNormZero(
252
+ time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
253
+ )
254
+ self.ff = FeedForward(
255
+ dim,
256
+ dropout=dropout,
257
+ activation_fn=activation_fn,
258
+ final_dropout=final_dropout,
259
+ inner_dim=ff_inner_dim,
260
+ bias=ff_bias,
261
+ )
262
+
263
+ self.txt_ff = None
264
+ if is_mmdit_block:
265
+ self.txt_ff = FeedForward(
266
+ dim,
267
+ dropout=dropout,
268
+ activation_fn=activation_fn,
269
+ final_dropout=final_dropout,
270
+ inner_dim=ff_inner_dim,
271
+ bias=ff_bias,
272
+ )
273
+
274
+ self.norm3 = None
275
+ if after_norm:
276
+ self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
277
+
278
+ def forward(
279
+ self,
280
+ hidden_states: torch.Tensor,
281
+ encoder_hidden_states: torch.Tensor,
282
+ temb: torch.Tensor,
283
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
284
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
285
+ # 1. Attention
286
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
287
+ hidden_states, encoder_hidden_states, temb
288
+ )
289
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
290
+ hidden_states=norm_hidden_states,
291
+ encoder_hidden_states=norm_encoder_hidden_states,
292
+ image_rotary_emb=image_rotary_emb,
293
+ )
294
+ hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
295
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
296
+
297
+ # 2. Feed-forward
298
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
299
+ hidden_states, encoder_hidden_states, temb
300
+ )
301
+ if self.norm3 is not None:
302
+ norm_hidden_states = self.norm3(self.ff(norm_hidden_states))
303
+ if self.txt_ff is not None:
304
+ norm_encoder_hidden_states = self.norm3(self.txt_ff(norm_encoder_hidden_states))
305
+ else:
306
+ norm_encoder_hidden_states = self.norm3(self.ff(norm_encoder_hidden_states))
307
+ else:
308
+ norm_hidden_states = self.ff(norm_hidden_states)
309
+ if self.txt_ff is not None:
310
+ norm_encoder_hidden_states = self.txt_ff(norm_encoder_hidden_states)
311
+ else:
312
+ norm_encoder_hidden_states = self.ff(norm_encoder_hidden_states)
313
+ hidden_states = hidden_states + gate_ff.unsqueeze(1) * norm_hidden_states
314
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff.unsqueeze(1) * norm_encoder_hidden_states
315
+ return hidden_states, encoder_hidden_states
316
+
317
+
318
+ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
319
+ """
320
+ A Transformer model for video-like data in [EasyAnimate](https://github.com/aigc-apps/EasyAnimate).
321
+
322
+ Parameters:
323
+ num_attention_heads (`int`, defaults to `48`):
324
+ The number of heads to use for multi-head attention.
325
+ attention_head_dim (`int`, defaults to `64`):
326
+ The number of channels in each head.
327
+ in_channels (`int`, defaults to `16`):
328
+ The number of channels in the input.
329
+ out_channels (`int`, *optional*, defaults to `16`):
330
+ The number of channels in the output.
331
+ patch_size (`int`, defaults to `2`):
332
+ The size of the patches to use in the patch embedding layer.
333
+ sample_width (`int`, defaults to `90`):
334
+ The width of the input latents.
335
+ sample_height (`int`, defaults to `60`):
336
+ The height of the input latents.
337
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
338
+ Activation function to use in feed-forward.
339
+ timestep_activation_fn (`str`, defaults to `"silu"`):
340
+ Activation function to use when generating the timestep embeddings.
341
+ num_layers (`int`, defaults to `30`):
342
+ The number of layers of Transformer blocks to use.
343
+ mmdit_layers (`int`, defaults to `1000`):
344
+ The number of layers of Multi Modal Transformer blocks to use.
345
+ dropout (`float`, defaults to `0.0`):
346
+ The dropout probability to use.
347
+ time_embed_dim (`int`, defaults to `512`):
348
+ Output dimension of timestep embeddings.
349
+ text_embed_dim (`int`, defaults to `4096`):
350
+ Input dimension of text embeddings from the text encoder.
351
+ norm_eps (`float`, defaults to `1e-5`):
352
+ The epsilon value to use in normalization layers.
353
+ norm_elementwise_affine (`bool`, defaults to `True`):
354
+ Whether to use elementwise affine in normalization layers.
355
+ flip_sin_to_cos (`bool`, defaults to `True`):
356
+ Whether to flip the sin to cos in the time embedding.
357
+ time_position_encoding_type (`str`, defaults to `3d_rope`):
358
+ Type of time position encoding.
359
+ after_norm (`bool`, defaults to `False`):
360
+ Flag to apply normalization after.
361
+ resize_inpaint_mask_directly (`bool`, defaults to `True`):
362
+ Flag to resize inpaint mask directly.
363
+ enable_text_attention_mask (`bool`, defaults to `True`):
364
+ Flag to enable text attention mask.
365
+ add_noise_in_inpaint_model (`bool`, defaults to `False`):
366
+ Flag to add noise in inpaint model.
367
+ """
368
+
369
+ _supports_gradient_checkpointing = True
370
+ _no_split_modules = ["EasyAnimateTransformerBlock"]
371
+ _skip_layerwise_casting_patterns = ["^proj$", "norm", "^proj_out$"]
372
+
373
+ @register_to_config
374
+ def __init__(
375
+ self,
376
+ num_attention_heads: int = 48,
377
+ attention_head_dim: int = 64,
378
+ in_channels: Optional[int] = None,
379
+ out_channels: Optional[int] = None,
380
+ patch_size: Optional[int] = None,
381
+ sample_width: int = 90,
382
+ sample_height: int = 60,
383
+ activation_fn: str = "gelu-approximate",
384
+ timestep_activation_fn: str = "silu",
385
+ freq_shift: int = 0,
386
+ num_layers: int = 48,
387
+ mmdit_layers: int = 48,
388
+ dropout: float = 0.0,
389
+ time_embed_dim: int = 512,
390
+ add_norm_text_encoder: bool = False,
391
+ text_embed_dim: int = 3584,
392
+ text_embed_dim_t5: int = None,
393
+ norm_eps: float = 1e-5,
394
+ norm_elementwise_affine: bool = True,
395
+ flip_sin_to_cos: bool = True,
396
+ time_position_encoding_type: str = "3d_rope",
397
+ after_norm=False,
398
+ resize_inpaint_mask_directly: bool = True,
399
+ enable_text_attention_mask: bool = True,
400
+ add_noise_in_inpaint_model: bool = True,
401
+ ):
402
+ super().__init__()
403
+ inner_dim = num_attention_heads * attention_head_dim
404
+
405
+ # 1. Timestep embedding
406
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
407
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
408
+ self.rope_embedding = EasyAnimateRotaryPosEmbed(patch_size, attention_head_dim)
409
+
410
+ # 2. Patch embedding
411
+ self.proj = nn.Conv2d(
412
+ in_channels, inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
413
+ )
414
+
415
+ # 3. Text refined embedding
416
+ self.text_proj = None
417
+ self.text_proj_t5 = None
418
+ if not add_norm_text_encoder:
419
+ self.text_proj = nn.Linear(text_embed_dim, inner_dim)
420
+ if text_embed_dim_t5 is not None:
421
+ self.text_proj_t5 = nn.Linear(text_embed_dim_t5, inner_dim)
422
+ else:
423
+ self.text_proj = nn.Sequential(
424
+ RMSNorm(text_embed_dim, 1e-6, elementwise_affine=True), nn.Linear(text_embed_dim, inner_dim)
425
+ )
426
+ if text_embed_dim_t5 is not None:
427
+ self.text_proj_t5 = nn.Sequential(
428
+ RMSNorm(text_embed_dim, 1e-6, elementwise_affine=True), nn.Linear(text_embed_dim_t5, inner_dim)
429
+ )
430
+
431
+ # 4. Transformer blocks
432
+ self.transformer_blocks = nn.ModuleList(
433
+ [
434
+ EasyAnimateTransformerBlock(
435
+ dim=inner_dim,
436
+ num_attention_heads=num_attention_heads,
437
+ attention_head_dim=attention_head_dim,
438
+ time_embed_dim=time_embed_dim,
439
+ dropout=dropout,
440
+ activation_fn=activation_fn,
441
+ norm_elementwise_affine=norm_elementwise_affine,
442
+ norm_eps=norm_eps,
443
+ after_norm=after_norm,
444
+ is_mmdit_block=True if _ < mmdit_layers else False,
445
+ )
446
+ for _ in range(num_layers)
447
+ ]
448
+ )
449
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
450
+
451
+ # 5. Output norm & projection
452
+ self.norm_out = AdaLayerNorm(
453
+ embedding_dim=time_embed_dim,
454
+ output_dim=2 * inner_dim,
455
+ norm_elementwise_affine=norm_elementwise_affine,
456
+ norm_eps=norm_eps,
457
+ chunk_dim=1,
458
+ )
459
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
460
+
461
+ self.gradient_checkpointing = False
462
+
463
+ def forward(
464
+ self,
465
+ hidden_states: torch.Tensor,
466
+ timestep: torch.Tensor,
467
+ timestep_cond: Optional[torch.Tensor] = None,
468
+ encoder_hidden_states: Optional[torch.Tensor] = None,
469
+ encoder_hidden_states_t5: Optional[torch.Tensor] = None,
470
+ inpaint_latents: Optional[torch.Tensor] = None,
471
+ control_latents: Optional[torch.Tensor] = None,
472
+ return_dict: bool = True,
473
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
474
+ batch_size, channels, video_length, height, width = hidden_states.size()
475
+ p = self.config.patch_size
476
+ post_patch_height = height // p
477
+ post_patch_width = width // p
478
+
479
+ # 1. Time embedding
480
+ temb = self.time_proj(timestep).to(dtype=hidden_states.dtype)
481
+ temb = self.time_embedding(temb, timestep_cond)
482
+ image_rotary_emb = self.rope_embedding(hidden_states)
483
+
484
+ # 2. Patch embedding
485
+ if inpaint_latents is not None:
486
+ hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
487
+ if control_latents is not None:
488
+ hidden_states = torch.concat([hidden_states, control_latents], 1)
489
+
490
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, F, H, W] -> [BF, C, H, W]
491
+ hidden_states = self.proj(hidden_states)
492
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
493
+ 0, 2, 1, 3, 4
494
+ ) # [BF, C, H, W] -> [B, F, C, H, W]
495
+ hidden_states = hidden_states.flatten(2, 4).transpose(1, 2) # [B, F, C, H, W] -> [B, FHW, C]
496
+
497
+ # 3. Text embedding
498
+ encoder_hidden_states = self.text_proj(encoder_hidden_states)
499
+ if encoder_hidden_states_t5 is not None:
500
+ encoder_hidden_states_t5 = self.text_proj_t5(encoder_hidden_states_t5)
501
+ encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1).contiguous()
502
+
503
+ # 4. Transformer blocks
504
+ for block in self.transformer_blocks:
505
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
506
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
507
+ block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
508
+ )
509
+ else:
510
+ hidden_states, encoder_hidden_states = block(
511
+ hidden_states, encoder_hidden_states, temb, image_rotary_emb
512
+ )
513
+
514
+ hidden_states = self.norm_final(hidden_states)
515
+
516
+ # 5. Output norm & projection
517
+ hidden_states = self.norm_out(hidden_states, temb=temb)
518
+ hidden_states = self.proj_out(hidden_states)
519
+
520
+ # 6. Unpatchify
521
+ p = self.config.patch_size
522
+ output = hidden_states.reshape(batch_size, video_length, post_patch_height, post_patch_width, channels, p, p)
523
+ output = output.permute(0, 4, 1, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
524
+
525
+ if not return_dict:
526
+ return (output,)
527
+ return Transformer2DModelOutput(sample=output)