diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1094 @@
1
+ # Copyright 2025 The EasyAnimate team and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from ...configuration_utils import ConfigMixin, register_to_config
24
+ from ...utils import logging
25
+ from ...utils.accelerate_utils import apply_forward_hook
26
+ from ..activations import get_activation
27
+ from ..modeling_outputs import AutoencoderKLOutput
28
+ from ..modeling_utils import ModelMixin
29
+ from .vae import DecoderOutput, DiagonalGaussianDistribution
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ class EasyAnimateCausalConv3d(nn.Conv3d):
36
+ def __init__(
37
+ self,
38
+ in_channels: int,
39
+ out_channels: int,
40
+ kernel_size: Union[int, Tuple[int, ...]] = 3,
41
+ stride: Union[int, Tuple[int, ...]] = 1,
42
+ padding: Union[int, Tuple[int, ...]] = 1,
43
+ dilation: Union[int, Tuple[int, ...]] = 1,
44
+ groups: int = 1,
45
+ bias: bool = True,
46
+ padding_mode: str = "zeros",
47
+ ):
48
+ # Ensure kernel_size, stride, and dilation are tuples of length 3
49
+ kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3
50
+ assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead."
51
+
52
+ stride = stride if isinstance(stride, tuple) else (stride,) * 3
53
+ assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead."
54
+
55
+ dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3
56
+ assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."
57
+
58
+ # Unpack kernel size, stride, and dilation for temporal, height, and width dimensions
59
+ t_ks, h_ks, w_ks = kernel_size
60
+ self.t_stride, h_stride, w_stride = stride
61
+ t_dilation, h_dilation, w_dilation = dilation
62
+
63
+ # Calculate padding for temporal dimension to maintain causality
64
+ t_pad = (t_ks - 1) * t_dilation
65
+
66
+ # Calculate padding for height and width dimensions based on the padding parameter
67
+ if padding is None:
68
+ h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2)
69
+ w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2)
70
+ elif isinstance(padding, int):
71
+ h_pad = w_pad = padding
72
+ else:
73
+ assert NotImplementedError
74
+
75
+ # Store temporal padding and initialize flags and previous features cache
76
+ self.temporal_padding = t_pad
77
+ self.temporal_padding_origin = math.ceil(((t_ks - 1) * w_dilation + (1 - w_stride)) / 2)
78
+
79
+ self.prev_features = None
80
+
81
+ # Initialize the parent class with modified padding
82
+ super().__init__(
83
+ in_channels=in_channels,
84
+ out_channels=out_channels,
85
+ kernel_size=kernel_size,
86
+ stride=stride,
87
+ dilation=dilation,
88
+ padding=(0, h_pad, w_pad),
89
+ groups=groups,
90
+ bias=bias,
91
+ padding_mode=padding_mode,
92
+ )
93
+
94
+ def _clear_conv_cache(self):
95
+ del self.prev_features
96
+ self.prev_features = None
97
+
98
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
99
+ # Ensure input tensor is of the correct type
100
+ dtype = hidden_states.dtype
101
+ if self.prev_features is None:
102
+ # Pad the input tensor in the temporal dimension to maintain causality
103
+ hidden_states = F.pad(
104
+ hidden_states,
105
+ pad=(0, 0, 0, 0, self.temporal_padding, 0),
106
+ mode="replicate", # TODO: check if this is necessary
107
+ )
108
+ hidden_states = hidden_states.to(dtype=dtype)
109
+
110
+ # Clear cache before processing and store previous features for causality
111
+ self._clear_conv_cache()
112
+ self.prev_features = hidden_states[:, :, -self.temporal_padding :].clone()
113
+
114
+ # Process the input tensor in chunks along the temporal dimension
115
+ num_frames = hidden_states.size(2)
116
+ outputs = []
117
+ i = 0
118
+ while i + self.temporal_padding + 1 <= num_frames:
119
+ out = super().forward(hidden_states[:, :, i : i + self.temporal_padding + 1])
120
+ i += self.t_stride
121
+ outputs.append(out)
122
+ return torch.concat(outputs, 2)
123
+ else:
124
+ # Concatenate previous features with the input tensor for continuous temporal processing
125
+ if self.t_stride == 2:
126
+ hidden_states = torch.concat(
127
+ [self.prev_features[:, :, -(self.temporal_padding - 1) :], hidden_states], dim=2
128
+ )
129
+ else:
130
+ hidden_states = torch.concat([self.prev_features, hidden_states], dim=2)
131
+ hidden_states = hidden_states.to(dtype=dtype)
132
+
133
+ # Clear cache and update previous features
134
+ self._clear_conv_cache()
135
+ self.prev_features = hidden_states[:, :, -self.temporal_padding :].clone()
136
+
137
+ # Process the concatenated tensor in chunks along the temporal dimension
138
+ num_frames = hidden_states.size(2)
139
+ outputs = []
140
+ i = 0
141
+ while i + self.temporal_padding + 1 <= num_frames:
142
+ out = super().forward(hidden_states[:, :, i : i + self.temporal_padding + 1])
143
+ i += self.t_stride
144
+ outputs.append(out)
145
+ return torch.concat(outputs, 2)
146
+
147
+
148
+ class EasyAnimateResidualBlock3D(nn.Module):
149
+ def __init__(
150
+ self,
151
+ in_channels: int,
152
+ out_channels: int,
153
+ non_linearity: str = "silu",
154
+ norm_num_groups: int = 32,
155
+ norm_eps: float = 1e-6,
156
+ spatial_group_norm: bool = True,
157
+ dropout: float = 0.0,
158
+ output_scale_factor: float = 1.0,
159
+ ):
160
+ super().__init__()
161
+
162
+ self.output_scale_factor = output_scale_factor
163
+
164
+ # Group normalization for input tensor
165
+ self.norm1 = nn.GroupNorm(
166
+ num_groups=norm_num_groups,
167
+ num_channels=in_channels,
168
+ eps=norm_eps,
169
+ affine=True,
170
+ )
171
+ self.nonlinearity = get_activation(non_linearity)
172
+ self.conv1 = EasyAnimateCausalConv3d(in_channels, out_channels, kernel_size=3)
173
+
174
+ self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True)
175
+ self.dropout = nn.Dropout(dropout)
176
+ self.conv2 = EasyAnimateCausalConv3d(out_channels, out_channels, kernel_size=3)
177
+
178
+ if in_channels != out_channels:
179
+ self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1)
180
+ else:
181
+ self.shortcut = nn.Identity()
182
+
183
+ self.spatial_group_norm = spatial_group_norm
184
+
185
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
186
+ shortcut = self.shortcut(hidden_states)
187
+
188
+ if self.spatial_group_norm:
189
+ batch_size = hidden_states.size(0)
190
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, T, H, W] -> [B * T, C, H, W]
191
+ hidden_states = self.norm1(hidden_states)
192
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
193
+ 0, 2, 1, 3, 4
194
+ ) # [B * T, C, H, W] -> [B, C, T, H, W]
195
+ else:
196
+ hidden_states = self.norm1(hidden_states)
197
+
198
+ hidden_states = self.nonlinearity(hidden_states)
199
+ hidden_states = self.conv1(hidden_states)
200
+
201
+ if self.spatial_group_norm:
202
+ batch_size = hidden_states.size(0)
203
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, T, H, W] -> [B * T, C, H, W]
204
+ hidden_states = self.norm2(hidden_states)
205
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
206
+ 0, 2, 1, 3, 4
207
+ ) # [B * T, C, H, W] -> [B, C, T, H, W]
208
+ else:
209
+ hidden_states = self.norm2(hidden_states)
210
+
211
+ hidden_states = self.nonlinearity(hidden_states)
212
+ hidden_states = self.dropout(hidden_states)
213
+ hidden_states = self.conv2(hidden_states)
214
+
215
+ return (hidden_states + shortcut) / self.output_scale_factor
216
+
217
+
218
+ class EasyAnimateDownsampler3D(nn.Module):
219
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: tuple = (2, 2, 2)):
220
+ super().__init__()
221
+
222
+ self.conv = EasyAnimateCausalConv3d(
223
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=0
224
+ )
225
+
226
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
227
+ hidden_states = F.pad(hidden_states, (0, 1, 0, 1))
228
+ hidden_states = self.conv(hidden_states)
229
+ return hidden_states
230
+
231
+
232
+ class EasyAnimateUpsampler3D(nn.Module):
233
+ def __init__(
234
+ self,
235
+ in_channels: int,
236
+ out_channels: int,
237
+ kernel_size: int = 3,
238
+ temporal_upsample: bool = False,
239
+ spatial_group_norm: bool = True,
240
+ ):
241
+ super().__init__()
242
+ out_channels = out_channels or in_channels
243
+
244
+ self.temporal_upsample = temporal_upsample
245
+ self.spatial_group_norm = spatial_group_norm
246
+
247
+ self.conv = EasyAnimateCausalConv3d(
248
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size
249
+ )
250
+ self.prev_features = None
251
+
252
+ def _clear_conv_cache(self):
253
+ del self.prev_features
254
+ self.prev_features = None
255
+
256
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
257
+ hidden_states = F.interpolate(hidden_states, scale_factor=(1, 2, 2), mode="nearest")
258
+ hidden_states = self.conv(hidden_states)
259
+
260
+ if self.temporal_upsample:
261
+ if self.prev_features is None:
262
+ self.prev_features = hidden_states
263
+ else:
264
+ hidden_states = F.interpolate(
265
+ hidden_states,
266
+ scale_factor=(2, 1, 1),
267
+ mode="trilinear" if not self.spatial_group_norm else "nearest",
268
+ )
269
+ return hidden_states
270
+
271
+
272
+ class EasyAnimateDownBlock3D(nn.Module):
273
+ def __init__(
274
+ self,
275
+ in_channels: int,
276
+ out_channels: int,
277
+ num_layers: int = 1,
278
+ act_fn: str = "silu",
279
+ norm_num_groups: int = 32,
280
+ norm_eps: float = 1e-6,
281
+ spatial_group_norm: bool = True,
282
+ dropout: float = 0.0,
283
+ output_scale_factor: float = 1.0,
284
+ add_downsample: bool = True,
285
+ add_temporal_downsample: bool = True,
286
+ ):
287
+ super().__init__()
288
+
289
+ self.convs = nn.ModuleList([])
290
+ for i in range(num_layers):
291
+ in_channels = in_channels if i == 0 else out_channels
292
+ self.convs.append(
293
+ EasyAnimateResidualBlock3D(
294
+ in_channels=in_channels,
295
+ out_channels=out_channels,
296
+ non_linearity=act_fn,
297
+ norm_num_groups=norm_num_groups,
298
+ norm_eps=norm_eps,
299
+ spatial_group_norm=spatial_group_norm,
300
+ dropout=dropout,
301
+ output_scale_factor=output_scale_factor,
302
+ )
303
+ )
304
+
305
+ if add_downsample and add_temporal_downsample:
306
+ self.downsampler = EasyAnimateDownsampler3D(out_channels, out_channels, kernel_size=3, stride=(2, 2, 2))
307
+ self.spatial_downsample_factor = 2
308
+ self.temporal_downsample_factor = 2
309
+ elif add_downsample and not add_temporal_downsample:
310
+ self.downsampler = EasyAnimateDownsampler3D(out_channels, out_channels, kernel_size=3, stride=(1, 2, 2))
311
+ self.spatial_downsample_factor = 2
312
+ self.temporal_downsample_factor = 1
313
+ else:
314
+ self.downsampler = None
315
+ self.spatial_downsample_factor = 1
316
+ self.temporal_downsample_factor = 1
317
+
318
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
319
+ for conv in self.convs:
320
+ hidden_states = conv(hidden_states)
321
+ if self.downsampler is not None:
322
+ hidden_states = self.downsampler(hidden_states)
323
+ return hidden_states
324
+
325
+
326
+ class EasyAnimateUpBlock3d(nn.Module):
327
+ def __init__(
328
+ self,
329
+ in_channels: int,
330
+ out_channels: int,
331
+ num_layers: int = 1,
332
+ act_fn: str = "silu",
333
+ norm_num_groups: int = 32,
334
+ norm_eps: float = 1e-6,
335
+ spatial_group_norm: bool = False,
336
+ dropout: float = 0.0,
337
+ output_scale_factor: float = 1.0,
338
+ add_upsample: bool = True,
339
+ add_temporal_upsample: bool = True,
340
+ ):
341
+ super().__init__()
342
+
343
+ self.convs = nn.ModuleList([])
344
+ for i in range(num_layers):
345
+ in_channels = in_channels if i == 0 else out_channels
346
+ self.convs.append(
347
+ EasyAnimateResidualBlock3D(
348
+ in_channels=in_channels,
349
+ out_channels=out_channels,
350
+ non_linearity=act_fn,
351
+ norm_num_groups=norm_num_groups,
352
+ norm_eps=norm_eps,
353
+ spatial_group_norm=spatial_group_norm,
354
+ dropout=dropout,
355
+ output_scale_factor=output_scale_factor,
356
+ )
357
+ )
358
+
359
+ if add_upsample:
360
+ self.upsampler = EasyAnimateUpsampler3D(
361
+ in_channels,
362
+ in_channels,
363
+ temporal_upsample=add_temporal_upsample,
364
+ spatial_group_norm=spatial_group_norm,
365
+ )
366
+ else:
367
+ self.upsampler = None
368
+
369
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
370
+ for conv in self.convs:
371
+ hidden_states = conv(hidden_states)
372
+ if self.upsampler is not None:
373
+ hidden_states = self.upsampler(hidden_states)
374
+ return hidden_states
375
+
376
+
377
+ class EasyAnimateMidBlock3d(nn.Module):
378
+ def __init__(
379
+ self,
380
+ in_channels: int,
381
+ num_layers: int = 1,
382
+ act_fn: str = "silu",
383
+ norm_num_groups: int = 32,
384
+ norm_eps: float = 1e-6,
385
+ spatial_group_norm: bool = True,
386
+ dropout: float = 0.0,
387
+ output_scale_factor: float = 1.0,
388
+ ):
389
+ super().__init__()
390
+
391
+ norm_num_groups = norm_num_groups if norm_num_groups is not None else min(in_channels // 4, 32)
392
+
393
+ self.convs = nn.ModuleList(
394
+ [
395
+ EasyAnimateResidualBlock3D(
396
+ in_channels=in_channels,
397
+ out_channels=in_channels,
398
+ non_linearity=act_fn,
399
+ norm_num_groups=norm_num_groups,
400
+ norm_eps=norm_eps,
401
+ spatial_group_norm=spatial_group_norm,
402
+ dropout=dropout,
403
+ output_scale_factor=output_scale_factor,
404
+ )
405
+ ]
406
+ )
407
+
408
+ for _ in range(num_layers - 1):
409
+ self.convs.append(
410
+ EasyAnimateResidualBlock3D(
411
+ in_channels=in_channels,
412
+ out_channels=in_channels,
413
+ non_linearity=act_fn,
414
+ norm_num_groups=norm_num_groups,
415
+ norm_eps=norm_eps,
416
+ spatial_group_norm=spatial_group_norm,
417
+ dropout=dropout,
418
+ output_scale_factor=output_scale_factor,
419
+ )
420
+ )
421
+
422
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
423
+ hidden_states = self.convs[0](hidden_states)
424
+ for resnet in self.convs[1:]:
425
+ hidden_states = resnet(hidden_states)
426
+ return hidden_states
427
+
428
+
429
+ class EasyAnimateEncoder(nn.Module):
430
+ r"""
431
+ Causal encoder for 3D video-like data used in [EasyAnimate](https://arxiv.org/abs/2405.18991).
432
+ """
433
+
434
+ _supports_gradient_checkpointing = True
435
+
436
+ def __init__(
437
+ self,
438
+ in_channels: int = 3,
439
+ out_channels: int = 8,
440
+ down_block_types: Tuple[str, ...] = (
441
+ "SpatialDownBlock3D",
442
+ "SpatialTemporalDownBlock3D",
443
+ "SpatialTemporalDownBlock3D",
444
+ "SpatialTemporalDownBlock3D",
445
+ ),
446
+ block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
447
+ layers_per_block: int = 2,
448
+ norm_num_groups: int = 32,
449
+ act_fn: str = "silu",
450
+ double_z: bool = True,
451
+ spatial_group_norm: bool = False,
452
+ ):
453
+ super().__init__()
454
+
455
+ # 1. Input convolution
456
+ self.conv_in = EasyAnimateCausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
457
+
458
+ # 2. Down blocks
459
+ self.down_blocks = nn.ModuleList([])
460
+ output_channels = block_out_channels[0]
461
+ for i, down_block_type in enumerate(down_block_types):
462
+ input_channels = output_channels
463
+ output_channels = block_out_channels[i]
464
+ is_final_block = i == len(block_out_channels) - 1
465
+ if down_block_type == "SpatialDownBlock3D":
466
+ down_block = EasyAnimateDownBlock3D(
467
+ in_channels=input_channels,
468
+ out_channels=output_channels,
469
+ num_layers=layers_per_block,
470
+ act_fn=act_fn,
471
+ norm_num_groups=norm_num_groups,
472
+ norm_eps=1e-6,
473
+ spatial_group_norm=spatial_group_norm,
474
+ add_downsample=not is_final_block,
475
+ add_temporal_downsample=False,
476
+ )
477
+ elif down_block_type == "SpatialTemporalDownBlock3D":
478
+ down_block = EasyAnimateDownBlock3D(
479
+ in_channels=input_channels,
480
+ out_channels=output_channels,
481
+ num_layers=layers_per_block,
482
+ act_fn=act_fn,
483
+ norm_num_groups=norm_num_groups,
484
+ norm_eps=1e-6,
485
+ spatial_group_norm=spatial_group_norm,
486
+ add_downsample=not is_final_block,
487
+ add_temporal_downsample=True,
488
+ )
489
+ else:
490
+ raise ValueError(f"Unknown up block type: {down_block_type}")
491
+ self.down_blocks.append(down_block)
492
+
493
+ # 3. Middle block
494
+ self.mid_block = EasyAnimateMidBlock3d(
495
+ in_channels=block_out_channels[-1],
496
+ num_layers=layers_per_block,
497
+ act_fn=act_fn,
498
+ spatial_group_norm=spatial_group_norm,
499
+ norm_num_groups=norm_num_groups,
500
+ norm_eps=1e-6,
501
+ dropout=0,
502
+ output_scale_factor=1,
503
+ )
504
+
505
+ # 4. Output normalization & convolution
506
+ self.spatial_group_norm = spatial_group_norm
507
+ self.conv_norm_out = nn.GroupNorm(
508
+ num_channels=block_out_channels[-1],
509
+ num_groups=norm_num_groups,
510
+ eps=1e-6,
511
+ )
512
+ self.conv_act = get_activation(act_fn)
513
+
514
+ # Initialize the output convolution layer
515
+ conv_out_channels = 2 * out_channels if double_z else out_channels
516
+ self.conv_out = EasyAnimateCausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
517
+
518
+ self.gradient_checkpointing = False
519
+
520
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
521
+ # hidden_states: (B, C, T, H, W)
522
+ hidden_states = self.conv_in(hidden_states)
523
+
524
+ for down_block in self.down_blocks:
525
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
526
+ hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
527
+ else:
528
+ hidden_states = down_block(hidden_states)
529
+
530
+ hidden_states = self.mid_block(hidden_states)
531
+
532
+ if self.spatial_group_norm:
533
+ batch_size = hidden_states.size(0)
534
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
535
+ hidden_states = self.conv_norm_out(hidden_states)
536
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
537
+ else:
538
+ hidden_states = self.conv_norm_out(hidden_states)
539
+
540
+ hidden_states = self.conv_act(hidden_states)
541
+ hidden_states = self.conv_out(hidden_states)
542
+ return hidden_states
543
+
544
+
545
+ class EasyAnimateDecoder(nn.Module):
546
+ r"""
547
+ Causal decoder for 3D video-like data used in [EasyAnimate](https://arxiv.org/abs/2405.18991).
548
+ """
549
+
550
+ _supports_gradient_checkpointing = True
551
+
552
+ def __init__(
553
+ self,
554
+ in_channels: int = 8,
555
+ out_channels: int = 3,
556
+ up_block_types: Tuple[str, ...] = (
557
+ "SpatialUpBlock3D",
558
+ "SpatialTemporalUpBlock3D",
559
+ "SpatialTemporalUpBlock3D",
560
+ "SpatialTemporalUpBlock3D",
561
+ ),
562
+ block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
563
+ layers_per_block: int = 2,
564
+ norm_num_groups: int = 32,
565
+ act_fn: str = "silu",
566
+ spatial_group_norm: bool = False,
567
+ ):
568
+ super().__init__()
569
+
570
+ # 1. Input convolution
571
+ self.conv_in = EasyAnimateCausalConv3d(in_channels, block_out_channels[-1], kernel_size=3)
572
+
573
+ # 2. Middle block
574
+ self.mid_block = EasyAnimateMidBlock3d(
575
+ in_channels=block_out_channels[-1],
576
+ num_layers=layers_per_block,
577
+ act_fn=act_fn,
578
+ norm_num_groups=norm_num_groups,
579
+ norm_eps=1e-6,
580
+ dropout=0,
581
+ output_scale_factor=1,
582
+ )
583
+
584
+ # 3. Up blocks
585
+ self.up_blocks = nn.ModuleList([])
586
+ reversed_block_out_channels = list(reversed(block_out_channels))
587
+ output_channels = reversed_block_out_channels[0]
588
+ for i, up_block_type in enumerate(up_block_types):
589
+ input_channels = output_channels
590
+ output_channels = reversed_block_out_channels[i]
591
+ is_final_block = i == len(block_out_channels) - 1
592
+
593
+ # Create and append up block to up_blocks
594
+ if up_block_type == "SpatialUpBlock3D":
595
+ up_block = EasyAnimateUpBlock3d(
596
+ in_channels=input_channels,
597
+ out_channels=output_channels,
598
+ num_layers=layers_per_block + 1,
599
+ act_fn=act_fn,
600
+ norm_num_groups=norm_num_groups,
601
+ norm_eps=1e-6,
602
+ spatial_group_norm=spatial_group_norm,
603
+ add_upsample=not is_final_block,
604
+ add_temporal_upsample=False,
605
+ )
606
+ elif up_block_type == "SpatialTemporalUpBlock3D":
607
+ up_block = EasyAnimateUpBlock3d(
608
+ in_channels=input_channels,
609
+ out_channels=output_channels,
610
+ num_layers=layers_per_block + 1,
611
+ act_fn=act_fn,
612
+ norm_num_groups=norm_num_groups,
613
+ norm_eps=1e-6,
614
+ spatial_group_norm=spatial_group_norm,
615
+ add_upsample=not is_final_block,
616
+ add_temporal_upsample=True,
617
+ )
618
+ else:
619
+ raise ValueError(f"Unknown up block type: {up_block_type}")
620
+ self.up_blocks.append(up_block)
621
+
622
+ # Output normalization and activation
623
+ self.spatial_group_norm = spatial_group_norm
624
+ self.conv_norm_out = nn.GroupNorm(
625
+ num_channels=block_out_channels[0],
626
+ num_groups=norm_num_groups,
627
+ eps=1e-6,
628
+ )
629
+ self.conv_act = get_activation(act_fn)
630
+
631
+ # Output convolution layer
632
+ self.conv_out = EasyAnimateCausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
633
+
634
+ self.gradient_checkpointing = False
635
+
636
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
637
+ # hidden_states: (B, C, T, H, W)
638
+ hidden_states = self.conv_in(hidden_states)
639
+
640
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
641
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
642
+ else:
643
+ hidden_states = self.mid_block(hidden_states)
644
+
645
+ for up_block in self.up_blocks:
646
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
647
+ hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
648
+ else:
649
+ hidden_states = up_block(hidden_states)
650
+
651
+ if self.spatial_group_norm:
652
+ batch_size = hidden_states.size(0)
653
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, T, H, W] -> [B * T, C, H, W]
654
+ hidden_states = self.conv_norm_out(hidden_states)
655
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
656
+ 0, 2, 1, 3, 4
657
+ ) # [B * T, C, H, W] -> [B, C, T, H, W]
658
+ else:
659
+ hidden_states = self.conv_norm_out(hidden_states)
660
+
661
+ hidden_states = self.conv_act(hidden_states)
662
+ hidden_states = self.conv_out(hidden_states)
663
+ return hidden_states
664
+
665
+
666
+ class AutoencoderKLMagvit(ModelMixin, ConfigMixin):
667
+ r"""
668
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This
669
+ model is used in [EasyAnimate](https://arxiv.org/abs/2405.18991).
670
+
671
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
672
+ for all models (such as downloading or saving).
673
+ """
674
+
675
+ _supports_gradient_checkpointing = True
676
+
677
+ @register_to_config
678
+ def __init__(
679
+ self,
680
+ in_channels: int = 3,
681
+ latent_channels: int = 16,
682
+ out_channels: int = 3,
683
+ block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
684
+ down_block_types: Tuple[str, ...] = [
685
+ "SpatialDownBlock3D",
686
+ "SpatialTemporalDownBlock3D",
687
+ "SpatialTemporalDownBlock3D",
688
+ "SpatialTemporalDownBlock3D",
689
+ ],
690
+ up_block_types: Tuple[str, ...] = [
691
+ "SpatialUpBlock3D",
692
+ "SpatialTemporalUpBlock3D",
693
+ "SpatialTemporalUpBlock3D",
694
+ "SpatialTemporalUpBlock3D",
695
+ ],
696
+ layers_per_block: int = 2,
697
+ act_fn: str = "silu",
698
+ norm_num_groups: int = 32,
699
+ scaling_factor: float = 0.7125,
700
+ spatial_group_norm: bool = True,
701
+ ):
702
+ super().__init__()
703
+
704
+ # Initialize the encoder
705
+ self.encoder = EasyAnimateEncoder(
706
+ in_channels=in_channels,
707
+ out_channels=latent_channels,
708
+ down_block_types=down_block_types,
709
+ block_out_channels=block_out_channels,
710
+ layers_per_block=layers_per_block,
711
+ norm_num_groups=norm_num_groups,
712
+ act_fn=act_fn,
713
+ double_z=True,
714
+ spatial_group_norm=spatial_group_norm,
715
+ )
716
+
717
+ # Initialize the decoder
718
+ self.decoder = EasyAnimateDecoder(
719
+ in_channels=latent_channels,
720
+ out_channels=out_channels,
721
+ up_block_types=up_block_types,
722
+ block_out_channels=block_out_channels,
723
+ layers_per_block=layers_per_block,
724
+ norm_num_groups=norm_num_groups,
725
+ act_fn=act_fn,
726
+ spatial_group_norm=spatial_group_norm,
727
+ )
728
+
729
+ # Initialize convolution layers for quantization and post-quantization
730
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
731
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
732
+
733
+ self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1)
734
+ self.temporal_compression_ratio = 2 ** (len(block_out_channels) - 2)
735
+
736
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
737
+ # to perform decoding of a single video latent at a time.
738
+ self.use_slicing = False
739
+
740
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
741
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
742
+ # intermediate tiles together, the memory requirement can be lowered.
743
+ self.use_tiling = False
744
+
745
+ # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
746
+ # at a fixed frame batch size (based on `self.num_latent_frames_batch_size`), the memory requirement can be lowered.
747
+ self.use_framewise_encoding = False
748
+ self.use_framewise_decoding = False
749
+
750
+ # Assign mini-batch sizes for encoder and decoder
751
+ self.num_sample_frames_batch_size = 4
752
+ self.num_latent_frames_batch_size = 1
753
+
754
+ # The minimal tile height and width for spatial tiling to be used
755
+ self.tile_sample_min_height = 512
756
+ self.tile_sample_min_width = 512
757
+ self.tile_sample_min_num_frames = 4
758
+
759
+ # The minimal distance between two spatial tiles
760
+ self.tile_sample_stride_height = 448
761
+ self.tile_sample_stride_width = 448
762
+ self.tile_sample_stride_num_frames = 8
763
+
764
+ def _clear_conv_cache(self):
765
+ # Clear cache for convolutional layers if needed
766
+ for name, module in self.named_modules():
767
+ if isinstance(module, EasyAnimateCausalConv3d):
768
+ module._clear_conv_cache()
769
+ if isinstance(module, EasyAnimateUpsampler3D):
770
+ module._clear_conv_cache()
771
+
772
+ def enable_tiling(
773
+ self,
774
+ tile_sample_min_height: Optional[int] = None,
775
+ tile_sample_min_width: Optional[int] = None,
776
+ tile_sample_min_num_frames: Optional[int] = None,
777
+ tile_sample_stride_height: Optional[float] = None,
778
+ tile_sample_stride_width: Optional[float] = None,
779
+ tile_sample_stride_num_frames: Optional[float] = None,
780
+ ) -> None:
781
+ r"""
782
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
783
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
784
+ processing larger images.
785
+
786
+ Args:
787
+ tile_sample_min_height (`int`, *optional*):
788
+ The minimum height required for a sample to be separated into tiles across the height dimension.
789
+ tile_sample_min_width (`int`, *optional*):
790
+ The minimum width required for a sample to be separated into tiles across the width dimension.
791
+ tile_sample_stride_height (`int`, *optional*):
792
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
793
+ no tiling artifacts produced across the height dimension.
794
+ tile_sample_stride_width (`int`, *optional*):
795
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
796
+ artifacts produced across the width dimension.
797
+ """
798
+ self.use_tiling = True
799
+ self.use_framewise_decoding = True
800
+ self.use_framewise_encoding = True
801
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
802
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
803
+ self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
804
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
805
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
806
+ self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
807
+
808
+ def disable_tiling(self) -> None:
809
+ r"""
810
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
811
+ decoding in one step.
812
+ """
813
+ self.use_tiling = False
814
+
815
+ def enable_slicing(self) -> None:
816
+ r"""
817
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
818
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
819
+ """
820
+ self.use_slicing = True
821
+
822
+ def disable_slicing(self) -> None:
823
+ r"""
824
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
825
+ decoding in one step.
826
+ """
827
+ self.use_slicing = False
828
+
829
+ @apply_forward_hook
830
+ def _encode(
831
+ self, x: torch.Tensor, return_dict: bool = True
832
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
833
+ """
834
+ Encode a batch of images into latents.
835
+
836
+ Args:
837
+ x (`torch.Tensor`): Input batch of images.
838
+ return_dict (`bool`, *optional*, defaults to `True`):
839
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
840
+
841
+ Returns:
842
+ The latent representations of the encoded images. If `return_dict` is True, a
843
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
844
+ """
845
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_height or x.shape[-2] > self.tile_sample_min_width):
846
+ return self.tiled_encode(x, return_dict=return_dict)
847
+
848
+ first_frames = self.encoder(x[:, :, :1, :, :])
849
+ h = [first_frames]
850
+ for i in range(1, x.shape[2], self.num_sample_frames_batch_size):
851
+ next_frames = self.encoder(x[:, :, i : i + self.num_sample_frames_batch_size, :, :])
852
+ h.append(next_frames)
853
+ h = torch.cat(h, dim=2)
854
+ moments = self.quant_conv(h)
855
+
856
+ self._clear_conv_cache()
857
+ return moments
858
+
859
+ @apply_forward_hook
860
+ def encode(
861
+ self, x: torch.Tensor, return_dict: bool = True
862
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
863
+ """
864
+ Encode a batch of images into latents.
865
+
866
+ Args:
867
+ x (`torch.Tensor`): Input batch of images.
868
+ return_dict (`bool`, *optional*, defaults to `True`):
869
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
870
+
871
+ Returns:
872
+ The latent representations of the encoded videos. If `return_dict` is True, a
873
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
874
+ """
875
+ if self.use_slicing and x.shape[0] > 1:
876
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
877
+ h = torch.cat(encoded_slices)
878
+ else:
879
+ h = self._encode(x)
880
+
881
+ posterior = DiagonalGaussianDistribution(h)
882
+
883
+ if not return_dict:
884
+ return (posterior,)
885
+ return AutoencoderKLOutput(latent_dist=posterior)
886
+
887
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
888
+ batch_size, num_channels, num_frames, height, width = z.shape
889
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
890
+ tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
891
+
892
+ if self.use_tiling and (z.shape[-1] > tile_latent_min_height or z.shape[-2] > tile_latent_min_width):
893
+ return self.tiled_decode(z, return_dict=return_dict)
894
+
895
+ z = self.post_quant_conv(z)
896
+
897
+ # Process the first frame and save the result
898
+ first_frames = self.decoder(z[:, :, :1, :, :])
899
+ # Initialize the list to store the processed frames, starting with the first frame
900
+ dec = [first_frames]
901
+ # Process the remaining frames, with the number of frames processed at a time determined by mini_batch_decoder
902
+ for i in range(1, z.shape[2], self.num_latent_frames_batch_size):
903
+ next_frames = self.decoder(z[:, :, i : i + self.num_latent_frames_batch_size, :, :])
904
+ dec.append(next_frames)
905
+ # Concatenate all processed frames along the channel dimension
906
+ dec = torch.cat(dec, dim=2)
907
+
908
+ if not return_dict:
909
+ return (dec,)
910
+
911
+ return DecoderOutput(sample=dec)
912
+
913
+ @apply_forward_hook
914
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
915
+ """
916
+ Decode a batch of images.
917
+
918
+ Args:
919
+ z (`torch.Tensor`): Input batch of latent vectors.
920
+ return_dict (`bool`, *optional*, defaults to `True`):
921
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
922
+
923
+ Returns:
924
+ [`~models.vae.DecoderOutput`] or `tuple`:
925
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
926
+ returned.
927
+ """
928
+ if self.use_slicing and z.shape[0] > 1:
929
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
930
+ decoded = torch.cat(decoded_slices)
931
+ else:
932
+ decoded = self._decode(z).sample
933
+
934
+ self._clear_conv_cache()
935
+ if not return_dict:
936
+ return (decoded,)
937
+ return DecoderOutput(sample=decoded)
938
+
939
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
940
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
941
+ for y in range(blend_extent):
942
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
943
+ y / blend_extent
944
+ )
945
+ return b
946
+
947
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
948
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
949
+ for x in range(blend_extent):
950
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
951
+ x / blend_extent
952
+ )
953
+ return b
954
+
955
+ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
956
+ batch_size, num_channels, num_frames, height, width = x.shape
957
+ latent_height = height // self.spatial_compression_ratio
958
+ latent_width = width // self.spatial_compression_ratio
959
+
960
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
961
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
962
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
963
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
964
+
965
+ blend_height = tile_latent_min_height - tile_latent_stride_height
966
+ blend_width = tile_latent_min_width - tile_latent_stride_width
967
+
968
+ # Split the image into 512x512 tiles and encode them separately.
969
+ rows = []
970
+ for i in range(0, height, self.tile_sample_stride_height):
971
+ row = []
972
+ for j in range(0, width, self.tile_sample_stride_width):
973
+ tile = x[
974
+ :,
975
+ :,
976
+ :,
977
+ i : i + self.tile_sample_min_height,
978
+ j : j + self.tile_sample_min_width,
979
+ ]
980
+
981
+ first_frames = self.encoder(tile[:, :, 0:1, :, :])
982
+ tile_h = [first_frames]
983
+ for k in range(1, num_frames, self.num_sample_frames_batch_size):
984
+ next_frames = self.encoder(tile[:, :, k : k + self.num_sample_frames_batch_size, :, :])
985
+ tile_h.append(next_frames)
986
+ tile = torch.cat(tile_h, dim=2)
987
+ tile = self.quant_conv(tile)
988
+ self._clear_conv_cache()
989
+ row.append(tile)
990
+ rows.append(row)
991
+ result_rows = []
992
+ for i, row in enumerate(rows):
993
+ result_row = []
994
+ for j, tile in enumerate(row):
995
+ # blend the above tile and the left tile
996
+ # to the current tile and add the current tile to the result row
997
+ if i > 0:
998
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
999
+ if j > 0:
1000
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1001
+ result_row.append(tile[:, :, :, :latent_height, :latent_width])
1002
+ result_rows.append(torch.cat(result_row, dim=4))
1003
+
1004
+ moments = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
1005
+ return moments
1006
+
1007
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1008
+ batch_size, num_channels, num_frames, height, width = z.shape
1009
+ sample_height = height * self.spatial_compression_ratio
1010
+ sample_width = width * self.spatial_compression_ratio
1011
+
1012
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1013
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1014
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1015
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1016
+
1017
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1018
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
1019
+
1020
+ # Split z into overlapping 64x64 tiles and decode them separately.
1021
+ # The tiles have an overlap to avoid seams between tiles.
1022
+ rows = []
1023
+ for i in range(0, height, tile_latent_stride_height):
1024
+ row = []
1025
+ for j in range(0, width, tile_latent_stride_width):
1026
+ tile = z[
1027
+ :,
1028
+ :,
1029
+ :,
1030
+ i : i + tile_latent_min_height,
1031
+ j : j + tile_latent_min_width,
1032
+ ]
1033
+ tile = self.post_quant_conv(tile)
1034
+
1035
+ # Process the first frame and save the result
1036
+ first_frames = self.decoder(tile[:, :, :1, :, :])
1037
+ # Initialize the list to store the processed frames, starting with the first frame
1038
+ tile_dec = [first_frames]
1039
+ # Process the remaining frames, with the number of frames processed at a time determined by mini_batch_decoder
1040
+ for k in range(1, num_frames, self.num_latent_frames_batch_size):
1041
+ next_frames = self.decoder(tile[:, :, k : k + self.num_latent_frames_batch_size, :, :])
1042
+ tile_dec.append(next_frames)
1043
+ # Concatenate all processed frames along the channel dimension
1044
+ decoded = torch.cat(tile_dec, dim=2)
1045
+ self._clear_conv_cache()
1046
+ row.append(decoded)
1047
+ rows.append(row)
1048
+ result_rows = []
1049
+ for i, row in enumerate(rows):
1050
+ result_row = []
1051
+ for j, tile in enumerate(row):
1052
+ # blend the above tile and the left tile
1053
+ # to the current tile and add the current tile to the result row
1054
+ if i > 0:
1055
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1056
+ if j > 0:
1057
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1058
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1059
+ result_rows.append(torch.cat(result_row, dim=4))
1060
+
1061
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1062
+
1063
+ if not return_dict:
1064
+ return (dec,)
1065
+
1066
+ return DecoderOutput(sample=dec)
1067
+
1068
+ def forward(
1069
+ self,
1070
+ sample: torch.Tensor,
1071
+ sample_posterior: bool = False,
1072
+ return_dict: bool = True,
1073
+ generator: Optional[torch.Generator] = None,
1074
+ ) -> Union[DecoderOutput, torch.Tensor]:
1075
+ r"""
1076
+ Args:
1077
+ sample (`torch.Tensor`): Input sample.
1078
+ sample_posterior (`bool`, *optional*, defaults to `False`):
1079
+ Whether to sample from the posterior.
1080
+ return_dict (`bool`, *optional*, defaults to `True`):
1081
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1082
+ """
1083
+ x = sample
1084
+ posterior = self.encode(x).latent_dist
1085
+ if sample_posterior:
1086
+ z = posterior.sample(generator=generator)
1087
+ else:
1088
+ z = posterior.mode()
1089
+ dec = self.decode(z).sample
1090
+
1091
+ if not return_dict:
1092
+ return (dec,)
1093
+
1094
+ return DecoderOutput(sample=dec)