diffusers 0.27.0__py3-none-any.whl → 0.32.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (445) hide show
  1. diffusers/__init__.py +233 -6
  2. diffusers/callbacks.py +209 -0
  3. diffusers/commands/env.py +102 -6
  4. diffusers/configuration_utils.py +45 -16
  5. diffusers/dependency_versions_table.py +4 -3
  6. diffusers/image_processor.py +434 -110
  7. diffusers/loaders/__init__.py +42 -9
  8. diffusers/loaders/ip_adapter.py +626 -36
  9. diffusers/loaders/lora_base.py +900 -0
  10. diffusers/loaders/lora_conversion_utils.py +991 -125
  11. diffusers/loaders/lora_pipeline.py +3812 -0
  12. diffusers/loaders/peft.py +571 -7
  13. diffusers/loaders/single_file.py +405 -173
  14. diffusers/loaders/single_file_model.py +385 -0
  15. diffusers/loaders/single_file_utils.py +1783 -713
  16. diffusers/loaders/textual_inversion.py +41 -23
  17. diffusers/loaders/transformer_flux.py +181 -0
  18. diffusers/loaders/transformer_sd3.py +89 -0
  19. diffusers/loaders/unet.py +464 -540
  20. diffusers/loaders/unet_loader_utils.py +163 -0
  21. diffusers/models/__init__.py +76 -7
  22. diffusers/models/activations.py +65 -10
  23. diffusers/models/adapter.py +53 -53
  24. diffusers/models/attention.py +605 -18
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +4304 -687
  27. diffusers/models/autoencoders/__init__.py +8 -0
  28. diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
  29. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  30. diffusers/models/autoencoders/autoencoder_kl.py +110 -28
  31. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  32. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  33. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  34. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  35. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  36. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
  37. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  38. diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
  39. diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
  40. diffusers/models/autoencoders/vae.py +41 -29
  41. diffusers/models/autoencoders/vq_model.py +182 -0
  42. diffusers/models/controlnet.py +47 -800
  43. diffusers/models/controlnet_flux.py +70 -0
  44. diffusers/models/controlnet_sd3.py +68 -0
  45. diffusers/models/controlnet_sparsectrl.py +116 -0
  46. diffusers/models/controlnets/__init__.py +23 -0
  47. diffusers/models/controlnets/controlnet.py +872 -0
  48. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
  49. diffusers/models/controlnets/controlnet_flux.py +536 -0
  50. diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
  51. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  52. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  53. diffusers/models/controlnets/controlnet_union.py +832 -0
  54. diffusers/models/controlnets/controlnet_xs.py +1946 -0
  55. diffusers/models/controlnets/multicontrolnet.py +183 -0
  56. diffusers/models/downsampling.py +85 -18
  57. diffusers/models/embeddings.py +1856 -158
  58. diffusers/models/embeddings_flax.py +23 -9
  59. diffusers/models/model_loading_utils.py +480 -0
  60. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  61. diffusers/models/modeling_flax_utils.py +2 -7
  62. diffusers/models/modeling_outputs.py +14 -0
  63. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  64. diffusers/models/modeling_utils.py +611 -146
  65. diffusers/models/normalization.py +361 -20
  66. diffusers/models/resnet.py +18 -23
  67. diffusers/models/transformers/__init__.py +16 -0
  68. diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
  69. diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
  70. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  71. diffusers/models/transformers/dual_transformer_2d.py +9 -8
  72. diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
  73. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  74. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  75. diffusers/models/transformers/pixart_transformer_2d.py +445 -0
  76. diffusers/models/transformers/prior_transformer.py +13 -13
  77. diffusers/models/transformers/sana_transformer.py +488 -0
  78. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  79. diffusers/models/transformers/t5_film_transformer.py +17 -19
  80. diffusers/models/transformers/transformer_2d.py +297 -187
  81. diffusers/models/transformers/transformer_allegro.py +422 -0
  82. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  83. diffusers/models/transformers/transformer_flux.py +593 -0
  84. diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
  85. diffusers/models/transformers/transformer_ltx.py +469 -0
  86. diffusers/models/transformers/transformer_mochi.py +499 -0
  87. diffusers/models/transformers/transformer_sd3.py +461 -0
  88. diffusers/models/transformers/transformer_temporal.py +21 -19
  89. diffusers/models/unets/unet_1d.py +8 -8
  90. diffusers/models/unets/unet_1d_blocks.py +31 -31
  91. diffusers/models/unets/unet_2d.py +17 -10
  92. diffusers/models/unets/unet_2d_blocks.py +225 -149
  93. diffusers/models/unets/unet_2d_condition.py +50 -53
  94. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  95. diffusers/models/unets/unet_3d_blocks.py +192 -1057
  96. diffusers/models/unets/unet_3d_condition.py +22 -27
  97. diffusers/models/unets/unet_i2vgen_xl.py +22 -18
  98. diffusers/models/unets/unet_kandinsky3.py +2 -2
  99. diffusers/models/unets/unet_motion_model.py +1413 -89
  100. diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
  101. diffusers/models/unets/unet_stable_cascade.py +19 -18
  102. diffusers/models/unets/uvit_2d.py +2 -2
  103. diffusers/models/upsampling.py +95 -26
  104. diffusers/models/vq_model.py +12 -164
  105. diffusers/optimization.py +1 -1
  106. diffusers/pipelines/__init__.py +202 -3
  107. diffusers/pipelines/allegro/__init__.py +48 -0
  108. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  109. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  110. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  111. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  112. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  113. diffusers/pipelines/animatediff/__init__.py +8 -0
  114. diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
  115. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
  116. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
  117. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
  118. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
  119. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  120. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  121. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  122. diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
  123. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
  124. diffusers/pipelines/aura_flow/__init__.py +48 -0
  125. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
  126. diffusers/pipelines/auto_pipeline.py +196 -28
  127. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  128. diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
  129. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  130. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  131. diffusers/pipelines/cogvideo/__init__.py +54 -0
  132. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
  133. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  134. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
  135. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
  136. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  137. diffusers/pipelines/cogview3/__init__.py +47 -0
  138. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  139. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  140. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
  141. diffusers/pipelines/controlnet/__init__.py +86 -80
  142. diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
  143. diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
  144. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  145. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
  146. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
  147. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
  148. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
  149. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
  150. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  151. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  152. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  153. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  154. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  155. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
  156. diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
  157. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
  158. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  159. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  160. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
  161. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
  162. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  163. diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
  164. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
  165. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
  166. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
  167. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
  168. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
  169. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  170. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  171. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  172. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
  173. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
  174. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  175. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  176. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
  177. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
  178. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
  179. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
  180. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
  181. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
  182. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  183. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  184. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  185. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  186. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  187. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  188. diffusers/pipelines/flux/__init__.py +69 -0
  189. diffusers/pipelines/flux/modeling_flux.py +47 -0
  190. diffusers/pipelines/flux/pipeline_flux.py +957 -0
  191. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  192. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  193. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  194. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  195. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  196. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  197. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  198. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  199. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  200. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  201. diffusers/pipelines/flux/pipeline_output.py +37 -0
  202. diffusers/pipelines/free_init_utils.py +41 -38
  203. diffusers/pipelines/free_noise_utils.py +596 -0
  204. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  205. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  206. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  207. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  208. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
  209. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  210. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  211. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
  212. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  213. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  214. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  215. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  216. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
  217. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  218. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  219. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  220. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  221. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  222. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  223. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  224. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
  225. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
  226. diffusers/pipelines/kolors/__init__.py +54 -0
  227. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  228. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
  229. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  230. diffusers/pipelines/kolors/text_encoder.py +889 -0
  231. diffusers/pipelines/kolors/tokenizer.py +338 -0
  232. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
  233. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
  234. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
  235. diffusers/pipelines/latte/__init__.py +48 -0
  236. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  237. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
  238. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
  239. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  240. diffusers/pipelines/ltx/__init__.py +50 -0
  241. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  242. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  243. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  244. diffusers/pipelines/lumina/__init__.py +48 -0
  245. diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
  246. diffusers/pipelines/marigold/__init__.py +50 -0
  247. diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
  248. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  249. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  250. diffusers/pipelines/mochi/__init__.py +48 -0
  251. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  252. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  253. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  254. diffusers/pipelines/pag/__init__.py +80 -0
  255. diffusers/pipelines/pag/pag_utils.py +243 -0
  256. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
  257. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  258. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
  259. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  260. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
  261. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  262. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
  263. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  264. diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
  265. diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
  266. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  267. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
  268. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  269. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  270. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
  271. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
  272. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
  273. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  274. diffusers/pipelines/pia/pipeline_pia.py +74 -164
  275. diffusers/pipelines/pipeline_flax_utils.py +5 -10
  276. diffusers/pipelines/pipeline_loading_utils.py +515 -53
  277. diffusers/pipelines/pipeline_utils.py +411 -222
  278. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  279. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
  280. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
  281. diffusers/pipelines/sana/__init__.py +47 -0
  282. diffusers/pipelines/sana/pipeline_output.py +21 -0
  283. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  284. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
  285. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  286. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  287. diffusers/pipelines/shap_e/renderer.py +1 -1
  288. diffusers/pipelines/stable_audio/__init__.py +50 -0
  289. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  290. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
  291. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
  292. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  293. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
  294. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  295. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
  296. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  297. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  298. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  299. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
  300. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
  301. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  302. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
  303. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
  304. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
  305. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
  306. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
  307. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
  308. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
  309. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  310. diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  311. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  312. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  313. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
  314. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
  315. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
  316. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
  317. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
  318. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
  319. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
  320. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
  321. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
  322. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
  323. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  324. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  325. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
  326. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
  327. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
  328. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
  329. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
  330. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  331. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  332. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
  333. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
  334. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  335. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
  336. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
  337. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
  338. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
  339. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  340. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  341. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  342. diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
  343. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
  344. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  345. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  346. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
  347. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  348. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  349. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
  350. diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
  351. diffusers/quantizers/auto.py +139 -0
  352. diffusers/quantizers/base.py +233 -0
  353. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  354. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  355. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  356. diffusers/quantizers/gguf/__init__.py +1 -0
  357. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  358. diffusers/quantizers/gguf/utils.py +456 -0
  359. diffusers/quantizers/quantization_config.py +669 -0
  360. diffusers/quantizers/torchao/__init__.py +15 -0
  361. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  362. diffusers/schedulers/__init__.py +12 -2
  363. diffusers/schedulers/deprecated/__init__.py +1 -1
  364. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  365. diffusers/schedulers/scheduling_amused.py +5 -5
  366. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  367. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  368. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  369. diffusers/schedulers/scheduling_ddim.py +27 -26
  370. diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
  371. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  372. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  373. diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
  374. diffusers/schedulers/scheduling_ddpm.py +27 -30
  375. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  376. diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
  377. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  378. diffusers/schedulers/scheduling_deis_multistep.py +150 -50
  379. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  380. diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
  381. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  382. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
  383. diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
  384. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
  385. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
  386. diffusers/schedulers/scheduling_edm_euler.py +62 -39
  387. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
  388. diffusers/schedulers/scheduling_euler_discrete.py +255 -74
  389. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
  390. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
  391. diffusers/schedulers/scheduling_heun_discrete.py +174 -46
  392. diffusers/schedulers/scheduling_ipndm.py +9 -9
  393. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
  394. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
  395. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  396. diffusers/schedulers/scheduling_lcm.py +23 -29
  397. diffusers/schedulers/scheduling_lms_discrete.py +105 -28
  398. diffusers/schedulers/scheduling_pndm.py +20 -20
  399. diffusers/schedulers/scheduling_repaint.py +21 -21
  400. diffusers/schedulers/scheduling_sasolver.py +157 -60
  401. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  402. diffusers/schedulers/scheduling_tcd.py +41 -36
  403. diffusers/schedulers/scheduling_unclip.py +19 -16
  404. diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
  405. diffusers/schedulers/scheduling_utils.py +12 -5
  406. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  407. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  408. diffusers/training_utils.py +214 -30
  409. diffusers/utils/__init__.py +17 -1
  410. diffusers/utils/constants.py +3 -0
  411. diffusers/utils/doc_utils.py +1 -0
  412. diffusers/utils/dummy_pt_objects.py +592 -7
  413. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  414. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  415. diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
  416. diffusers/utils/dynamic_modules_utils.py +34 -29
  417. diffusers/utils/export_utils.py +50 -6
  418. diffusers/utils/hub_utils.py +131 -17
  419. diffusers/utils/import_utils.py +210 -8
  420. diffusers/utils/loading_utils.py +118 -5
  421. diffusers/utils/logging.py +4 -2
  422. diffusers/utils/peft_utils.py +37 -7
  423. diffusers/utils/state_dict_utils.py +13 -2
  424. diffusers/utils/testing_utils.py +193 -11
  425. diffusers/utils/torch_utils.py +4 -0
  426. diffusers/video_processor.py +113 -0
  427. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
  428. diffusers-0.32.2.dist-info/RECORD +550 -0
  429. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
  430. diffusers/loaders/autoencoder.py +0 -146
  431. diffusers/loaders/controlnet.py +0 -136
  432. diffusers/loaders/lora.py +0 -1349
  433. diffusers/models/prior_transformer.py +0 -12
  434. diffusers/models/t5_film_transformer.py +0 -70
  435. diffusers/models/transformer_2d.py +0 -25
  436. diffusers/models/transformer_temporal.py +0 -34
  437. diffusers/models/unet_1d.py +0 -26
  438. diffusers/models/unet_1d_blocks.py +0 -203
  439. diffusers/models/unet_2d.py +0 -27
  440. diffusers/models/unet_2d_blocks.py +0 -375
  441. diffusers/models/unet_2d_condition.py +0 -25
  442. diffusers-0.27.0.dist-info/RECORD +0 -399
  443. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
  444. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
  445. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1149 @@
