diffusers 0.27.1__py3-none-any.whl → 0.32.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (445) hide show
  1. diffusers/__init__.py +233 -6
  2. diffusers/callbacks.py +209 -0
  3. diffusers/commands/env.py +102 -6
  4. diffusers/configuration_utils.py +45 -16
  5. diffusers/dependency_versions_table.py +4 -3
  6. diffusers/image_processor.py +434 -110
  7. diffusers/loaders/__init__.py +42 -9
  8. diffusers/loaders/ip_adapter.py +626 -36
  9. diffusers/loaders/lora_base.py +900 -0
  10. diffusers/loaders/lora_conversion_utils.py +991 -125
  11. diffusers/loaders/lora_pipeline.py +3812 -0
  12. diffusers/loaders/peft.py +571 -7
  13. diffusers/loaders/single_file.py +405 -173
  14. diffusers/loaders/single_file_model.py +385 -0
  15. diffusers/loaders/single_file_utils.py +1783 -713
  16. diffusers/loaders/textual_inversion.py +41 -23
  17. diffusers/loaders/transformer_flux.py +181 -0
  18. diffusers/loaders/transformer_sd3.py +89 -0
  19. diffusers/loaders/unet.py +464 -540
  20. diffusers/loaders/unet_loader_utils.py +163 -0
  21. diffusers/models/__init__.py +76 -7
  22. diffusers/models/activations.py +65 -10
  23. diffusers/models/adapter.py +53 -53
  24. diffusers/models/attention.py +605 -18
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +4304 -687
  27. diffusers/models/autoencoders/__init__.py +8 -0
  28. diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
  29. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  30. diffusers/models/autoencoders/autoencoder_kl.py +110 -28
  31. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  32. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  33. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  34. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  35. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  36. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
  37. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  38. diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
  39. diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
  40. diffusers/models/autoencoders/vae.py +41 -29
  41. diffusers/models/autoencoders/vq_model.py +182 -0
  42. diffusers/models/controlnet.py +47 -800
  43. diffusers/models/controlnet_flux.py +70 -0
  44. diffusers/models/controlnet_sd3.py +68 -0
  45. diffusers/models/controlnet_sparsectrl.py +116 -0
  46. diffusers/models/controlnets/__init__.py +23 -0
  47. diffusers/models/controlnets/controlnet.py +872 -0
  48. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
  49. diffusers/models/controlnets/controlnet_flux.py +536 -0
  50. diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
  51. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  52. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  53. diffusers/models/controlnets/controlnet_union.py +832 -0
  54. diffusers/models/controlnets/controlnet_xs.py +1946 -0
  55. diffusers/models/controlnets/multicontrolnet.py +183 -0
  56. diffusers/models/downsampling.py +85 -18
  57. diffusers/models/embeddings.py +1856 -158
  58. diffusers/models/embeddings_flax.py +23 -9
  59. diffusers/models/model_loading_utils.py +480 -0
  60. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  61. diffusers/models/modeling_flax_utils.py +2 -7
  62. diffusers/models/modeling_outputs.py +14 -0
  63. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  64. diffusers/models/modeling_utils.py +611 -146
  65. diffusers/models/normalization.py +361 -20
  66. diffusers/models/resnet.py +18 -23
  67. diffusers/models/transformers/__init__.py +16 -0
  68. diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
  69. diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
  70. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  71. diffusers/models/transformers/dual_transformer_2d.py +9 -8
  72. diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
  73. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  74. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  75. diffusers/models/transformers/pixart_transformer_2d.py +445 -0
  76. diffusers/models/transformers/prior_transformer.py +13 -13
  77. diffusers/models/transformers/sana_transformer.py +488 -0
  78. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  79. diffusers/models/transformers/t5_film_transformer.py +17 -19
  80. diffusers/models/transformers/transformer_2d.py +297 -187
  81. diffusers/models/transformers/transformer_allegro.py +422 -0
  82. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  83. diffusers/models/transformers/transformer_flux.py +593 -0
  84. diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
  85. diffusers/models/transformers/transformer_ltx.py +469 -0
  86. diffusers/models/transformers/transformer_mochi.py +499 -0
  87. diffusers/models/transformers/transformer_sd3.py +461 -0
  88. diffusers/models/transformers/transformer_temporal.py +21 -19
  89. diffusers/models/unets/unet_1d.py +8 -8
  90. diffusers/models/unets/unet_1d_blocks.py +31 -31
  91. diffusers/models/unets/unet_2d.py +17 -10
  92. diffusers/models/unets/unet_2d_blocks.py +225 -149
  93. diffusers/models/unets/unet_2d_condition.py +41 -40
  94. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  95. diffusers/models/unets/unet_3d_blocks.py +192 -1057
  96. diffusers/models/unets/unet_3d_condition.py +22 -27
  97. diffusers/models/unets/unet_i2vgen_xl.py +22 -18
  98. diffusers/models/unets/unet_kandinsky3.py +2 -2
  99. diffusers/models/unets/unet_motion_model.py +1413 -89
  100. diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
  101. diffusers/models/unets/unet_stable_cascade.py +19 -18
  102. diffusers/models/unets/uvit_2d.py +2 -2
  103. diffusers/models/upsampling.py +95 -26
  104. diffusers/models/vq_model.py +12 -164
  105. diffusers/optimization.py +1 -1
  106. diffusers/pipelines/__init__.py +202 -3
  107. diffusers/pipelines/allegro/__init__.py +48 -0
  108. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  109. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  110. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  111. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  112. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  113. diffusers/pipelines/animatediff/__init__.py +8 -0
  114. diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
  115. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
  116. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
  117. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
  118. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
  119. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  120. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  121. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  122. diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
  123. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
  124. diffusers/pipelines/aura_flow/__init__.py +48 -0
  125. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
  126. diffusers/pipelines/auto_pipeline.py +196 -28
  127. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  128. diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
  129. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  130. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  131. diffusers/pipelines/cogvideo/__init__.py +54 -0
  132. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
  133. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  134. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
  135. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
  136. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  137. diffusers/pipelines/cogview3/__init__.py +47 -0
  138. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  139. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  140. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
  141. diffusers/pipelines/controlnet/__init__.py +86 -80
  142. diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
  143. diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
  144. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  145. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
  146. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
  147. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
  148. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
  149. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
  150. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  151. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  152. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  153. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  154. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  155. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
  156. diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
  157. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
  158. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  159. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  160. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
  161. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
  162. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  163. diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
  164. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
  165. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
  166. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
  167. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
  168. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
  169. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  170. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  171. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  172. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
  173. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
  174. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  175. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  176. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
  177. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
  178. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
  179. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
  180. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
  181. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
  182. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  183. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  184. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  185. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  186. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  187. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  188. diffusers/pipelines/flux/__init__.py +69 -0
  189. diffusers/pipelines/flux/modeling_flux.py +47 -0
  190. diffusers/pipelines/flux/pipeline_flux.py +957 -0
  191. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  192. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  193. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  194. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  195. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  196. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  197. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  198. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  199. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  200. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  201. diffusers/pipelines/flux/pipeline_output.py +37 -0
  202. diffusers/pipelines/free_init_utils.py +41 -38
  203. diffusers/pipelines/free_noise_utils.py +596 -0
  204. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  205. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  206. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  207. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  208. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
  209. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  210. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  211. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
  212. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  213. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  214. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  215. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  216. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
  217. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  218. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  219. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  220. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  221. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  222. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  223. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  224. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
  225. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
  226. diffusers/pipelines/kolors/__init__.py +54 -0
  227. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  228. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
  229. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  230. diffusers/pipelines/kolors/text_encoder.py +889 -0
  231. diffusers/pipelines/kolors/tokenizer.py +338 -0
  232. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
  233. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
  234. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
  235. diffusers/pipelines/latte/__init__.py +48 -0
  236. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  237. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
  238. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
  239. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  240. diffusers/pipelines/ltx/__init__.py +50 -0
  241. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  242. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  243. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  244. diffusers/pipelines/lumina/__init__.py +48 -0
  245. diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
  246. diffusers/pipelines/marigold/__init__.py +50 -0
  247. diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
  248. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  249. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  250. diffusers/pipelines/mochi/__init__.py +48 -0
  251. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  252. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  253. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  254. diffusers/pipelines/pag/__init__.py +80 -0
  255. diffusers/pipelines/pag/pag_utils.py +243 -0
  256. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
  257. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  258. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
  259. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  260. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
  261. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  262. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
  263. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  264. diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
  265. diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
  266. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  267. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
  268. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  269. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  270. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
  271. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
  272. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
  273. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  274. diffusers/pipelines/pia/pipeline_pia.py +74 -164
  275. diffusers/pipelines/pipeline_flax_utils.py +5 -10
  276. diffusers/pipelines/pipeline_loading_utils.py +515 -53
  277. diffusers/pipelines/pipeline_utils.py +411 -222
  278. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  279. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
  280. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
  281. diffusers/pipelines/sana/__init__.py +47 -0
  282. diffusers/pipelines/sana/pipeline_output.py +21 -0
  283. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  284. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
  285. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  286. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  287. diffusers/pipelines/shap_e/renderer.py +1 -1
  288. diffusers/pipelines/stable_audio/__init__.py +50 -0
  289. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  290. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
  291. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
  292. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  293. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
  294. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  295. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
  296. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  297. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  298. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  299. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
  300. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
  301. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  302. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
  303. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
  304. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
  305. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
  306. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
  307. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
  308. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
  309. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  310. diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  311. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  312. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  313. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
  314. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
  315. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
  316. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
  317. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
  318. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
  319. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
  320. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
  321. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
  322. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
  323. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  324. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  325. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
  326. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
  327. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
  328. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
  329. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
  330. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  331. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  332. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
  333. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
  334. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  335. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
  336. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
  337. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
  338. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
  339. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  340. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  341. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  342. diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
  343. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
  344. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  345. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  346. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
  347. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  348. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  349. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
  350. diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
  351. diffusers/quantizers/auto.py +139 -0
  352. diffusers/quantizers/base.py +233 -0
  353. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  354. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  355. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  356. diffusers/quantizers/gguf/__init__.py +1 -0
  357. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  358. diffusers/quantizers/gguf/utils.py +456 -0
  359. diffusers/quantizers/quantization_config.py +669 -0
  360. diffusers/quantizers/torchao/__init__.py +15 -0
  361. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  362. diffusers/schedulers/__init__.py +12 -2
  363. diffusers/schedulers/deprecated/__init__.py +1 -1
  364. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  365. diffusers/schedulers/scheduling_amused.py +5 -5
  366. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  367. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  368. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  369. diffusers/schedulers/scheduling_ddim.py +27 -26
  370. diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
  371. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  372. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  373. diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
  374. diffusers/schedulers/scheduling_ddpm.py +27 -30
  375. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  376. diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
  377. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  378. diffusers/schedulers/scheduling_deis_multistep.py +150 -50
  379. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  380. diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
  381. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  382. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
  383. diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
  384. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
  385. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
  386. diffusers/schedulers/scheduling_edm_euler.py +62 -39
  387. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
  388. diffusers/schedulers/scheduling_euler_discrete.py +255 -74
  389. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
  390. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
  391. diffusers/schedulers/scheduling_heun_discrete.py +174 -46
  392. diffusers/schedulers/scheduling_ipndm.py +9 -9
  393. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
  394. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
  395. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  396. diffusers/schedulers/scheduling_lcm.py +23 -29
  397. diffusers/schedulers/scheduling_lms_discrete.py +105 -28
  398. diffusers/schedulers/scheduling_pndm.py +20 -20
  399. diffusers/schedulers/scheduling_repaint.py +21 -21
  400. diffusers/schedulers/scheduling_sasolver.py +157 -60
  401. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  402. diffusers/schedulers/scheduling_tcd.py +41 -36
  403. diffusers/schedulers/scheduling_unclip.py +19 -16
  404. diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
  405. diffusers/schedulers/scheduling_utils.py +12 -5
  406. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  407. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  408. diffusers/training_utils.py +214 -30
  409. diffusers/utils/__init__.py +17 -1
  410. diffusers/utils/constants.py +3 -0
  411. diffusers/utils/doc_utils.py +1 -0
  412. diffusers/utils/dummy_pt_objects.py +592 -7
  413. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  414. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  415. diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
  416. diffusers/utils/dynamic_modules_utils.py +34 -29
  417. diffusers/utils/export_utils.py +50 -6
  418. diffusers/utils/hub_utils.py +131 -17
  419. diffusers/utils/import_utils.py +210 -8
  420. diffusers/utils/loading_utils.py +118 -5
  421. diffusers/utils/logging.py +4 -2
  422. diffusers/utils/peft_utils.py +37 -7
  423. diffusers/utils/state_dict_utils.py +13 -2
  424. diffusers/utils/testing_utils.py +193 -11
  425. diffusers/utils/torch_utils.py +4 -0
  426. diffusers/video_processor.py +113 -0
  427. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
  428. diffusers-0.32.2.dist-info/RECORD +550 -0
  429. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
  430. diffusers/loaders/autoencoder.py +0 -146
  431. diffusers/loaders/controlnet.py +0 -136
  432. diffusers/loaders/lora.py +0 -1349
  433. diffusers/models/prior_transformer.py +0 -12
  434. diffusers/models/t5_film_transformer.py +0 -70
  435. diffusers/models/transformer_2d.py +0 -25
  436. diffusers/models/transformer_temporal.py +0 -34
  437. diffusers/models/unet_1d.py +0 -26
  438. diffusers/models/unet_1d_blocks.py +0 -203
  439. diffusers/models/unet_2d.py +0 -27
  440. diffusers/models/unet_2d_blocks.py +0 -375
  441. diffusers/models/unet_2d_condition.py +0 -25
  442. diffusers-0.27.1.dist-info/RECORD +0 -399
  443. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
  444. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
  445. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1166 @@
