diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,994 @@
1
+ # Copyright 2025 The EasyAnimate team and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from typing import Callable, Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from PIL import Image
23
+ from transformers import (
24
+ BertModel,
25
+ BertTokenizer,
26
+ Qwen2Tokenizer,
27
+ Qwen2VLForConditionalGeneration,
28
+ )
29
+
30
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
31
+ from ...image_processor import VaeImageProcessor
32
+ from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
33
+ from ...pipelines.pipeline_utils import DiffusionPipeline
34
+ from ...schedulers import FlowMatchEulerDiscreteScheduler
35
+ from ...utils import is_torch_xla_available, logging, replace_example_docstring
36
+ from ...utils.torch_utils import randn_tensor
37
+ from ...video_processor import VideoProcessor
38
+ from .pipeline_output import EasyAnimatePipelineOutput
39
+
40
+
41
+ if is_torch_xla_available():
42
+ import torch_xla.core.xla_model as xm
43
+
44
+ XLA_AVAILABLE = True
45
+ else:
46
+ XLA_AVAILABLE = False
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+
51
+ EXAMPLE_DOC_STRING = """
52
+ Examples:
53
+ ```python
54
+ >>> import torch
55
+ >>> from diffusers import EasyAnimateControlPipeline
56
+ >>> from diffusers.pipelines.easyanimate.pipeline_easyanimate_control import get_video_to_video_latent
57
+ >>> from diffusers.utils import export_to_video, load_video
58
+
59
+ >>> pipe = EasyAnimateControlPipeline.from_pretrained(
60
+ ... "alibaba-pai/EasyAnimateV5.1-12b-zh-Control-diffusers", torch_dtype=torch.bfloat16
61
+ ... )
62
+ >>> pipe.to("cuda")
63
+
64
+ >>> control_video = load_video(
65
+ ... "https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control/blob/main/asset/pose.mp4"
66
+ ... )
67
+ >>> prompt = (
68
+ ... "In this sunlit outdoor garden, a beautiful woman is dressed in a knee-length, sleeveless white dress. "
69
+ ... "The hem of her dress gently sways with her graceful dance, much like a butterfly fluttering in the breeze. "
70
+ ... "Sunlight filters through the leaves, casting dappled shadows that highlight her soft features and clear eyes, "
71
+ ... "making her appear exceptionally elegant. It seems as if every movement she makes speaks of youth and vitality. "
72
+ ... "As she twirls on the grass, her dress flutters, as if the entire garden is rejoicing in her dance. "
73
+ ... "The colorful flowers around her sway in the gentle breeze, with roses, chrysanthemums, and lilies each "
74
+ ... "releasing their fragrances, creating a relaxed and joyful atmosphere."
75
+ ... )
76
+ >>> sample_size = (672, 384)
77
+ >>> num_frames = 49
78
+
79
+ >>> input_video, _, _ = get_video_to_video_latent(control_video, num_frames, sample_size)
80
+ >>> video = pipe(
81
+ ... prompt,
82
+ ... num_frames=num_frames,
83
+ ... negative_prompt="Twisted body, limb deformities, text subtitles, comics, stillness, ugliness, errors, garbled text.",
84
+ ... height=sample_size[0],
85
+ ... width=sample_size[1],
86
+ ... control_video=input_video,
87
+ ... ).frames[0]
88
+ >>> export_to_video(video, "output.mp4", fps=8)
89
+ ```
90
+ """
91
+
92
+
93
+ def preprocess_image(image, sample_size):
94
+ """
95
+ Preprocess a single image (PIL.Image, numpy.ndarray, or torch.Tensor) to a resized tensor.
96
+ """
97
+ if isinstance(image, torch.Tensor):
98
+ # If input is a tensor, assume it's in CHW format and resize using interpolation
99
+ image = torch.nn.functional.interpolate(
100
+ image.unsqueeze(0), size=sample_size, mode="bilinear", align_corners=False
101
+ ).squeeze(0)
102
+ elif isinstance(image, Image.Image):
103
+ # If input is a PIL image, resize and convert to numpy array
104
+ image = image.resize((sample_size[1], sample_size[0]))
105
+ image = np.array(image)
106
+ elif isinstance(image, np.ndarray):
107
+ # If input is a numpy array, resize using PIL
108
+ image = Image.fromarray(image).resize((sample_size[1], sample_size[0]))
109
+ image = np.array(image)
110
+ else:
111
+ raise ValueError("Unsupported input type. Expected PIL.Image, numpy.ndarray, or torch.Tensor.")
112
+
113
+ # Convert to tensor if not already
114
+ if not isinstance(image, torch.Tensor):
115
+ image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 # HWC -> CHW, normalize to [0, 1]
116
+
117
+ return image
118
+
119
+
120
+ def get_video_to_video_latent(input_video, num_frames, sample_size, validation_video_mask=None, ref_image=None):
121
+ if input_video is not None:
122
+ # Convert each frame in the list to tensor
123
+ input_video = [preprocess_image(frame, sample_size=sample_size) for frame in input_video]
124
+
125
+ # Stack all frames into a single tensor (F, C, H, W)
126
+ input_video = torch.stack(input_video)[:num_frames]
127
+
128
+ # Add batch dimension (B, F, C, H, W)
129
+ input_video = input_video.permute(1, 0, 2, 3).unsqueeze(0)
130
+
131
+ if validation_video_mask is not None:
132
+ # Handle mask input
133
+ validation_video_mask = preprocess_image(validation_video_mask, size=sample_size)
134
+ input_video_mask = torch.where(validation_video_mask < 240 / 255.0, 0.0, 255)
135
+
136
+ # Adjust mask dimensions to match video
137
+ input_video_mask = input_video_mask.unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
138
+ input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
139
+ input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
140
+ else:
141
+ input_video_mask = torch.zeros_like(input_video[:, :1])
142
+ input_video_mask[:, :, :] = 255
143
+ else:
144
+ input_video, input_video_mask = None, None
145
+
146
+ if ref_image is not None:
147
+ # Convert reference image to tensor
148
+ ref_image = preprocess_image(ref_image, size=sample_size)
149
+ ref_image = ref_image.permute(1, 0, 2, 3).unsqueeze(0) # Add batch dimension (B, C, H, W)
150
+ else:
151
+ ref_image = None
152
+
153
+ return input_video, input_video_mask, ref_image
154
+
155
+
156
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
157
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
158
+ tw = tgt_width
159
+ th = tgt_height
160
+ h, w = src
161
+ r = h / w
162
+ if r > (th / tw):
163
+ resize_height = th
164
+ resize_width = int(round(th / h * w))
165
+ else:
166
+ resize_width = tw
167
+ resize_height = int(round(tw / w * h))
168
+
169
+ crop_top = int(round((th - resize_height) / 2.0))
170
+ crop_left = int(round((tw - resize_width) / 2.0))
171
+
172
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
173
+
174
+
175
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
176
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
177
+ r"""
178
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
179
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
180
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
181
+
182
+ Args:
183
+ noise_cfg (`torch.Tensor`):
184
+ The predicted noise tensor for the guided diffusion process.
185
+ noise_pred_text (`torch.Tensor`):
186
+ The predicted noise tensor for the text-guided diffusion process.
187
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
188
+ A rescale factor applied to the noise predictions.
189
+
190
+ Returns:
191
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
192
+ """
193
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
194
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
195
+ # rescale the results from guidance (fixes overexposure)
196
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
197
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
198
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
199
+ return noise_cfg
200
+
201
+
202
+ # Resize mask information in magvit
203
+ def resize_mask(mask, latent, process_first_frame_only=True):
204
+ latent_size = latent.size()
205
+
206
+ if process_first_frame_only:
207
+ target_size = list(latent_size[2:])
208
+ target_size[0] = 1
209
+ first_frame_resized = F.interpolate(
210
+ mask[:, :, 0:1, :, :], size=target_size, mode="trilinear", align_corners=False
211
+ )
212
+
213
+ target_size = list(latent_size[2:])
214
+ target_size[0] = target_size[0] - 1
215
+ if target_size[0] != 0:
216
+ remaining_frames_resized = F.interpolate(
217
+ mask[:, :, 1:, :, :], size=target_size, mode="trilinear", align_corners=False
218
+ )
219
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
220
+ else:
221
+ resized_mask = first_frame_resized
222
+ else:
223
+ target_size = list(latent_size[2:])
224
+ resized_mask = F.interpolate(mask, size=target_size, mode="trilinear", align_corners=False)
225
+ return resized_mask
226
+
227
+
228
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
229
+ def retrieve_timesteps(
230
+ scheduler,
231
+ num_inference_steps: Optional[int] = None,
232
+ device: Optional[Union[str, torch.device]] = None,
233
+ timesteps: Optional[List[int]] = None,
234
+ sigmas: Optional[List[float]] = None,
235
+ **kwargs,
236
+ ):
237
+ r"""
238
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
239
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
240
+
241
+ Args:
242
+ scheduler (`SchedulerMixin`):
243
+ The scheduler to get timesteps from.
244
+ num_inference_steps (`int`):
245
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
246
+ must be `None`.
247
+ device (`str` or `torch.device`, *optional*):
248
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
249
+ timesteps (`List[int]`, *optional*):
250
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
251
+ `num_inference_steps` and `sigmas` must be `None`.
252
+ sigmas (`List[float]`, *optional*):
253
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
254
+ `num_inference_steps` and `timesteps` must be `None`.
255
+
256
+ Returns:
257
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
258
+ second element is the number of inference steps.
259
+ """
260
+ if timesteps is not None and sigmas is not None:
261
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
262
+ if timesteps is not None:
263
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
264
+ if not accepts_timesteps:
265
+ raise ValueError(
266
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
267
+ f" timestep schedules. Please check whether you are using the correct scheduler."
268
+ )
269
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
270
+ timesteps = scheduler.timesteps
271
+ num_inference_steps = len(timesteps)
272
+ elif sigmas is not None:
273
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
274
+ if not accept_sigmas:
275
+ raise ValueError(
276
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
277
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
278
+ )
279
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
280
+ timesteps = scheduler.timesteps
281
+ num_inference_steps = len(timesteps)
282
+ else:
283
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
284
+ timesteps = scheduler.timesteps
285
+ return timesteps, num_inference_steps
286
+
287
+
288
+ class EasyAnimateControlPipeline(DiffusionPipeline):
289
+ r"""
290
+ Pipeline for text-to-video generation using EasyAnimate.
291
+
292
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
293
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
294
+
295
+ EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
296
+
297
+ Args:
298
+ vae ([`AutoencoderKLMagvit`]):
299
+ Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
300
+ text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
301
+ EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
302
+ tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
303
+ A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
304
+ transformer ([`EasyAnimateTransformer3DModel`]):
305
+ The EasyAnimate model designed by EasyAnimate Team.
306
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
307
+ A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
308
+ """
309
+
310
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
311
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
312
+
313
+ def __init__(
314
+ self,
315
+ vae: AutoencoderKLMagvit,
316
+ text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
317
+ tokenizer: Union[Qwen2Tokenizer, BertTokenizer],
318
+ transformer: EasyAnimateTransformer3DModel,
319
+ scheduler: FlowMatchEulerDiscreteScheduler,
320
+ ):
321
+ super().__init__()
322
+
323
+ self.register_modules(
324
+ vae=vae,
325
+ text_encoder=text_encoder,
326
+ tokenizer=tokenizer,
327
+ transformer=transformer,
328
+ scheduler=scheduler,
329
+ )
330
+
331
+ self.enable_text_attention_mask = (
332
+ self.transformer.config.enable_text_attention_mask
333
+ if getattr(self, "transformer", None) is not None
334
+ else True
335
+ )
336
+ self.vae_spatial_compression_ratio = (
337
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8
338
+ )
339
+ self.vae_temporal_compression_ratio = (
340
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4
341
+ )
342
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
343
+ self.mask_processor = VaeImageProcessor(
344
+ vae_scale_factor=self.vae_spatial_compression_ratio,
345
+ do_normalize=False,
346
+ do_binarize=True,
347
+ do_convert_grayscale=True,
348
+ )
349
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
350
+
351
+ # Copied from diffusers.pipelines.easyanimate.pipeline_easyanimate.EasyAnimatePipeline.encode_prompt
352
+ def encode_prompt(
353
+ self,
354
+ prompt: Union[str, List[str]],
355
+ num_images_per_prompt: int = 1,
356
+ do_classifier_free_guidance: bool = True,
357
+ negative_prompt: Optional[Union[str, List[str]]] = None,
358
+ prompt_embeds: Optional[torch.Tensor] = None,
359
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
360
+ prompt_attention_mask: Optional[torch.Tensor] = None,
361
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
362
+ device: Optional[torch.device] = None,
363
+ dtype: Optional[torch.dtype] = None,
364
+ max_sequence_length: int = 256,
365
+ ):
366
+ r"""
367
+ Encodes the prompt into text encoder hidden states.
368
+
369
+ Args:
370
+ prompt (`str` or `List[str]`, *optional*):
371
+ prompt to be encoded
372
+ device: (`torch.device`):
373
+ torch device
374
+ dtype (`torch.dtype`):
375
+ torch dtype
376
+ num_images_per_prompt (`int`):
377
+ number of images that should be generated per prompt
378
+ do_classifier_free_guidance (`bool`):
379
+ whether to use classifier free guidance or not
380
+ negative_prompt (`str` or `List[str]`, *optional*):
381
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
382
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
383
+ less than `1`).
384
+ prompt_embeds (`torch.Tensor`, *optional*):
385
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
386
+ provided, text embeddings will be generated from `prompt` input argument.
387
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
388
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
389
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
390
+ argument.
391
+ prompt_attention_mask (`torch.Tensor`, *optional*):
392
+ Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
393
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
394
+ Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
395
+ max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
396
+ """
397
+ dtype = dtype or self.text_encoder.dtype
398
+ device = device or self.text_encoder.device
399
+
400
+ if prompt is not None and isinstance(prompt, str):
401
+ batch_size = 1
402
+ elif prompt is not None and isinstance(prompt, list):
403
+ batch_size = len(prompt)
404
+ else:
405
+ batch_size = prompt_embeds.shape[0]
406
+
407
+ if prompt_embeds is None:
408
+ if isinstance(prompt, str):
409
+ messages = [
410
+ {
411
+ "role": "user",
412
+ "content": [{"type": "text", "text": prompt}],
413
+ }
414
+ ]
415
+ else:
416
+ messages = [
417
+ {
418
+ "role": "user",
419
+ "content": [{"type": "text", "text": _prompt}],
420
+ }
421
+ for _prompt in prompt
422
+ ]
423
+ text = [
424
+ self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
425
+ ]
426
+
427
+ text_inputs = self.tokenizer(
428
+ text=text,
429
+ padding="max_length",
430
+ max_length=max_sequence_length,
431
+ truncation=True,
432
+ return_attention_mask=True,
433
+ padding_side="right",
434
+ return_tensors="pt",
435
+ )
436
+ text_inputs = text_inputs.to(self.text_encoder.device)
437
+
438
+ text_input_ids = text_inputs.input_ids
439
+ prompt_attention_mask = text_inputs.attention_mask
440
+ if self.enable_text_attention_mask:
441
+ # Inference: Generation of the output
442
+ prompt_embeds = self.text_encoder(
443
+ input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
444
+ ).hidden_states[-2]
445
+ else:
446
+ raise ValueError("LLM needs attention_mask")
447
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
448
+
449
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
450
+
451
+ bs_embed, seq_len, _ = prompt_embeds.shape
452
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
453
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
454
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
455
+ prompt_attention_mask = prompt_attention_mask.to(device=device)
456
+
457
+ # get unconditional embeddings for classifier free guidance
458
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
459
+ if negative_prompt is not None and isinstance(negative_prompt, str):
460
+ messages = [
461
+ {
462
+ "role": "user",
463
+ "content": [{"type": "text", "text": negative_prompt}],
464
+ }
465
+ ]
466
+ else:
467
+ messages = [
468
+ {
469
+ "role": "user",
470
+ "content": [{"type": "text", "text": _negative_prompt}],
471
+ }
472
+ for _negative_prompt in negative_prompt
473
+ ]
474
+ text = [
475
+ self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
476
+ ]
477
+
478
+ text_inputs = self.tokenizer(
479
+ text=text,
480
+ padding="max_length",
481
+ max_length=max_sequence_length,
482
+ truncation=True,
483
+ return_attention_mask=True,
484
+ padding_side="right",
485
+ return_tensors="pt",
486
+ )
487
+ text_inputs = text_inputs.to(self.text_encoder.device)
488
+
489
+ text_input_ids = text_inputs.input_ids
490
+ negative_prompt_attention_mask = text_inputs.attention_mask
491
+ if self.enable_text_attention_mask:
492
+ # Inference: Generation of the output
493
+ negative_prompt_embeds = self.text_encoder(
494
+ input_ids=text_input_ids,
495
+ attention_mask=negative_prompt_attention_mask,
496
+ output_hidden_states=True,
497
+ ).hidden_states[-2]
498
+ else:
499
+ raise ValueError("LLM needs attention_mask")
500
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
501
+
502
+ if do_classifier_free_guidance:
503
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
504
+ seq_len = negative_prompt_embeds.shape[1]
505
+
506
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
507
+
508
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
509
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
510
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
511
+
512
+ return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
513
+
514
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
515
+ def prepare_extra_step_kwargs(self, generator, eta):
516
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
517
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
518
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
519
+ # and should be between [0, 1]
520
+
521
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
522
+ extra_step_kwargs = {}
523
+ if accepts_eta:
524
+ extra_step_kwargs["eta"] = eta
525
+
526
+ # check if the scheduler accepts generator
527
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
528
+ if accepts_generator:
529
+ extra_step_kwargs["generator"] = generator
530
+ return extra_step_kwargs
531
+
532
+ def check_inputs(
533
+ self,
534
+ prompt,
535
+ height,
536
+ width,
537
+ negative_prompt=None,
538
+ prompt_embeds=None,
539
+ negative_prompt_embeds=None,
540
+ prompt_attention_mask=None,
541
+ negative_prompt_attention_mask=None,
542
+ callback_on_step_end_tensor_inputs=None,
543
+ ):
544
+ if height % 16 != 0 or width % 16 != 0:
545
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
546
+
547
+ if callback_on_step_end_tensor_inputs is not None and not all(
548
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
549
+ ):
550
+ raise ValueError(
551
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
552
+ )
553
+
554
+ if prompt is not None and prompt_embeds is not None:
555
+ raise ValueError(
556
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
557
+ " only forward one of the two."
558
+ )
559
+ elif prompt is None and prompt_embeds is None:
560
+ raise ValueError(
561
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
562
+ )
563
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
564
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
565
+
566
+ if prompt_embeds is not None and prompt_attention_mask is None:
567
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
568
+
569
+ if negative_prompt is not None and negative_prompt_embeds is not None:
570
+ raise ValueError(
571
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
572
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
573
+ )
574
+
575
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
576
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
577
+
578
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
579
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
580
+ raise ValueError(
581
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
582
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
583
+ f" {negative_prompt_embeds.shape}."
584
+ )
585
+
586
+ def prepare_latents(
587
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
588
+ ):
589
+ if latents is not None:
590
+ return latents.to(device=device, dtype=dtype)
591
+
592
+ shape = (
593
+ batch_size,
594
+ num_channels_latents,
595
+ (num_frames - 1) // self.vae_temporal_compression_ratio + 1,
596
+ height // self.vae_spatial_compression_ratio,
597
+ width // self.vae_spatial_compression_ratio,
598
+ )
599
+
600
+ if isinstance(generator, list) and len(generator) != batch_size:
601
+ raise ValueError(
602
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
603
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
604
+ )
605
+
606
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
607
+ # scale the initial noise by the standard deviation required by the scheduler
608
+ if hasattr(self.scheduler, "init_noise_sigma"):
609
+ latents = latents * self.scheduler.init_noise_sigma
610
+ return latents
611
+
612
+ def prepare_control_latents(
613
+ self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
614
+ ):
615
+ # resize the control to latents shape as we concatenate the control to the latents
616
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
617
+ # and half precision
618
+
619
+ if control is not None:
620
+ control = control.to(device=device, dtype=dtype)
621
+ bs = 1
622
+ new_control = []
623
+ for i in range(0, control.shape[0], bs):
624
+ control_bs = control[i : i + bs]
625
+ control_bs = self.vae.encode(control_bs)[0]
626
+ control_bs = control_bs.mode()
627
+ new_control.append(control_bs)
628
+ control = torch.cat(new_control, dim=0)
629
+ control = control * self.vae.config.scaling_factor
630
+
631
+ if control_image is not None:
632
+ control_image = control_image.to(device=device, dtype=dtype)
633
+ bs = 1
634
+ new_control_pixel_values = []
635
+ for i in range(0, control_image.shape[0], bs):
636
+ control_pixel_values_bs = control_image[i : i + bs]
637
+ control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
638
+ control_pixel_values_bs = control_pixel_values_bs.mode()
639
+ new_control_pixel_values.append(control_pixel_values_bs)
640
+ control_image_latents = torch.cat(new_control_pixel_values, dim=0)
641
+ control_image_latents = control_image_latents * self.vae.config.scaling_factor
642
+ else:
643
+ control_image_latents = None
644
+
645
+ return control, control_image_latents
646
+
647
+ @property
648
+ def guidance_scale(self):
649
+ return self._guidance_scale
650
+
651
+ @property
652
+ def guidance_rescale(self):
653
+ return self._guidance_rescale
654
+
655
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
656
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
657
+ # corresponds to doing no classifier free guidance.
658
+ @property
659
+ def do_classifier_free_guidance(self):
660
+ return self._guidance_scale > 1
661
+
662
+ @property
663
+ def num_timesteps(self):
664
+ return self._num_timesteps
665
+
666
+ @property
667
+ def interrupt(self):
668
+ return self._interrupt
669
+
670
+ @torch.no_grad()
671
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
672
+ def __call__(
673
+ self,
674
+ prompt: Union[str, List[str]] = None,
675
+ num_frames: Optional[int] = 49,
676
+ height: Optional[int] = 512,
677
+ width: Optional[int] = 512,
678
+ control_video: Union[torch.FloatTensor] = None,
679
+ control_camera_video: Union[torch.FloatTensor] = None,
680
+ ref_image: Union[torch.FloatTensor] = None,
681
+ num_inference_steps: Optional[int] = 50,
682
+ guidance_scale: Optional[float] = 5.0,
683
+ negative_prompt: Optional[Union[str, List[str]]] = None,
684
+ num_images_per_prompt: Optional[int] = 1,
685
+ eta: Optional[float] = 0.0,
686
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
687
+ latents: Optional[torch.Tensor] = None,
688
+ prompt_embeds: Optional[torch.Tensor] = None,
689
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
690
+ prompt_attention_mask: Optional[torch.Tensor] = None,
691
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
692
+ output_type: Optional[str] = "pil",
693
+ return_dict: bool = True,
694
+ callback_on_step_end: Optional[
695
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
696
+ ] = None,
697
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
698
+ guidance_rescale: float = 0.0,
699
+ timesteps: Optional[List[int]] = None,
700
+ ):
701
+ r"""
702
+ Generates images or video using the EasyAnimate pipeline based on the provided prompts.
703
+
704
+ Examples:
705
+ prompt (`str` or `List[str]`, *optional*):
706
+ Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
707
+ num_frames (`int`, *optional*):
708
+ Length of the generated video (in frames).
709
+ height (`int`, *optional*):
710
+ Height of the generated image in pixels.
711
+ width (`int`, *optional*):
712
+ Width of the generated image in pixels.
713
+ num_inference_steps (`int`, *optional*, defaults to 50):
714
+ Number of denoising steps during generation. More steps generally yield higher quality images but slow
715
+ down inference.
716
+ guidance_scale (`float`, *optional*, defaults to 5.0):
717
+ Encourages the model to align outputs with prompts. A higher value may decrease image quality.
718
+ negative_prompt (`str` or `List[str]`, *optional*):
719
+ Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
720
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
721
+ Number of images to generate for each prompt.
722
+ eta (`float`, *optional*, defaults to 0.0):
723
+ Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
724
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
725
+ A generator to ensure reproducibility in image generation.
726
+ latents (`torch.Tensor`, *optional*):
727
+ Predefined latent tensors to condition generation.
728
+ prompt_embeds (`torch.Tensor`, *optional*):
729
+ Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
730
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
731
+ Embeddings for negative prompts. Overrides string inputs if defined.
732
+ prompt_attention_mask (`torch.Tensor`, *optional*):
733
+ Attention mask for the primary prompt embeddings.
734
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
735
+ Attention mask for negative prompt embeddings.
736
+ output_type (`str`, *optional*, defaults to "latent"):
737
+ Format of the generated output, either as a PIL image or as a NumPy array.
738
+ return_dict (`bool`, *optional*, defaults to `True`):
739
+ If `True`, returns a structured output. Otherwise returns a simple tuple.
740
+ callback_on_step_end (`Callable`, *optional*):
741
+ Functions called at the end of each denoising step.
742
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
743
+ Tensor names to be included in callback function calls.
744
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
745
+ Adjusts noise levels based on guidance scale.
746
+
747
+ Returns:
748
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
749
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
750
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
751
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
752
+ "not-safe-for-work" (nsfw) content.
753
+ """
754
+
755
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
756
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
757
+
758
+ # 0. default height and width
759
+ height = int((height // 16) * 16)
760
+ width = int((width // 16) * 16)
761
+
762
+ # 1. Check inputs. Raise error if not correct
763
+ self.check_inputs(
764
+ prompt,
765
+ height,
766
+ width,
767
+ negative_prompt,
768
+ prompt_embeds,
769
+ negative_prompt_embeds,
770
+ prompt_attention_mask,
771
+ negative_prompt_attention_mask,
772
+ callback_on_step_end_tensor_inputs,
773
+ )
774
+ self._guidance_scale = guidance_scale
775
+ self._guidance_rescale = guidance_rescale
776
+ self._interrupt = False
777
+
778
+ # 2. Define call parameters
779
+ if prompt is not None and isinstance(prompt, str):
780
+ batch_size = 1
781
+ elif prompt is not None and isinstance(prompt, list):
782
+ batch_size = len(prompt)
783
+ else:
784
+ batch_size = prompt_embeds.shape[0]
785
+
786
+ device = self._execution_device
787
+ if self.text_encoder is not None:
788
+ dtype = self.text_encoder.dtype
789
+ else:
790
+ dtype = self.transformer.dtype
791
+
792
+ # 3. Encode input prompt
793
+ (
794
+ prompt_embeds,
795
+ negative_prompt_embeds,
796
+ prompt_attention_mask,
797
+ negative_prompt_attention_mask,
798
+ ) = self.encode_prompt(
799
+ prompt=prompt,
800
+ device=device,
801
+ dtype=dtype,
802
+ num_images_per_prompt=num_images_per_prompt,
803
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
804
+ negative_prompt=negative_prompt,
805
+ prompt_embeds=prompt_embeds,
806
+ negative_prompt_embeds=negative_prompt_embeds,
807
+ prompt_attention_mask=prompt_attention_mask,
808
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
809
+ text_encoder_index=0,
810
+ )
811
+
812
+ # 4. Prepare timesteps
813
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
814
+ timesteps, num_inference_steps = retrieve_timesteps(
815
+ self.scheduler, num_inference_steps, device, timesteps, mu=1
816
+ )
817
+ else:
818
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
819
+ timesteps = self.scheduler.timesteps
820
+
821
+ # 5. Prepare latent variables
822
+ num_channels_latents = self.vae.config.latent_channels
823
+ latents = self.prepare_latents(
824
+ batch_size * num_images_per_prompt,
825
+ num_channels_latents,
826
+ num_frames,
827
+ height,
828
+ width,
829
+ dtype,
830
+ device,
831
+ generator,
832
+ latents,
833
+ )
834
+
835
+ if control_camera_video is not None:
836
+ control_video_latents = resize_mask(control_camera_video, latents, process_first_frame_only=True)
837
+ control_video_latents = control_video_latents * 6
838
+ control_latents = (
839
+ torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
840
+ ).to(device, dtype)
841
+ elif control_video is not None:
842
+ batch_size, channels, num_frames, height_video, width_video = control_video.shape
843
+ control_video = self.image_processor.preprocess(
844
+ control_video.permute(0, 2, 1, 3, 4).reshape(
845
+ batch_size * num_frames, channels, height_video, width_video
846
+ ),
847
+ height=height,
848
+ width=width,
849
+ )
850
+ control_video = control_video.to(dtype=torch.float32)
851
+ control_video = control_video.reshape(batch_size, num_frames, channels, height, width).permute(
852
+ 0, 2, 1, 3, 4
853
+ )
854
+ control_video_latents = self.prepare_control_latents(
855
+ None,
856
+ control_video,
857
+ batch_size,
858
+ height,
859
+ width,
860
+ dtype,
861
+ device,
862
+ generator,
863
+ self.do_classifier_free_guidance,
864
+ )[1]
865
+ control_latents = (
866
+ torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
867
+ ).to(device, dtype)
868
+ else:
869
+ control_video_latents = torch.zeros_like(latents).to(device, dtype)
870
+ control_latents = (
871
+ torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
872
+ ).to(device, dtype)
873
+
874
+ if ref_image is not None:
875
+ batch_size, channels, num_frames, height_video, width_video = ref_image.shape
876
+ ref_image = self.image_processor.preprocess(
877
+ ref_image.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height_video, width_video),
878
+ height=height,
879
+ width=width,
880
+ )
881
+ ref_image = ref_image.to(dtype=torch.float32)
882
+ ref_image = ref_image.reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
883
+
884
+ ref_image_latents = self.prepare_control_latents(
885
+ None,
886
+ ref_image,
887
+ batch_size,
888
+ height,
889
+ width,
890
+ prompt_embeds.dtype,
891
+ device,
892
+ generator,
893
+ self.do_classifier_free_guidance,
894
+ )[1]
895
+
896
+ ref_image_latents_conv_in = torch.zeros_like(latents)
897
+ if latents.size()[2] != 1:
898
+ ref_image_latents_conv_in[:, :, :1] = ref_image_latents
899
+ ref_image_latents_conv_in = (
900
+ torch.cat([ref_image_latents_conv_in] * 2)
901
+ if self.do_classifier_free_guidance
902
+ else ref_image_latents_conv_in
903
+ ).to(device, dtype)
904
+ control_latents = torch.cat([control_latents, ref_image_latents_conv_in], dim=1)
905
+ else:
906
+ ref_image_latents_conv_in = torch.zeros_like(latents)
907
+ ref_image_latents_conv_in = (
908
+ torch.cat([ref_image_latents_conv_in] * 2)
909
+ if self.do_classifier_free_guidance
910
+ else ref_image_latents_conv_in
911
+ ).to(device, dtype)
912
+ control_latents = torch.cat([control_latents, ref_image_latents_conv_in], dim=1)
913
+
914
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
915
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
916
+
917
+ if self.do_classifier_free_guidance:
918
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
919
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
920
+
921
+ # To latents.device
922
+ prompt_embeds = prompt_embeds.to(device=device)
923
+ prompt_attention_mask = prompt_attention_mask.to(device=device)
924
+
925
+ # 7. Denoising loop
926
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
927
+ self._num_timesteps = len(timesteps)
928
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
929
+ for i, t in enumerate(timesteps):
930
+ if self.interrupt:
931
+ continue
932
+
933
+ # expand the latents if we are doing classifier free guidance
934
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
935
+ if hasattr(self.scheduler, "scale_model_input"):
936
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
937
+
938
+ # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
939
+ t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
940
+ dtype=latent_model_input.dtype
941
+ )
942
+ # predict the noise residual
943
+ noise_pred = self.transformer(
944
+ latent_model_input,
945
+ t_expand,
946
+ encoder_hidden_states=prompt_embeds,
947
+ control_latents=control_latents,
948
+ return_dict=False,
949
+ )[0]
950
+ if noise_pred.size()[1] != self.vae.config.latent_channels:
951
+ noise_pred, _ = noise_pred.chunk(2, dim=1)
952
+
953
+ # perform guidance
954
+ if self.do_classifier_free_guidance:
955
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
956
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
957
+
958
+ if self.do_classifier_free_guidance and guidance_rescale > 0.0:
959
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
960
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
961
+
962
+ # compute the previous noisy sample x_t -> x_t-1
963
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
964
+
965
+ if callback_on_step_end is not None:
966
+ callback_kwargs = {}
967
+ for k in callback_on_step_end_tensor_inputs:
968
+ callback_kwargs[k] = locals()[k]
969
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
970
+
971
+ latents = callback_outputs.pop("latents", latents)
972
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
973
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
974
+
975
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
976
+ progress_bar.update()
977
+
978
+ if XLA_AVAILABLE:
979
+ xm.mark_step()
980
+
981
+ # Convert to tensor
982
+ if not output_type == "latent":
983
+ video = self.decode_latents(latents)
984
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
985
+ else:
986
+ video = latents
987
+
988
+ # Offload all models
989
+ self.maybe_free_model_hooks()
990
+
991
+ if not return_dict:
992
+ return (video,)
993
+
994
+ return EasyAnimatePipelineOutput(frames=video)