1
+ # Copyright 2024 The RhymesAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...utils.accelerate_utils import apply_forward_hook
24
+ from ..attention_processor import Attention, SpatialNorm
25
+ from ..autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
26
+ from ..downsampling import Downsample2D
27
+ from ..modeling_outputs import AutoencoderKLOutput
28
+ from ..modeling_utils import ModelMixin
29
+ from ..resnet import ResnetBlock2D
30
+ from ..upsampling import Upsample2D
31
+
32
+
33
+ class AllegroTemporalConvLayer(nn.Module):
34
+ r"""
35
+ Temporal convolutional layer that can be used for video (sequence of images) input. Code adapted from:
36
+ https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ in_dim: int,
42
+ out_dim: Optional[int] = None,
43
+ dropout: float = 0.0,
44
+ norm_num_groups: int = 32,
45
+ up_sample: bool = False,
46
+ down_sample: bool = False,
47
+ stride: int = 1,
48
+ ) -> None:
49
+ super().__init__()
50
+
51
+ out_dim = out_dim or in_dim
52
+ pad_h = pad_w = int((stride - 1) * 0.5)
53
+ pad_t = 0
54
+
55
+ self.down_sample = down_sample
56
+ self.up_sample = up_sample
57
+
58
+ if down_sample:
59
+ self.conv1 = nn.Sequential(
60
+ nn.GroupNorm(norm_num_groups, in_dim),
61
+ nn.SiLU(),
62
+ nn.Conv3d(in_dim, out_dim, (2, stride, stride), stride=(2, 1, 1), padding=(0, pad_h, pad_w)),
63
+ )
64
+ elif up_sample:
65
+ self.conv1 = nn.Sequential(
66
+ nn.GroupNorm(norm_num_groups, in_dim),
67
+ nn.SiLU(),
68
+ nn.Conv3d(in_dim, out_dim * 2, (1, stride, stride), padding=(0, pad_h, pad_w)),
69
+ )
70
+ else:
71
+ self.conv1 = nn.Sequential(
72
+ nn.GroupNorm(norm_num_groups, in_dim),
73
+ nn.SiLU(),
74
+ nn.Conv3d(in_dim, out_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)),
75
+ )
76
+ self.conv2 = nn.Sequential(
77
+ nn.GroupNorm(norm_num_groups, out_dim),
78
+ nn.SiLU(),
79
+ nn.Dropout(dropout),
80
+ nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)),
81
+ )
82
+ self.conv3 = nn.Sequential(
83
+ nn.GroupNorm(norm_num_groups, out_dim),
84
+ nn.SiLU(),
85
+ nn.Dropout(dropout),
86
+ nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)),
87
+ )
88
+ self.conv4 = nn.Sequential(
89
+ nn.GroupNorm(norm_num_groups, out_dim),
90
+ nn.SiLU(),
91
+ nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)),
92
+ )
93
+
94
+ @staticmethod
95
+ def _pad_temporal_dim(hidden_states: torch.Tensor) -> torch.Tensor:
96
+ hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2)
97
+ hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2)
98
+ return hidden_states
99
+
100
+ def forward(self, hidden_states: torch.Tensor, batch_size: int) -> torch.Tensor:
101
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
102
+
103
+ if self.down_sample:
104
+ identity = hidden_states[:, :, ::2]
105
+ elif self.up_sample:
106
+ identity = hidden_states.repeat_interleave(2, dim=2)
107
+ else:
108
+ identity = hidden_states
109
+
110
+ if self.down_sample or self.up_sample:
111
+ hidden_states = self.conv1(hidden_states)
112
+ else:
113
+ hidden_states = self._pad_temporal_dim(hidden_states)
114
+ hidden_states = self.conv1(hidden_states)
115
+
116
+ if self.up_sample:
117
+ hidden_states = hidden_states.unflatten(1, (2, -1)).permute(0, 2, 3, 1, 4, 5).flatten(2, 3)
118
+
119
+ hidden_states = self._pad_temporal_dim(hidden_states)
120
+ hidden_states = self.conv2(hidden_states)
121
+
122
+ hidden_states = self._pad_temporal_dim(hidden_states)
123
+ hidden_states = self.conv3(hidden_states)
124
+
125
+ hidden_states = self._pad_temporal_dim(hidden_states)
126
+ hidden_states = self.conv4(hidden_states)
127
+
128
+ hidden_states = identity + hidden_states
129
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
130
+
131
+ return hidden_states
132
+
133
+
134
+ class AllegroDownBlock3D(nn.Module):
135
+ def __init__(
136
+ self,
137
+ in_channels: int,
138
+ out_channels: int,
139
+ dropout: float = 0.0,
140
+ num_layers: int = 1,
141
+ resnet_eps: float = 1e-6,
142
+ resnet_time_scale_shift: str = "default",
143
+ resnet_act_fn: str = "swish",
144
+ resnet_groups: int = 32,
145
+ resnet_pre_norm: bool = True,
146
+ output_scale_factor: float = 1.0,
147
+ spatial_downsample: bool = True,
148
+ temporal_downsample: bool = False,
149
+ downsample_padding: int = 1,
150
+ ):
151
+ super().__init__()
152
+
153
+ resnets = []
154
+ temp_convs = []
155
+
156
+ for i in range(num_layers):
157
+ in_channels = in_channels if i == 0 else out_channels
158
+ resnets.append(
159
+ ResnetBlock2D(
160
+ in_channels=in_channels,
161
+ out_channels=out_channels,
162
+ temb_channels=None,
163
+ eps=resnet_eps,
164
+ groups=resnet_groups,
165
+ dropout=dropout,
166
+ time_embedding_norm=resnet_time_scale_shift,
167
+ non_linearity=resnet_act_fn,
168
+ output_scale_factor=output_scale_factor,
169
+ pre_norm=resnet_pre_norm,
170
+ )
171
+ )
172
+ temp_convs.append(
173
+ AllegroTemporalConvLayer(
174
+ out_channels,
175
+ out_channels,
176
+ dropout=0.1,
177
+ norm_num_groups=resnet_groups,
178
+ )
179
+ )
180
+
181
+ self.resnets = nn.ModuleList(resnets)
182
+ self.temp_convs = nn.ModuleList(temp_convs)
183
+
184
+ if temporal_downsample:
185
+ self.temp_convs_down = AllegroTemporalConvLayer(
186
+ out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, down_sample=True, stride=3
187
+ )
188
+ self.add_temp_downsample = temporal_downsample
189
+
190
+ if spatial_downsample:
191
+ self.downsamplers = nn.ModuleList(
192
+ [
193
+ Downsample2D(
194
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
195
+ )
196
+ ]
197
+ )
198
+ else:
199
+ self.downsamplers = None
200
+
201
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
202
+ batch_size = hidden_states.shape[0]
203
+
204
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
205
+
206
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
207
+ hidden_states = resnet(hidden_states, temb=None)
208
+ hidden_states = temp_conv(hidden_states, batch_size=batch_size)
209
+
210
+ if self.add_temp_downsample:
211
+ hidden_states = self.temp_convs_down(hidden_states, batch_size=batch_size)
212
+
213
+ if self.downsamplers is not None:
214
+ for downsampler in self.downsamplers:
215
+ hidden_states = downsampler(hidden_states)
216
+
217
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
218
+ return hidden_states
219
+
220
+
221
+ class AllegroUpBlock3D(nn.Module):
222
+ def __init__(
223
+ self,
224
+ in_channels: int,
225
+ out_channels: int,
226
+ dropout: float = 0.0,
227
+ num_layers: int = 1,
228
+ resnet_eps: float = 1e-6,
229
+ resnet_time_scale_shift: str = "default", # default, spatial
230
+ resnet_act_fn: str = "swish",
231
+ resnet_groups: int = 32,
232
+ resnet_pre_norm: bool = True,
233
+ output_scale_factor: float = 1.0,
234
+ spatial_upsample: bool = True,
235
+ temporal_upsample: bool = False,
236
+ temb_channels: Optional[int] = None,
237
+ ):
238
+ super().__init__()
239
+
240
+ resnets = []
241
+ temp_convs = []
242
+
243
+ for i in range(num_layers):
244
+ input_channels = in_channels if i == 0 else out_channels
245
+
246
+ resnets.append(
247
+ ResnetBlock2D(
248
+ in_channels=input_channels,
249
+ out_channels=out_channels,
250
+ temb_channels=temb_channels,
251
+ eps=resnet_eps,
252
+ groups=resnet_groups,
253
+ dropout=dropout,
254
+ time_embedding_norm=resnet_time_scale_shift,
255
+ non_linearity=resnet_act_fn,
256
+ output_scale_factor=output_scale_factor,
257
+ pre_norm=resnet_pre_norm,
258
+ )
259
+ )
260
+ temp_convs.append(
261
+ AllegroTemporalConvLayer(
262
+ out_channels,
263
+ out_channels,
264
+ dropout=0.1,
265
+ norm_num_groups=resnet_groups,
266
+ )
267
+ )
268
+
269
+ self.resnets = nn.ModuleList(resnets)
270
+ self.temp_convs = nn.ModuleList(temp_convs)
271
+
272
+ self.add_temp_upsample = temporal_upsample
273
+ if temporal_upsample:
274
+ self.temp_conv_up = AllegroTemporalConvLayer(
275
+ out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, up_sample=True, stride=3
276
+ )
277
+
278
+ if spatial_upsample:
279
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
280
+ else:
281
+ self.upsamplers = None
282
+
283
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
284
+ batch_size = hidden_states.shape[0]
285
+
286
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
287
+
288
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
289
+ hidden_states = resnet(hidden_states, temb=None)
290
+ hidden_states = temp_conv(hidden_states, batch_size=batch_size)
291
+
292
+ if self.add_temp_upsample:
293
+ hidden_states = self.temp_conv_up(hidden_states, batch_size=batch_size)
294
+
295
+ if self.upsamplers is not None:
296
+ for upsampler in self.upsamplers:
297
+ hidden_states = upsampler(hidden_states)
298
+
299
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
300
+ return hidden_states
301
+
302
+
303
+ class AllegroMidBlock3DConv(nn.Module):
304
+ def __init__(
305
+ self,
306
+ in_channels: int,
307
+ temb_channels: int,
308
+ dropout: float = 0.0,
309
+ num_layers: int = 1,
310
+ resnet_eps: float = 1e-6,
311
+ resnet_time_scale_shift: str = "default", # default, spatial
312
+ resnet_act_fn: str = "swish",
313
+ resnet_groups: int = 32,
314
+ resnet_pre_norm: bool = True,
315
+ add_attention: bool = True,
316
+ attention_head_dim: int = 1,
317
+ output_scale_factor: float = 1.0,
318
+ ):
319
+ super().__init__()
320
+
321
+ # there is always at least one resnet
322
+ resnets = [
323
+ ResnetBlock2D(
324
+ in_channels=in_channels,
325
+ out_channels=in_channels,
326
+ temb_channels=temb_channels,
327
+ eps=resnet_eps,
328
+ groups=resnet_groups,
329
+ dropout=dropout,
330
+ time_embedding_norm=resnet_time_scale_shift,
331
+ non_linearity=resnet_act_fn,
332
+ output_scale_factor=output_scale_factor,
333
+ pre_norm=resnet_pre_norm,
334
+ )
335
+ ]
336
+ temp_convs = [
337
+ AllegroTemporalConvLayer(
338
+ in_channels,
339
+ in_channels,
340
+ dropout=0.1,
341
+ norm_num_groups=resnet_groups,
342
+ )
343
+ ]
344
+ attentions = []
345
+
346
+ if attention_head_dim is None:
347
+ attention_head_dim = in_channels
348
+
349
+ for _ in range(num_layers):
350
+ if add_attention:
351
+ attentions.append(
352
+ Attention(
353
+ in_channels,
354
+ heads=in_channels // attention_head_dim,
355
+ dim_head=attention_head_dim,
356
+ rescale_output_factor=output_scale_factor,
357
+ eps=resnet_eps,
358
+ norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
359
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
360
+ residual_connection=True,
361
+ bias=True,
362
+ upcast_softmax=True,
363
+ _from_deprecated_attn_block=True,
364
+ )
365
+ )
366
+ else:
367
+ attentions.append(None)
368
+
369
+ resnets.append(
370
+ ResnetBlock2D(
371
+ in_channels=in_channels,
372
+ out_channels=in_channels,
373
+ temb_channels=temb_channels,
374
+ eps=resnet_eps,
375
+ groups=resnet_groups,
376
+ dropout=dropout,
377
+ time_embedding_norm=resnet_time_scale_shift,
378
+ non_linearity=resnet_act_fn,
379
+ output_scale_factor=output_scale_factor,
380
+ pre_norm=resnet_pre_norm,
381
+ )
382
+ )
383
+
384
+ temp_convs.append(
385
+ AllegroTemporalConvLayer(
386
+ in_channels,
387
+ in_channels,
388
+ dropout=0.1,
389
+ norm_num_groups=resnet_groups,
390
+ )
391
+ )
392
+
393
+ self.resnets = nn.ModuleList(resnets)
394
+ self.temp_convs = nn.ModuleList(temp_convs)
395
+ self.attentions = nn.ModuleList(attentions)
396
+
397
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
398
+ batch_size = hidden_states.shape[0]
399
+
400
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
401
+ hidden_states = self.resnets[0](hidden_states, temb=None)
402
+
403
+ hidden_states = self.temp_convs[0](hidden_states, batch_size=batch_size)
404
+
405
+ for attn, resnet, temp_conv in zip(self.attentions, self.resnets[1:], self.temp_convs[1:]):
406
+ hidden_states = attn(hidden_states)
407
+ hidden_states = resnet(hidden_states, temb=None)
408
+ hidden_states = temp_conv(hidden_states, batch_size=batch_size)
409
+
410
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
411
+ return hidden_states
412
+
413
+
414
+ class AllegroEncoder3D(nn.Module):
415
+ def __init__(
416
+ self,
417
+ in_channels: int = 3,
418
+ out_channels: int = 3,
419
+ down_block_types: Tuple[str, ...] = (
420
+ "AllegroDownBlock3D",
421
+ "AllegroDownBlock3D",
422
+ "AllegroDownBlock3D",
423
+ "AllegroDownBlock3D",
424
+ ),
425
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
426
+ temporal_downsample_blocks: Tuple[bool, ...] = [True, True, False, False],
427
+ layers_per_block: int = 2,
428
+ norm_num_groups: int = 32,
429
+ act_fn: str = "silu",
430
+ double_z: bool = True,
431
+ ):
432
+ super().__init__()
433
+
434
+ self.conv_in = nn.Conv2d(
435
+ in_channels,
436
+ block_out_channels[0],
437
+ kernel_size=3,
438
+ stride=1,
439
+ padding=1,
440
+ )
441
+
442
+ self.temp_conv_in = nn.Conv3d(
443
+ in_channels=block_out_channels[0],
444
+ out_channels=block_out_channels[0],
445
+ kernel_size=(3, 1, 1),
446
+ padding=(1, 0, 0),
447
+ )
448
+
449
+ self.down_blocks = nn.ModuleList([])
450
+
451
+ # down
452
+ output_channel = block_out_channels[0]
453
+ for i, down_block_type in enumerate(down_block_types):
454
+ input_channel = output_channel
455
+ output_channel = block_out_channels[i]
456
+ is_final_block = i == len(block_out_channels) - 1
457
+
458
+ if down_block_type == "AllegroDownBlock3D":
459
+ down_block = AllegroDownBlock3D(
460
+ num_layers=layers_per_block,
461
+ in_channels=input_channel,
462
+ out_channels=output_channel,
463
+ spatial_downsample=not is_final_block,
464
+ temporal_downsample=temporal_downsample_blocks[i],
465
+ resnet_eps=1e-6,
466
+ downsample_padding=0,
467
+ resnet_act_fn=act_fn,
468
+ resnet_groups=norm_num_groups,
469
+ )
470
+ else:
471
+ raise ValueError("Invalid `down_block_type` encountered. Must be `AllegroDownBlock3D`")
472
+
473
+ self.down_blocks.append(down_block)
474
+
475
+ # mid
476
+ self.mid_block = AllegroMidBlock3DConv(
477
+ in_channels=block_out_channels[-1],
478
+ resnet_eps=1e-6,
479
+ resnet_act_fn=act_fn,
480
+ output_scale_factor=1,
481
+ resnet_time_scale_shift="default",
482
+ attention_head_dim=block_out_channels[-1],
483
+ resnet_groups=norm_num_groups,
484
+ temb_channels=None,
485
+ )
486
+
487
+ # out
488
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
489
+ self.conv_act = nn.SiLU()
490
+
491
+ conv_out_channels = 2 * out_channels if double_z else out_channels
492
+
493
+ self.temp_conv_out = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0))
494
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
495
+
496
+ self.gradient_checkpointing = False
497
+
498
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
499
+ batch_size = sample.shape[0]
500
+
501
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
502
+ sample = self.conv_in(sample)
503
+
504
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
505
+ residual = sample
506
+ sample = self.temp_conv_in(sample)
507
+ sample = sample + residual
508
+
509
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
510
+
511
+ def create_custom_forward(module):
512
+ def custom_forward(*inputs):
513
+ return module(*inputs)
514
+
515
+ return custom_forward
516
+
517
+ # Down blocks
518
+ for down_block in self.down_blocks:
519
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
520
+
521
+ # Mid block
522
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
523
+ else:
524
+ # Down blocks
525
+ for down_block in self.down_blocks:
526
+ sample = down_block(sample)
527
+
528
+ # Mid block
529
+ sample = self.mid_block(sample)
530
+
531
+ # Post process
532
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
533
+ sample = self.conv_norm_out(sample)
534
+ sample = self.conv_act(sample)
535
+
536
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
537
+ residual = sample
538
+ sample = self.temp_conv_out(sample)
539
+ sample = sample + residual
540
+
541
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
542
+ sample = self.conv_out(sample)
543
+
544
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
545
+ return sample
546
+
547
+
548
+ class AllegroDecoder3D(nn.Module):
549
+ def __init__(
550
+ self,
551
+ in_channels: int = 4,
552
+ out_channels: int = 3,
553
+ up_block_types: Tuple[str, ...] = (
554
+ "AllegroUpBlock3D",
555
+ "AllegroUpBlock3D",
556
+ "AllegroUpBlock3D",
557
+ "AllegroUpBlock3D",
558
+ ),
559
+ temporal_upsample_blocks: Tuple[bool, ...] = [False, True, True, False],
560
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
561
+ layers_per_block: int = 2,
562
+ norm_num_groups: int = 32,
563
+ act_fn: str = "silu",
564
+ norm_type: str = "group", # group, spatial
565
+ ):
566
+ super().__init__()
567
+
568
+ self.conv_in = nn.Conv2d(
569
+ in_channels,
570
+ block_out_channels[-1],
571
+ kernel_size=3,
572
+ stride=1,
573
+ padding=1,
574
+ )
575
+
576
+ self.temp_conv_in = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0))
577
+
578
+ self.mid_block = None
579
+ self.up_blocks = nn.ModuleList([])
580
+
581
+ temb_channels = in_channels if norm_type == "spatial" else None
582
+
583
+ # mid
584
+ self.mid_block = AllegroMidBlock3DConv(
585
+ in_channels=block_out_channels[-1],
586
+ resnet_eps=1e-6,
587
+ resnet_act_fn=act_fn,
588
+ output_scale_factor=1,
589
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
590
+ attention_head_dim=block_out_channels[-1],
591
+ resnet_groups=norm_num_groups,
592
+ temb_channels=temb_channels,
593
+ )
594
+
595
+ # up
596
+ reversed_block_out_channels = list(reversed(block_out_channels))
597
+ output_channel = reversed_block_out_channels[0]
598
+ for i, up_block_type in enumerate(up_block_types):
599
+ prev_output_channel = output_channel
600
+ output_channel = reversed_block_out_channels[i]
601
+
602
+ is_final_block = i == len(block_out_channels) - 1
603
+
604
+ if up_block_type == "AllegroUpBlock3D":
605
+ up_block = AllegroUpBlock3D(
606
+ num_layers=layers_per_block + 1,
607
+ in_channels=prev_output_channel,
608
+ out_channels=output_channel,
609
+ spatial_upsample=not is_final_block,
610
+ temporal_upsample=temporal_upsample_blocks[i],
611
+ resnet_eps=1e-6,
612
+ resnet_act_fn=act_fn,
613
+ resnet_groups=norm_num_groups,
614
+ temb_channels=temb_channels,
615
+ resnet_time_scale_shift=norm_type,
616
+ )
617
+ else:
618
+ raise ValueError("Invalid `UP_block_type` encountered. Must be `AllegroUpBlock3D`")
619
+
620
+ self.up_blocks.append(up_block)
621
+ prev_output_channel = output_channel
622
+
623
+ # out
624
+ if norm_type == "spatial":
625
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
626
+ else:
627
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
628
+
629
+ self.conv_act = nn.SiLU()
630
+
631
+ self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3, 1, 1), padding=(1, 0, 0))
632
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
633
+
634
+ self.gradient_checkpointing = False
635
+
636
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
637
+ batch_size = sample.shape[0]
638
+
639
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
640
+ sample = self.conv_in(sample)
641
+
642
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
643
+ residual = sample
644
+ sample = self.temp_conv_in(sample)
645
+ sample = sample + residual
646
+
647
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
648
+
649
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
650
+
651
+ def create_custom_forward(module):
652
+ def custom_forward(*inputs):
653
+ return module(*inputs)
654
+
655
+ return custom_forward
656
+
657
+ # Mid block
658
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
659
+
660
+ # Up blocks
661
+ for up_block in self.up_blocks:
662
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
663
+
664
+ else:
665
+ # Mid block
666
+ sample = self.mid_block(sample)
667
+ sample = sample.to(upscale_dtype)
668
+
669
+ # Up blocks
670
+ for up_block in self.up_blocks:
671
+ sample = up_block(sample)
672
+
673
+ # Post process
674
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
675
+ sample = self.conv_norm_out(sample)
676
+ sample = self.conv_act(sample)
677
+
678
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
679
+ residual = sample
680
+ sample = self.temp_conv_out(sample)
681
+ sample = sample + residual
682
+
683
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
684
+ sample = self.conv_out(sample)
685
+
686
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
687
+ return sample
688
+
689
+
690
+ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
691
+ r"""
692
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
693
+ [Allegro](https://github.com/rhymes-ai/Allegro).
694
+
695
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
696
+ for all models (such as downloading or saving).
697
+
698
+ Parameters:
699
+ in_channels (int, defaults to `3`):
700
+ Number of channels in the input image.
701
+ out_channels (int, defaults to `3`):
702
+ Number of channels in the output.
703
+ down_block_types (`Tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")`):
704
+ Tuple of strings denoting which types of down blocks to use.
705
+ up_block_types (`Tuple[str, ...]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`):
706
+ Tuple of strings denoting which types of up blocks to use.
707
+ block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
708
+ Tuple of integers denoting number of output channels in each block.
709
+ temporal_downsample_blocks (`Tuple[bool, ...]`, defaults to `(True, True, False, False)`):
710
+ Tuple of booleans denoting which blocks to enable temporal downsampling in.
711
+ latent_channels (`int`, defaults to `4`):
712
+ Number of channels in latents.
713
+ layers_per_block (`int`, defaults to `2`):
714
+ Number of resnet or attention or temporal convolution layers per down/up block.
715
+ act_fn (`str`, defaults to `"silu"`):
716
+ The activation function to use.
717
+ norm_num_groups (`int`, defaults to `32`):
718
+ Number of groups to use in normalization layers.
719
+ temporal_compression_ratio (`int`, defaults to `4`):
720
+ Ratio by which temporal dimension of samples are compressed.
721
+ sample_size (`int`, defaults to `320`):
722
+ Default latent size.
723
+ scaling_factor (`float`, defaults to `0.13235`):
724
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
725
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
726
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
727
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
728
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
729
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
730
+ force_upcast (`bool`, default to `True`):
731
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
732
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
733
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
734
+ """
735
+
736
+ _supports_gradient_checkpointing = True
737
+
738
+ @register_to_config
739
+ def __init__(
740
+ self,
741
+ in_channels: int = 3,
742
+ out_channels: int = 3,
743
+ down_block_types: Tuple[str, ...] = (
744
+ "AllegroDownBlock3D",
745
+ "AllegroDownBlock3D",
746
+ "AllegroDownBlock3D",
747
+ "AllegroDownBlock3D",
748
+ ),
749
+ up_block_types: Tuple[str, ...] = (
750
+ "AllegroUpBlock3D",
751
+ "AllegroUpBlock3D",
752
+ "AllegroUpBlock3D",
753
+ "AllegroUpBlock3D",
754
+ ),
755
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
756
+ temporal_downsample_blocks: Tuple[bool, ...] = (True, True, False, False),
757
+ temporal_upsample_blocks: Tuple[bool, ...] = (False, True, True, False),
758
+ latent_channels: int = 4,
759
+ layers_per_block: int = 2,
760
+ act_fn: str = "silu",
761
+ norm_num_groups: int = 32,
762
+ temporal_compression_ratio: float = 4,
763
+ sample_size: int = 320,
764
+ scaling_factor: float = 0.13,
765
+ force_upcast: bool = True,
766
+ ) -> None:
767
+ super().__init__()
768
+
769
+ self.encoder = AllegroEncoder3D(
770
+ in_channels=in_channels,
771
+ out_channels=latent_channels,
772
+ down_block_types=down_block_types,
773
+ temporal_downsample_blocks=temporal_downsample_blocks,
774
+ block_out_channels=block_out_channels,
775
+ layers_per_block=layers_per_block,
776
+ act_fn=act_fn,
777
+ norm_num_groups=norm_num_groups,
778
+ double_z=True,
779
+ )
780
+ self.decoder = AllegroDecoder3D(
781
+ in_channels=latent_channels,
782
+ out_channels=out_channels,
783
+ up_block_types=up_block_types,
784
+ temporal_upsample_blocks=temporal_upsample_blocks,
785
+ block_out_channels=block_out_channels,
786
+ layers_per_block=layers_per_block,
787
+ norm_num_groups=norm_num_groups,
788
+ act_fn=act_fn,
789
+ )
790
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
791
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
792
+
793
+ # TODO(aryan): For the 1.0.0 refactor, `temporal_compression_ratio` can be inferred directly and we don't need
794
+ # to use a specific parameter here or in other VAEs.
795
+
796
+ self.use_slicing = False
797
+ self.use_tiling = False
798
+
799
+ self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1)
800
+ self.tile_overlap_t = 8
801
+ self.tile_overlap_h = 120
802
+ self.tile_overlap_w = 80
803
+ sample_frames = 24
804
+
805
+ self.kernel = (sample_frames, sample_size, sample_size)
806
+ self.stride = (
807
+ sample_frames - self.tile_overlap_t,
808
+ sample_size - self.tile_overlap_h,
809
+ sample_size - self.tile_overlap_w,
810
+ )
811
+
812
+ def _set_gradient_checkpointing(self, module, value=False):
813
+ if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)):
814
+ module.gradient_checkpointing = value
815
+
816
+ def enable_tiling(self) -> None:
817
+ r"""
818
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
819
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
820
+ processing larger images.
821
+ """
822
+ self.use_tiling = True
823
+
824
+ def disable_tiling(self) -> None:
825
+ r"""
826
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
827
+ decoding in one step.
828
+ """
829
+ self.use_tiling = False
830
+
831
+ def enable_slicing(self) -> None:
832
+ r"""
833
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
834
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
835
+ """
836
+ self.use_slicing = True
837
+
838
+ def disable_slicing(self) -> None:
839
+ r"""
840
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
841
+ decoding in one step.
842
+ """
843
+ self.use_slicing = False
844
+
845
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
846
+ # TODO(aryan)
847
+ # if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
848
+ if self.use_tiling:
849
+ return self.tiled_encode(x)
850
+
851
+ raise NotImplementedError("Encoding without tiling has not been implemented yet.")
852
+
853
+ @apply_forward_hook
854
+ def encode(
855
+ self, x: torch.Tensor, return_dict: bool = True
856
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
857
+ r"""
858
+ Encode a batch of videos into latents.
859
+
860
+ Args:
861
+ x (`torch.Tensor`):
862
+ Input batch of videos.
863
+ return_dict (`bool`, defaults to `True`):
864
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
865
+
866
+ Returns:
867
+ The latent representations of the encoded videos. If `return_dict` is True, a
868
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
869
+ """
870
+ if self.use_slicing and x.shape[0] > 1:
871
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
872
+ h = torch.cat(encoded_slices)
873
+ else:
874
+ h = self._encode(x)
875
+
876
+ posterior = DiagonalGaussianDistribution(h)
877
+
878
+ if not return_dict:
879
+ return (posterior,)
880
+ return AutoencoderKLOutput(latent_dist=posterior)
881
+
882
+ def _decode(self, z: torch.Tensor) -> torch.Tensor:
883
+ # TODO(aryan): refactor tiling implementation
884
+ # if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
885
+ if self.use_tiling:
886
+ return self.tiled_decode(z)
887
+
888
+ raise NotImplementedError("Decoding without tiling has not been implemented yet.")
889
+
890
+ @apply_forward_hook
891
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
892
+ """
893
+ Decode a batch of videos.
894
+
895
+ Args:
896
+ z (`torch.Tensor`):
897
+ Input batch of latent vectors.
898
+ return_dict (`bool`, defaults to `True`):
899
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
900
+
901
+ Returns:
902
+ [`~models.vae.DecoderOutput`] or `tuple`:
903
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
904
+ returned.
905
+ """
906
+ if self.use_slicing and z.shape[0] > 1:
907
+ decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
908
+ decoded = torch.cat(decoded_slices)
909
+ else:
910
+ decoded = self._decode(z)
911
+
912
+ if not return_dict:
913
+ return (decoded,)
914
+ return DecoderOutput(sample=decoded)
915
+
916
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
917
+ local_batch_size = 1
918
+ rs = self.spatial_compression_ratio
919
+ rt = self.config.temporal_compression_ratio
920
+
921
+ batch_size, num_channels, num_frames, height, width = x.shape
922
+
923
+ output_num_frames = math.floor((num_frames - self.kernel[0]) / self.stride[0]) + 1
924
+ output_height = math.floor((height - self.kernel[1]) / self.stride[1]) + 1
925
+ output_width = math.floor((width - self.kernel[2]) / self.stride[2]) + 1
926
+
927
+ count = 0
928
+ output_latent = x.new_zeros(
929
+ (
930
+ output_num_frames * output_height * output_width,
931
+ 2 * self.config.latent_channels,
932
+ self.kernel[0] // rt,
933
+ self.kernel[1] // rs,
934
+ self.kernel[2] // rs,
935
+ )
936
+ )
937
+ vae_batch_input = x.new_zeros((local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2]))
938
+
939
+ for i in range(output_num_frames):
940
+ for j in range(output_height):
941
+ for k in range(output_width):
942
+ n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0]
943
+ h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1]
944
+ w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2]
945
+
946
+ video_cube = x[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
947
+ vae_batch_input[count % local_batch_size] = video_cube
948
+
949
+ if (
950
+ count % local_batch_size == local_batch_size - 1
951
+ or count == output_num_frames * output_height * output_width - 1
952
+ ):
953
+ latent = self.encoder(vae_batch_input)
954
+
955
+ if (
956
+ count == output_num_frames * output_height * output_width - 1
957
+ and count % local_batch_size != local_batch_size - 1
958
+ ):
959
+ output_latent[count - count % local_batch_size :] = latent[: count % local_batch_size + 1]
960
+ else:
961
+ output_latent[count - local_batch_size + 1 : count + 1] = latent
962
+
963
+ vae_batch_input = x.new_zeros(
964
+ (local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2])
965
+ )
966
+
967
+ count += 1
968
+
969
+ latent = x.new_zeros(
970
+ (batch_size, 2 * self.config.latent_channels, num_frames // rt, height // rs, width // rs)
971
+ )
972
+ output_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs
973
+ output_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs
974
+ output_overlap = (
975
+ output_kernel[0] - output_stride[0],
976
+ output_kernel[1] - output_stride[1],
977
+ output_kernel[2] - output_stride[2],
978
+ )
979
+
980
+ for i in range(output_num_frames):
981
+ n_start, n_end = i * output_stride[0], i * output_stride[0] + output_kernel[0]
982
+ for j in range(output_height):
983
+ h_start, h_end = j * output_stride[1], j * output_stride[1] + output_kernel[1]
984
+ for k in range(output_width):
985
+ w_start, w_end = k * output_stride[2], k * output_stride[2] + output_kernel[2]
986
+ latent_mean = _prepare_for_blend(
987
+ (i, output_num_frames, output_overlap[0]),
988
+ (j, output_height, output_overlap[1]),
989
+ (k, output_width, output_overlap[2]),
990
+ output_latent[i * output_height * output_width + j * output_width + k].unsqueeze(0),
991
+ )
992
+ latent[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean
993
+
994
+ latent = latent.permute(0, 2, 1, 3, 4).flatten(0, 1)
995
+ latent = self.quant_conv(latent)
996
+ latent = latent.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
997
+ return latent
998
+
999
+ def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
1000
+ local_batch_size = 1
1001
+ rs = self.spatial_compression_ratio
1002
+ rt = self.config.temporal_compression_ratio
1003
+
1004
+ latent_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs
1005
+ latent_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs
1006
+
1007
+ batch_size, num_channels, num_frames, height, width = z.shape
1008
+
1009
+ ## post quant conv (a mapping)
1010
+ z = z.permute(0, 2, 1, 3, 4).flatten(0, 1)
1011
+ z = self.post_quant_conv(z)
1012
+ z = z.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
1013
+
1014
+ output_num_frames = math.floor((num_frames - latent_kernel[0]) / latent_stride[0]) + 1
1015
+ output_height = math.floor((height - latent_kernel[1]) / latent_stride[1]) + 1
1016
+ output_width = math.floor((width - latent_kernel[2]) / latent_stride[2]) + 1
1017
+
1018
+ count = 0
1019
+ decoded_videos = z.new_zeros(
1020
+ (
1021
+ output_num_frames * output_height * output_width,
1022
+ self.config.out_channels,
1023
+ self.kernel[0],
1024
+ self.kernel[1],
1025
+ self.kernel[2],
1026
+ )
1027
+ )
1028
+ vae_batch_input = z.new_zeros(
1029
+ (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2])
1030
+ )
1031
+
1032
+ for i in range(output_num_frames):
1033
+ for j in range(output_height):
1034
+ for k in range(output_width):
1035
+ n_start, n_end = i * latent_stride[0], i * latent_stride[0] + latent_kernel[0]
1036
+ h_start, h_end = j * latent_stride[1], j * latent_stride[1] + latent_kernel[1]
1037
+ w_start, w_end = k * latent_stride[2], k * latent_stride[2] + latent_kernel[2]
1038
+
1039
+ current_latent = z[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
1040
+ vae_batch_input[count % local_batch_size] = current_latent
1041
+
1042
+ if (
1043
+ count % local_batch_size == local_batch_size - 1
1044
+ or count == output_num_frames * output_height * output_width - 1
1045
+ ):
1046
+ current_video = self.decoder(vae_batch_input)
1047
+
1048
+ if (
1049
+ count == output_num_frames * output_height * output_width - 1
1050
+ and count % local_batch_size != local_batch_size - 1
1051
+ ):
1052
+ decoded_videos[count - count % local_batch_size :] = current_video[
1053
+ : count % local_batch_size + 1
1054
+ ]
1055
+ else:
1056
+ decoded_videos[count - local_batch_size + 1 : count + 1] = current_video
1057
+
1058
+ vae_batch_input = z.new_zeros(
1059
+ (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2])
1060
+ )
1061
+
1062
+ count += 1
1063
+
1064
+ video = z.new_zeros((batch_size, self.config.out_channels, num_frames * rt, height * rs, width * rs))
1065
+ video_overlap = (
1066
+ self.kernel[0] - self.stride[0],
1067
+ self.kernel[1] - self.stride[1],
1068
+ self.kernel[2] - self.stride[2],
1069
+ )
1070
+
1071
+ for i in range(output_num_frames):
1072
+ n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0]
1073
+ for j in range(output_height):
1074
+ h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1]
1075
+ for k in range(output_width):
1076
+ w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2]
1077
+ out_video_blend = _prepare_for_blend(
1078
+ (i, output_num_frames, video_overlap[0]),
1079
+ (j, output_height, video_overlap[1]),
1080
+ (k, output_width, video_overlap[2]),
1081
+ decoded_videos[i * output_height * output_width + j * output_width + k].unsqueeze(0),
1082
+ )
1083
+ video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend
1084
+
1085
+ video = video.permute(0, 2, 1, 3, 4).contiguous()
1086
+ return video
1087
+
1088
+ def forward(
1089
+ self,
1090
+ sample: torch.Tensor,
1091
+ sample_posterior: bool = False,
1092
+ return_dict: bool = True,
1093
+ generator: Optional[torch.Generator] = None,
1094
+ ) -> Union[DecoderOutput, torch.Tensor]:
1095
+ r"""
1096
+ Args:
1097
+ sample (`torch.Tensor`): Input sample.
1098
+ sample_posterior (`bool`, *optional*, defaults to `False`):
1099
+ Whether to sample from the posterior.
1100
+ return_dict (`bool`, *optional*, defaults to `True`):
1101
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1102
+ generator (`torch.Generator`, *optional*):
1103
+ PyTorch random number generator.
1104
+ """
1105
+ x = sample
1106
+ posterior = self.encode(x).latent_dist
1107
+ if sample_posterior:
1108
+ z = posterior.sample(generator=generator)
1109
+ else:
1110
+ z = posterior.mode()
1111
+ dec = self.decode(z).sample
1112
+
1113
+ if not return_dict:
1114
+ return (dec,)
1115
+
1116
+ return DecoderOutput(sample=dec)
1117
+
1118
+
1119
+ def _prepare_for_blend(n_param, h_param, w_param, x):
1120
+ # TODO(aryan): refactor
1121
+ n, n_max, overlap_n = n_param
1122
+ h, h_max, overlap_h = h_param
1123
+ w, w_max, overlap_w = w_param
1124
+ if overlap_n > 0:
1125
+ if n > 0: # the head overlap part decays from 0 to 1
1126
+ x[:, :, 0:overlap_n, :, :] = x[:, :, 0:overlap_n, :, :] * (
1127
+ torch.arange(0, overlap_n).float().to(x.device) / overlap_n
1128
+ ).reshape(overlap_n, 1, 1)
1129
+ if n < n_max - 1: # the tail overlap part decays from 1 to 0
1130
+ x[:, :, -overlap_n:, :, :] = x[:, :, -overlap_n:, :, :] * (
1131
+ 1 - torch.arange(0, overlap_n).float().to(x.device) / overlap_n
1132
+ ).reshape(overlap_n, 1, 1)
1133
+ if h > 0:
1134
+ x[:, :, :, 0:overlap_h, :] = x[:, :, :, 0:overlap_h, :] * (
1135
+ torch.arange(0, overlap_h).float().to(x.device) / overlap_h
1136
+ ).reshape(overlap_h, 1)
1137
+ if h < h_max - 1:
1138
+ x[:, :, :, -overlap_h:, :] = x[:, :, :, -overlap_h:, :] * (
1139
+ 1 - torch.arange(0, overlap_h).float().to(x.device) / overlap_h
1140
+ ).reshape(overlap_h, 1)
1141
+ if w > 0:
1142
+ x[:, :, :, :, 0:overlap_w] = x[:, :, :, :, 0:overlap_w] * (
1143
+ torch.arange(0, overlap_w).float().to(x.device) / overlap_w
1144
+ )
1145
+ if w < w_max - 1:
1146
+ x[:, :, :, :, -overlap_w:] = x[:, :, :, :, -overlap_w:] * (
1147
+ 1 - torch.arange(0, overlap_w).float().to(x.device) / overlap_w
1148
+ )
1149
+ return x