1
+ # Copyright 2024 The Mochi team and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ from typing import Dict, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...utils import logging
24
+ from ...utils.accelerate_utils import apply_forward_hook
25
+ from ..activations import get_activation
26
+ from ..attention_processor import Attention, MochiVaeAttnProcessor2_0
27
+ from ..modeling_outputs import AutoencoderKLOutput
28
+ from ..modeling_utils import ModelMixin
29
+ from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d
30
+ from .vae import DecoderOutput, DiagonalGaussianDistribution
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ class MochiChunkedGroupNorm3D(nn.Module):
37
+ r"""
38
+ Applies per-frame group normalization for 5D video inputs. It also supports memory-efficient chunked group
39
+ normalization.
40
+
41
+ Args:
42
+ num_channels (int): Number of channels expected in input
43
+ num_groups (int, optional): Number of groups to separate the channels into. Default: 32
44
+ affine (bool, optional): If True, this module has learnable affine parameters. Default: True
45
+ chunk_size (int, optional): Size of each chunk for processing. Default: 8
46
+
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ num_channels: int,
52
+ num_groups: int = 32,
53
+ affine: bool = True,
54
+ chunk_size: int = 8,
55
+ ):
56
+ super().__init__()
57
+ self.norm_layer = nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, affine=affine)
58
+ self.chunk_size = chunk_size
59
+
60
+ def forward(self, x: torch.Tensor = None) -> torch.Tensor:
61
+ batch_size = x.size(0)
62
+
63
+ x = x.permute(0, 2, 1, 3, 4).flatten(0, 1)
64
+ output = torch.cat([self.norm_layer(chunk) for chunk in x.split(self.chunk_size, dim=0)], dim=0)
65
+ output = output.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
66
+
67
+ return output
68
+
69
+
70
+ class MochiResnetBlock3D(nn.Module):
71
+ r"""
72
+ A 3D ResNet block used in the Mochi model.
73
+
74
+ Args:
75
+ in_channels (`int`):
76
+ Number of input channels.
77
+ out_channels (`int`, *optional*):
78
+ Number of output channels. If None, defaults to `in_channels`.
79
+ non_linearity (`str`, defaults to `"swish"`):
80
+ Activation function to use.
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ in_channels: int,
86
+ out_channels: Optional[int] = None,
87
+ act_fn: str = "swish",
88
+ ):
89
+ super().__init__()
90
+
91
+ out_channels = out_channels or in_channels
92
+
93
+ self.in_channels = in_channels
94
+ self.out_channels = out_channels
95
+ self.nonlinearity = get_activation(act_fn)
96
+
97
+ self.norm1 = MochiChunkedGroupNorm3D(num_channels=in_channels)
98
+ self.conv1 = CogVideoXCausalConv3d(
99
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, pad_mode="replicate"
100
+ )
101
+ self.norm2 = MochiChunkedGroupNorm3D(num_channels=out_channels)
102
+ self.conv2 = CogVideoXCausalConv3d(
103
+ in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, pad_mode="replicate"
104
+ )
105
+
106
+ def forward(
107
+ self,
108
+ inputs: torch.Tensor,
109
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
110
+ ) -> torch.Tensor:
111
+ new_conv_cache = {}
112
+ conv_cache = conv_cache or {}
113
+
114
+ hidden_states = inputs
115
+
116
+ hidden_states = self.norm1(hidden_states)
117
+ hidden_states = self.nonlinearity(hidden_states)
118
+ hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
119
+
120
+ hidden_states = self.norm2(hidden_states)
121
+ hidden_states = self.nonlinearity(hidden_states)
122
+ hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
123
+
124
+ hidden_states = hidden_states + inputs
125
+ return hidden_states, new_conv_cache
126
+
127
+
128
+ class MochiDownBlock3D(nn.Module):
129
+ r"""
130
+ An downsampling block used in the Mochi model.
131
+
132
+ Args:
133
+ in_channels (`int`):
134
+ Number of input channels.
135
+ out_channels (`int`, *optional*):
136
+ Number of output channels. If None, defaults to `in_channels`.
137
+ num_layers (`int`, defaults to `1`):
138
+ Number of resnet blocks in the block.
139
+ temporal_expansion (`int`, defaults to `2`):
140
+ Temporal expansion factor.
141
+ spatial_expansion (`int`, defaults to `2`):
142
+ Spatial expansion factor.
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ in_channels: int,
148
+ out_channels: int,
149
+ num_layers: int = 1,
150
+ temporal_expansion: int = 2,
151
+ spatial_expansion: int = 2,
152
+ add_attention: bool = True,
153
+ ):
154
+ super().__init__()
155
+ self.temporal_expansion = temporal_expansion
156
+ self.spatial_expansion = spatial_expansion
157
+
158
+ self.conv_in = CogVideoXCausalConv3d(
159
+ in_channels=in_channels,
160
+ out_channels=out_channels,
161
+ kernel_size=(temporal_expansion, spatial_expansion, spatial_expansion),
162
+ stride=(temporal_expansion, spatial_expansion, spatial_expansion),
163
+ pad_mode="replicate",
164
+ )
165
+
166
+ resnets = []
167
+ norms = []
168
+ attentions = []
169
+ for _ in range(num_layers):
170
+ resnets.append(MochiResnetBlock3D(in_channels=out_channels))
171
+ if add_attention:
172
+ norms.append(MochiChunkedGroupNorm3D(num_channels=out_channels))
173
+ attentions.append(
174
+ Attention(
175
+ query_dim=out_channels,
176
+ heads=out_channels // 32,
177
+ dim_head=32,
178
+ qk_norm="l2",
179
+ is_causal=True,
180
+ processor=MochiVaeAttnProcessor2_0(),
181
+ )
182
+ )
183
+ else:
184
+ norms.append(None)
185
+ attentions.append(None)
186
+
187
+ self.resnets = nn.ModuleList(resnets)
188
+ self.norms = nn.ModuleList(norms)
189
+ self.attentions = nn.ModuleList(attentions)
190
+
191
+ self.gradient_checkpointing = False
192
+
193
+ def forward(
194
+ self,
195
+ hidden_states: torch.Tensor,
196
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
197
+ chunk_size: int = 2**15,
198
+ ) -> torch.Tensor:
199
+ r"""Forward method of the `MochiUpBlock3D` class."""
200
+
201
+ new_conv_cache = {}
202
+ conv_cache = conv_cache or {}
203
+
204
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(hidden_states)
205
+
206
+ for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
207
+ conv_cache_key = f"resnet_{i}"
208
+
209
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
210
+
211
+ def create_custom_forward(module):
212
+ def create_forward(*inputs):
213
+ return module(*inputs)
214
+
215
+ return create_forward
216
+
217
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
218
+ create_custom_forward(resnet),
219
+ hidden_states,
220
+ conv_cache=conv_cache.get(conv_cache_key),
221
+ )
222
+ else:
223
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
224
+ hidden_states, conv_cache=conv_cache.get(conv_cache_key)
225
+ )
226
+
227
+ if attn is not None:
228
+ residual = hidden_states
229
+ hidden_states = norm(hidden_states)
230
+
231
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
232
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).flatten(0, 2).contiguous()
233
+
234
+ # Perform attention in chunks to avoid following error:
235
+ # RuntimeError: CUDA error: invalid configuration argument
236
+ if hidden_states.size(0) <= chunk_size:
237
+ hidden_states = attn(hidden_states)
238
+ else:
239
+ hidden_states_chunks = []
240
+ for i in range(0, hidden_states.size(0), chunk_size):
241
+ hidden_states_chunk = hidden_states[i : i + chunk_size]
242
+ hidden_states_chunk = attn(hidden_states_chunk)
243
+ hidden_states_chunks.append(hidden_states_chunk)
244
+ hidden_states = torch.cat(hidden_states_chunks)
245
+
246
+ hidden_states = hidden_states.unflatten(0, (batch_size, height, width)).permute(0, 4, 3, 1, 2)
247
+
248
+ hidden_states = residual + hidden_states
249
+
250
+ return hidden_states, new_conv_cache
251
+
252
+
253
+ class MochiMidBlock3D(nn.Module):
254
+ r"""
255
+ A middle block used in the Mochi model.
256
+
257
+ Args:
258
+ in_channels (`int`):
259
+ Number of input channels.
260
+ num_layers (`int`, defaults to `3`):
261
+ Number of resnet blocks in the block.
262
+ """
263
+
264
+ def __init__(
265
+ self,
266
+ in_channels: int, # 768
267
+ num_layers: int = 3,
268
+ add_attention: bool = True,
269
+ ):
270
+ super().__init__()
271
+
272
+ resnets = []
273
+ norms = []
274
+ attentions = []
275
+
276
+ for _ in range(num_layers):
277
+ resnets.append(MochiResnetBlock3D(in_channels=in_channels))
278
+
279
+ if add_attention:
280
+ norms.append(MochiChunkedGroupNorm3D(num_channels=in_channels))
281
+ attentions.append(
282
+ Attention(
283
+ query_dim=in_channels,
284
+ heads=in_channels // 32,
285
+ dim_head=32,
286
+ qk_norm="l2",
287
+ is_causal=True,
288
+ processor=MochiVaeAttnProcessor2_0(),
289
+ )
290
+ )
291
+ else:
292
+ norms.append(None)
293
+ attentions.append(None)
294
+
295
+ self.resnets = nn.ModuleList(resnets)
296
+ self.norms = nn.ModuleList(norms)
297
+ self.attentions = nn.ModuleList(attentions)
298
+
299
+ self.gradient_checkpointing = False
300
+
301
+ def forward(
302
+ self,
303
+ hidden_states: torch.Tensor,
304
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
305
+ ) -> torch.Tensor:
306
+ r"""Forward method of the `MochiMidBlock3D` class."""
307
+
308
+ new_conv_cache = {}
309
+ conv_cache = conv_cache or {}
310
+
311
+ for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
312
+ conv_cache_key = f"resnet_{i}"
313
+
314
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
315
+
316
+ def create_custom_forward(module):
317
+ def create_forward(*inputs):
318
+ return module(*inputs)
319
+
320
+ return create_forward
321
+
322
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
323
+ create_custom_forward(resnet), hidden_states, conv_cache=conv_cache.get(conv_cache_key)
324
+ )
325
+ else:
326
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
327
+ hidden_states, conv_cache=conv_cache.get(conv_cache_key)
328
+ )
329
+
330
+ if attn is not None:
331
+ residual = hidden_states
332
+ hidden_states = norm(hidden_states)
333
+
334
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
335
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).flatten(0, 2).contiguous()
336
+ hidden_states = attn(hidden_states)
337
+ hidden_states = hidden_states.unflatten(0, (batch_size, height, width)).permute(0, 4, 3, 1, 2)
338
+
339
+ hidden_states = residual + hidden_states
340
+
341
+ return hidden_states, new_conv_cache
342
+
343
+
344
+ class MochiUpBlock3D(nn.Module):
345
+ r"""
346
+ An upsampling block used in the Mochi model.
347
+
348
+ Args:
349
+ in_channels (`int`):
350
+ Number of input channels.
351
+ out_channels (`int`, *optional*):
352
+ Number of output channels. If None, defaults to `in_channels`.
353
+ num_layers (`int`, defaults to `1`):
354
+ Number of resnet blocks in the block.
355
+ temporal_expansion (`int`, defaults to `2`):
356
+ Temporal expansion factor.
357
+ spatial_expansion (`int`, defaults to `2`):
358
+ Spatial expansion factor.
359
+ """
360
+
361
+ def __init__(
362
+ self,
363
+ in_channels: int,
364
+ out_channels: int,
365
+ num_layers: int = 1,
366
+ temporal_expansion: int = 2,
367
+ spatial_expansion: int = 2,
368
+ ):
369
+ super().__init__()
370
+ self.temporal_expansion = temporal_expansion
371
+ self.spatial_expansion = spatial_expansion
372
+
373
+ resnets = []
374
+ for _ in range(num_layers):
375
+ resnets.append(MochiResnetBlock3D(in_channels=in_channels))
376
+ self.resnets = nn.ModuleList(resnets)
377
+
378
+ self.proj = nn.Linear(in_channels, out_channels * temporal_expansion * spatial_expansion**2)
379
+
380
+ self.gradient_checkpointing = False
381
+
382
+ def forward(
383
+ self,
384
+ hidden_states: torch.Tensor,
385
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
386
+ ) -> torch.Tensor:
387
+ r"""Forward method of the `MochiUpBlock3D` class."""
388
+
389
+ new_conv_cache = {}
390
+ conv_cache = conv_cache or {}
391
+
392
+ for i, resnet in enumerate(self.resnets):
393
+ conv_cache_key = f"resnet_{i}"
394
+
395
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
396
+
397
+ def create_custom_forward(module):
398
+ def create_forward(*inputs):
399
+ return module(*inputs)
400
+
401
+ return create_forward
402
+
403
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
404
+ create_custom_forward(resnet),
405
+ hidden_states,
406
+ conv_cache=conv_cache.get(conv_cache_key),
407
+ )
408
+ else:
409
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
410
+ hidden_states, conv_cache=conv_cache.get(conv_cache_key)
411
+ )
412
+
413
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1)
414
+ hidden_states = self.proj(hidden_states)
415
+ hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
416
+
417
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
418
+ st = self.temporal_expansion
419
+ sh = self.spatial_expansion
420
+ sw = self.spatial_expansion
421
+
422
+ # Reshape and unpatchify
423
+ hidden_states = hidden_states.view(batch_size, -1, st, sh, sw, num_frames, height, width)
424
+ hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
425
+ hidden_states = hidden_states.view(batch_size, -1, num_frames * st, height * sh, width * sw)
426
+
427
+ return hidden_states, new_conv_cache
428
+
429
+
430
+ class FourierFeatures(nn.Module):
431
+ def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
432
+ super().__init__()
433
+
434
+ self.start = start
435
+ self.stop = stop
436
+ self.step = step
437
+
438
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
439
+ r"""Forward method of the `FourierFeatures` class."""
440
+ original_dtype = inputs.dtype
441
+ inputs = inputs.to(torch.float32)
442
+ num_channels = inputs.shape[1]
443
+ num_freqs = (self.stop - self.start) // self.step
444
+
445
+ freqs = torch.arange(self.start, self.stop, self.step, dtype=inputs.dtype, device=inputs.device)
446
+ w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
447
+ w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1]
448
+
449
+ # Interleaved repeat of input channels to match w
450
+ h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
451
+ # Scale channels by frequency.
452
+ h = w * h
453
+
454
+ return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1).to(original_dtype)
455
+
456
+
457
+ class MochiEncoder3D(nn.Module):
458
+ r"""
459
+ The `MochiEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent
460
+ representation.
461
+
462
+ Args:
463
+ in_channels (`int`, *optional*):
464
+ The number of input channels.
465
+ out_channels (`int`, *optional*):
466
+ The number of output channels.
467
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
468
+ The number of output channels for each block.
469
+ layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
470
+ The number of resnet blocks for each block.
471
+ temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
472
+ The temporal expansion factor for each of the up blocks.
473
+ spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
474
+ The spatial expansion factor for each of the up blocks.
475
+ non_linearity (`str`, *optional*, defaults to `"swish"`):
476
+ The non-linearity to use in the decoder.
477
+ """
478
+
479
+ def __init__(
480
+ self,
481
+ in_channels: int,
482
+ out_channels: int,
483
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
484
+ layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
485
+ temporal_expansions: Tuple[int, ...] = (1, 2, 3),
486
+ spatial_expansions: Tuple[int, ...] = (2, 2, 2),
487
+ add_attention_block: Tuple[bool, ...] = (False, True, True, True, True),
488
+ act_fn: str = "swish",
489
+ ):
490
+ super().__init__()
491
+
492
+ self.nonlinearity = get_activation(act_fn)
493
+
494
+ self.fourier_features = FourierFeatures()
495
+ self.proj_in = nn.Linear(in_channels, block_out_channels[0])
496
+ self.block_in = MochiMidBlock3D(
497
+ in_channels=block_out_channels[0], num_layers=layers_per_block[0], add_attention=add_attention_block[0]
498
+ )
499
+
500
+ down_blocks = []
501
+ for i in range(len(block_out_channels) - 1):
502
+ down_block = MochiDownBlock3D(
503
+ in_channels=block_out_channels[i],
504
+ out_channels=block_out_channels[i + 1],
505
+ num_layers=layers_per_block[i + 1],
506
+ temporal_expansion=temporal_expansions[i],
507
+ spatial_expansion=spatial_expansions[i],
508
+ add_attention=add_attention_block[i + 1],
509
+ )
510
+ down_blocks.append(down_block)
511
+ self.down_blocks = nn.ModuleList(down_blocks)
512
+
513
+ self.block_out = MochiMidBlock3D(
514
+ in_channels=block_out_channels[-1], num_layers=layers_per_block[-1], add_attention=add_attention_block[-1]
515
+ )
516
+ self.norm_out = MochiChunkedGroupNorm3D(block_out_channels[-1])
517
+ self.proj_out = nn.Linear(block_out_channels[-1], 2 * out_channels, bias=False)
518
+
519
+ def forward(
520
+ self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
521
+ ) -> torch.Tensor:
522
+ r"""Forward method of the `MochiEncoder3D` class."""
523
+
524
+ new_conv_cache = {}
525
+ conv_cache = conv_cache or {}
526
+
527
+ hidden_states = self.fourier_features(hidden_states)
528
+
529
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1)
530
+ hidden_states = self.proj_in(hidden_states)
531
+ hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
532
+
533
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
534
+
535
+ def create_custom_forward(module):
536
+ def create_forward(*inputs):
537
+ return module(*inputs)
538
+
539
+ return create_forward
540
+
541
+ hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint(
542
+ create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in")
543
+ )
544
+
545
+ for i, down_block in enumerate(self.down_blocks):
546
+ conv_cache_key = f"down_block_{i}"
547
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
548
+ create_custom_forward(down_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key)
549
+ )
550
+ else:
551
+ hidden_states, new_conv_cache["block_in"] = self.block_in(
552
+ hidden_states, conv_cache=conv_cache.get("block_in")
553
+ )
554
+
555
+ for i, down_block in enumerate(self.down_blocks):
556
+ conv_cache_key = f"down_block_{i}"
557
+ hidden_states, new_conv_cache[conv_cache_key] = down_block(
558
+ hidden_states, conv_cache=conv_cache.get(conv_cache_key)
559
+ )
560
+
561
+ hidden_states, new_conv_cache["block_out"] = self.block_out(
562
+ hidden_states, conv_cache=conv_cache.get("block_out")
563
+ )
564
+
565
+ hidden_states = self.norm_out(hidden_states)
566
+ hidden_states = self.nonlinearity(hidden_states)
567
+
568
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1)
569
+ hidden_states = self.proj_out(hidden_states)
570
+ hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
571
+
572
+ return hidden_states, new_conv_cache
573
+
574
+
575
+ class MochiDecoder3D(nn.Module):
576
+ r"""
577
+ The `MochiDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
578
+ sample.
579
+
580
+ Args:
581
+ in_channels (`int`, *optional*):
582
+ The number of input channels.
583
+ out_channels (`int`, *optional*):
584
+ The number of output channels.
585
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
586
+ The number of output channels for each block.
587
+ layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
588
+ The number of resnet blocks for each block.
589
+ temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
590
+ The temporal expansion factor for each of the up blocks.
591
+ spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
592
+ The spatial expansion factor for each of the up blocks.
593
+ non_linearity (`str`, *optional*, defaults to `"swish"`):
594
+ The non-linearity to use in the decoder.
595
+ """
596
+
597
+ def __init__(
598
+ self,
599
+ in_channels: int, # 12
600
+ out_channels: int, # 3
601
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
602
+ layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
603
+ temporal_expansions: Tuple[int, ...] = (1, 2, 3),
604
+ spatial_expansions: Tuple[int, ...] = (2, 2, 2),
605
+ act_fn: str = "swish",
606
+ ):
607
+ super().__init__()
608
+
609
+ self.nonlinearity = get_activation(act_fn)
610
+
611
+ self.conv_in = nn.Conv3d(in_channels, block_out_channels[-1], kernel_size=(1, 1, 1))
612
+ self.block_in = MochiMidBlock3D(
613
+ in_channels=block_out_channels[-1],
614
+ num_layers=layers_per_block[-1],
615
+ add_attention=False,
616
+ )
617
+
618
+ up_blocks = []
619
+ for i in range(len(block_out_channels) - 1):
620
+ up_block = MochiUpBlock3D(
621
+ in_channels=block_out_channels[-i - 1],
622
+ out_channels=block_out_channels[-i - 2],
623
+ num_layers=layers_per_block[-i - 2],
624
+ temporal_expansion=temporal_expansions[-i - 1],
625
+ spatial_expansion=spatial_expansions[-i - 1],
626
+ )
627
+ up_blocks.append(up_block)
628
+ self.up_blocks = nn.ModuleList(up_blocks)
629
+
630
+ self.block_out = MochiMidBlock3D(
631
+ in_channels=block_out_channels[0],
632
+ num_layers=layers_per_block[0],
633
+ add_attention=False,
634
+ )
635
+ self.proj_out = nn.Linear(block_out_channels[0], out_channels)
636
+
637
+ self.gradient_checkpointing = False
638
+
639
+ def forward(
640
+ self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
641
+ ) -> torch.Tensor:
642
+ r"""Forward method of the `MochiDecoder3D` class."""
643
+
644
+ new_conv_cache = {}
645
+ conv_cache = conv_cache or {}
646
+
647
+ hidden_states = self.conv_in(hidden_states)
648
+
649
+ # 1. Mid
650
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
651
+
652
+ def create_custom_forward(module):
653
+ def create_forward(*inputs):
654
+ return module(*inputs)
655
+
656
+ return create_forward
657
+
658
+ hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint(
659
+ create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in")
660
+ )
661
+
662
+ for i, up_block in enumerate(self.up_blocks):
663
+ conv_cache_key = f"up_block_{i}"
664
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
665
+ create_custom_forward(up_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key)
666
+ )
667
+ else:
668
+ hidden_states, new_conv_cache["block_in"] = self.block_in(
669
+ hidden_states, conv_cache=conv_cache.get("block_in")
670
+ )
671
+
672
+ for i, up_block in enumerate(self.up_blocks):
673
+ conv_cache_key = f"up_block_{i}"
674
+ hidden_states, new_conv_cache[conv_cache_key] = up_block(
675
+ hidden_states, conv_cache=conv_cache.get(conv_cache_key)
676
+ )
677
+
678
+ hidden_states, new_conv_cache["block_out"] = self.block_out(
679
+ hidden_states, conv_cache=conv_cache.get("block_out")
680
+ )
681
+
682
+ hidden_states = self.nonlinearity(hidden_states)
683
+
684
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1)
685
+ hidden_states = self.proj_out(hidden_states)
686
+ hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
687
+
688
+ return hidden_states, new_conv_cache
689
+
690
+
691
+ class AutoencoderKLMochi(ModelMixin, ConfigMixin):
692
+ r"""
693
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
694
+ [Mochi 1 preview](https://github.com/genmoai/models).
695
+
696
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
697
+ for all models (such as downloading or saving).
698
+
699
+ Parameters:
700
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
701
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
702
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
703
+ Tuple of block output channels.
704
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
705
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
706
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
707
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
708
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
709
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
710
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
711
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
712
+ """
713
+
714
+ _supports_gradient_checkpointing = True
715
+ _no_split_modules = ["MochiResnetBlock3D"]
716
+
717
+ @register_to_config
718
+ def __init__(
719
+ self,
720
+ in_channels: int = 15,
721
+ out_channels: int = 3,
722
+ encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384),
723
+ decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768),
724
+ latent_channels: int = 12,
725
+ layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
726
+ act_fn: str = "silu",
727
+ temporal_expansions: Tuple[int, ...] = (1, 2, 3),
728
+ spatial_expansions: Tuple[int, ...] = (2, 2, 2),
729
+ add_attention_block: Tuple[bool, ...] = (False, True, True, True, True),
730
+ latents_mean: Tuple[float, ...] = (
731
+ -0.06730895953510081,
732
+ -0.038011381506090416,
733
+ -0.07477820912866141,
734
+ -0.05565264470995561,
735
+ 0.012767231469026969,
736
+ -0.04703542746246419,
737
+ 0.043896967884726704,
738
+ -0.09346305707025976,
739
+ -0.09918314763016893,
740
+ -0.008729793427399178,
741
+ -0.011931556316503654,
742
+ -0.0321993391887285,
743
+ ),
744
+ latents_std: Tuple[float, ...] = (
745
+ 0.9263795028493863,
746
+ 0.9248894543193766,
747
+ 0.9393059390890617,
748
+ 0.959253732819592,
749
+ 0.8244560132752793,
750
+ 0.917259975397747,
751
+ 0.9294154431013696,
752
+ 1.3720942357788521,
753
+ 0.881393668867029,
754
+ 0.9168315692124348,
755
+ 0.9185249279345552,
756
+ 0.9274757570805041,
757
+ ),
758
+ scaling_factor: float = 1.0,
759
+ ):
760
+ super().__init__()
761
+
762
+ self.encoder = MochiEncoder3D(
763
+ in_channels=in_channels,
764
+ out_channels=latent_channels,
765
+ block_out_channels=encoder_block_out_channels,
766
+ layers_per_block=layers_per_block,
767
+ temporal_expansions=temporal_expansions,
768
+ spatial_expansions=spatial_expansions,
769
+ add_attention_block=add_attention_block,
770
+ act_fn=act_fn,
771
+ )
772
+ self.decoder = MochiDecoder3D(
773
+ in_channels=latent_channels,
774
+ out_channels=out_channels,
775
+ block_out_channels=decoder_block_out_channels,
776
+ layers_per_block=layers_per_block,
777
+ temporal_expansions=temporal_expansions,
778
+ spatial_expansions=spatial_expansions,
779
+ act_fn=act_fn,
780
+ )
781
+
782
+ self.spatial_compression_ratio = functools.reduce(lambda x, y: x * y, spatial_expansions, 1)
783
+ self.temporal_compression_ratio = functools.reduce(lambda x, y: x * y, temporal_expansions, 1)
784
+
785
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
786
+ # to perform decoding of a single video latent at a time.
787
+ self.use_slicing = False
788
+
789
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
790
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
791
+ # intermediate tiles together, the memory requirement can be lowered.
792
+ self.use_tiling = False
793
+
794
+ # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
795
+ # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
796
+ self.use_framewise_encoding = False
797
+ self.use_framewise_decoding = False
798
+
799
+ # This can be used to determine how the number of output frames in the final decoded video. To maintain consistency with
800
+ # the original implementation, this defaults to `True`.
801
+ # - Original implementation (drop_last_temporal_frames=True):
802
+ # Output frames = (latent_frames - 1) * temporal_compression_ratio + 1
803
+ # - Without dropping additional temporal upscaled frames (drop_last_temporal_frames=False):
804
+ # Output frames = latent_frames * temporal_compression_ratio
805
+ # The latter case is useful for frame packing and some training/finetuning scenarios where the additional.
806
+ self.drop_last_temporal_frames = True
807
+
808
+ # This can be configured based on the amount of GPU memory available.
809
+ # `12` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs.
810
+ # Setting it to higher values results in higher memory usage.
811
+ self.num_sample_frames_batch_size = 12
812
+ self.num_latent_frames_batch_size = 2
813
+
814
+ # The minimal tile height and width for spatial tiling to be used
815
+ self.tile_sample_min_height = 256
816
+ self.tile_sample_min_width = 256
817
+
818
+ # The minimal distance between two spatial tiles
819
+ self.tile_sample_stride_height = 192
820
+ self.tile_sample_stride_width = 192
821
+
822
+ def _set_gradient_checkpointing(self, module, value=False):
823
+ if isinstance(module, (MochiEncoder3D, MochiDecoder3D)):
824
+ module.gradient_checkpointing = value
825
+
826
+ def enable_tiling(
827
+ self,
828
+ tile_sample_min_height: Optional[int] = None,
829
+ tile_sample_min_width: Optional[int] = None,
830
+ tile_sample_stride_height: Optional[float] = None,
831
+ tile_sample_stride_width: Optional[float] = None,
832
+ ) -> None:
833
+ r"""
834
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
835
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
836
+ processing larger images.
837
+
838
+ Args:
839
+ tile_sample_min_height (`int`, *optional*):
840
+ The minimum height required for a sample to be separated into tiles across the height dimension.
841
+ tile_sample_min_width (`int`, *optional*):
842
+ The minimum width required for a sample to be separated into tiles across the width dimension.
843
+ tile_sample_stride_height (`int`, *optional*):
844
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
845
+ no tiling artifacts produced across the height dimension.
846
+ tile_sample_stride_width (`int`, *optional*):
847
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
848
+ artifacts produced across the width dimension.
849
+ """
850
+ self.use_tiling = True
851
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
852
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
853
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
854
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
855
+
856
+ def disable_tiling(self) -> None:
857
+ r"""
858
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
859
+ decoding in one step.
860
+ """
861
+ self.use_tiling = False
862
+
863
+ def enable_slicing(self) -> None:
864
+ r"""
865
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
866
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
867
+ """
868
+ self.use_slicing = True
869
+
870
+ def disable_slicing(self) -> None:
871
+ r"""
872
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
873
+ decoding in one step.
874
+ """
875
+ self.use_slicing = False
876
+
877
+ def _enable_framewise_encoding(self):
878
+ r"""
879
+ Enables the framewise VAE encoding implementation with past latent padding. By default, Diffusers uses the
880
+ oneshot encoding implementation without current latent replicate padding.
881
+
882
+ Warning: Framewise encoding may not work as expected due to the causal attention layers. If you enable
883
+ framewise encoding, encode a video, and try to decode it, there will be noticeable jittering effect.
884
+ """
885
+ self.use_framewise_encoding = True
886
+ for name, module in self.named_modules():
887
+ if isinstance(module, CogVideoXCausalConv3d):
888
+ module.pad_mode = "constant"
889
+
890
+ def _enable_framewise_decoding(self):
891
+ r"""
892
+ Enables the framewise VAE decoding implementation with past latent padding. By default, Diffusers uses the
893
+ oneshot decoding implementation without current latent replicate padding.
894
+ """
895
+ self.use_framewise_decoding = True
896
+ for name, module in self.named_modules():
897
+ if isinstance(module, CogVideoXCausalConv3d):
898
+ module.pad_mode = "constant"
899
+
900
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
901
+ batch_size, num_channels, num_frames, height, width = x.shape
902
+
903
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
904
+ return self.tiled_encode(x)
905
+
906
+ if self.use_framewise_encoding:
907
+ raise NotImplementedError(
908
+ "Frame-wise encoding does not work with the Mochi VAE Encoder due to the presence of attention layers. "
909
+ "As intermediate frames are not independent from each other, they cannot be encoded frame-wise."
910
+ )
911
+ else:
912
+ enc, _ = self.encoder(x)
913
+
914
+ return enc
915
+
916
+ @apply_forward_hook
917
+ def encode(
918
+ self, x: torch.Tensor, return_dict: bool = True
919
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
920
+ """
921
+ Encode a batch of images into latents.
922
+
923
+ Args:
924
+ x (`torch.Tensor`): Input batch of images.
925
+ return_dict (`bool`, *optional*, defaults to `True`):
926
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
927
+
928
+ Returns:
929
+ The latent representations of the encoded videos. If `return_dict` is True, a
930
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
931
+ """
932
+ if self.use_slicing and x.shape[0] > 1:
933
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
934
+ h = torch.cat(encoded_slices)
935
+ else:
936
+ h = self._encode(x)
937
+
938
+ posterior = DiagonalGaussianDistribution(h)
939
+
940
+ if not return_dict:
941
+ return (posterior,)
942
+ return AutoencoderKLOutput(latent_dist=posterior)
943
+
944
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
945
+ batch_size, num_channels, num_frames, height, width = z.shape
946
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
947
+ tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
948
+
949
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
950
+ return self.tiled_decode(z, return_dict=return_dict)
951
+
952
+ if self.use_framewise_decoding:
953
+ conv_cache = None
954
+ dec = []
955
+
956
+ for i in range(0, num_frames, self.num_latent_frames_batch_size):
957
+ z_intermediate = z[:, :, i : i + self.num_latent_frames_batch_size]
958
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
959
+ dec.append(z_intermediate)
960
+
961
+ dec = torch.cat(dec, dim=2)
962
+ else:
963
+ dec, _ = self.decoder(z)
964
+
965
+ if self.drop_last_temporal_frames and dec.size(2) >= self.temporal_compression_ratio:
966
+ dec = dec[:, :, self.temporal_compression_ratio - 1 :]
967
+
968
+ if not return_dict:
969
+ return (dec,)
970
+
971
+ return DecoderOutput(sample=dec)
972
+
973
+ @apply_forward_hook
974
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
975
+ """
976
+ Decode a batch of images.
977
+
978
+ Args:
979
+ z (`torch.Tensor`): Input batch of latent vectors.
980
+ return_dict (`bool`, *optional*, defaults to `True`):
981
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
982
+
983
+ Returns:
984
+ [`~models.vae.DecoderOutput`] or `tuple`:
985
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
986
+ returned.
987
+ """
988
+ if self.use_slicing and z.shape[0] > 1:
989
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
990
+ decoded = torch.cat(decoded_slices)
991
+ else:
992
+ decoded = self._decode(z).sample
993
+
994
+ if not return_dict:
995
+ return (decoded,)
996
+
997
+ return DecoderOutput(sample=decoded)
998
+
999
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1000
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1001
+ for y in range(blend_extent):
1002
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1003
+ y / blend_extent
1004
+ )
1005
+ return b
1006
+
1007
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1008
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
1009
+ for x in range(blend_extent):
1010
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1011
+ x / blend_extent
1012
+ )
1013
+ return b
1014
+
1015
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1016
+ r"""Encode a batch of images using a tiled encoder.
1017
+
1018
+ Args:
1019
+ x (`torch.Tensor`): Input batch of videos.
1020
+
1021
+ Returns:
1022
+ `torch.Tensor`:
1023
+ The latent representation of the encoded videos.
1024
+ """
1025
+ batch_size, num_channels, num_frames, height, width = x.shape
1026
+ latent_height = height // self.spatial_compression_ratio
1027
+ latent_width = width // self.spatial_compression_ratio
1028
+
1029
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1030
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1031
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1032
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1033
+
1034
+ blend_height = tile_latent_min_height - tile_latent_stride_height
1035
+ blend_width = tile_latent_min_width - tile_latent_stride_width
1036
+
1037
+ # Split x into overlapping tiles and encode them separately.
1038
+ # The tiles have an overlap to avoid seams between tiles.
1039
+ rows = []
1040
+ for i in range(0, height, self.tile_sample_stride_height):
1041
+ row = []
1042
+ for j in range(0, width, self.tile_sample_stride_width):
1043
+ if self.use_framewise_encoding:
1044
+ raise NotImplementedError(
1045
+ "Frame-wise encoding does not work with the Mochi VAE Encoder due to the presence of attention layers. "
1046
+ "As intermediate frames are not independent from each other, they cannot be encoded frame-wise."
1047
+ )
1048
+ else:
1049
+ time, _ = self.encoder(
1050
+ x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
1051
+ )
1052
+
1053
+ row.append(time)
1054
+ rows.append(row)
1055
+
1056
+ result_rows = []
1057
+ for i, row in enumerate(rows):
1058
+ result_row = []
1059
+ for j, tile in enumerate(row):
1060
+ # blend the above tile and the left tile
1061
+ # to the current tile and add the current tile to the result row
1062
+ if i > 0:
1063
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1064
+ if j > 0:
1065
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1066
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
1067
+ result_rows.append(torch.cat(result_row, dim=4))
1068
+
1069
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
1070
+ return enc
1071
+
1072
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1073
+ r"""
1074
+ Decode a batch of images using a tiled decoder.
1075
+
1076
+ Args:
1077
+ z (`torch.Tensor`): Input batch of latent vectors.
1078
+ return_dict (`bool`, *optional*, defaults to `True`):
1079
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1080
+
1081
+ Returns:
1082
+ [`~models.vae.DecoderOutput`] or `tuple`:
1083
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1084
+ returned.
1085
+ """
1086
+
1087
+ batch_size, num_channels, num_frames, height, width = z.shape
1088
+ sample_height = height * self.spatial_compression_ratio
1089
+ sample_width = width * self.spatial_compression_ratio
1090
+
1091
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1092
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1093
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1094
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1095
+
1096
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1097
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
1098
+
1099
+ # Split z into overlapping tiles and decode them separately.
1100
+ # The tiles have an overlap to avoid seams between tiles.
1101
+ rows = []
1102
+ for i in range(0, height, tile_latent_stride_height):
1103
+ row = []
1104
+ for j in range(0, width, tile_latent_stride_width):
1105
+ if self.use_framewise_decoding:
1106
+ time = []
1107
+ conv_cache = None
1108
+
1109
+ for k in range(0, num_frames, self.num_latent_frames_batch_size):
1110
+ tile = z[
1111
+ :,
1112
+ :,
1113
+ k : k + self.num_latent_frames_batch_size,
1114
+ i : i + tile_latent_min_height,
1115
+ j : j + tile_latent_min_width,
1116
+ ]
1117
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1118
+ time.append(tile)
1119
+
1120
+ time = torch.cat(time, dim=2)
1121
+ else:
1122
+ time, _ = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width])
1123
+
1124
+ if self.drop_last_temporal_frames and time.size(2) >= self.temporal_compression_ratio:
1125
+ time = time[:, :, self.temporal_compression_ratio - 1 :]
1126
+
1127
+ row.append(time)
1128
+ rows.append(row)
1129
+
1130
+ result_rows = []
1131
+ for i, row in enumerate(rows):
1132
+ result_row = []
1133
+ for j, tile in enumerate(row):
1134
+ # blend the above tile and the left tile
1135
+ # to the current tile and add the current tile to the result row
1136
+ if i > 0:
1137
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1138
+ if j > 0:
1139
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1140
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1141
+ result_rows.append(torch.cat(result_row, dim=4))
1142
+
1143
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1144
+
1145
+ if not return_dict:
1146
+ return (dec,)
1147
+
1148
+ return DecoderOutput(sample=dec)
1149
+
1150
+ def forward(
1151
+ self,
1152
+ sample: torch.Tensor,
1153
+ sample_posterior: bool = False,
1154
+ return_dict: bool = True,
1155
+ generator: Optional[torch.Generator] = None,
1156
+ ) -> Union[torch.Tensor, torch.Tensor]:
1157
+ x = sample
1158
+ posterior = self.encode(x).latent_dist
1159
+ if sample_posterior:
1160
+ z = posterior.sample(generator=generator)
1161
+ else:
1162
+ z = posterior.mode()
1163
+ dec = self.decode(z)
1164
+ if not return_dict:
1165
+ return (dec,)
1166
+ return dec