diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Optional, Tuple, Union
16
+ from typing import Dict, Optional, Tuple, Union
17
17
 
18
18
  import numpy as np
19
19
  import torch
@@ -41,7 +41,9 @@ class CogVideoXSafeConv3d(nn.Conv3d):
41
41
  """
42
42
 
43
43
  def forward(self, input: torch.Tensor) -> torch.Tensor:
44
- memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
44
+ memory_count = (
45
+ (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
46
+ )
45
47
 
46
48
  # Set to 2GB, suitable for CuDNN
47
49
  if memory_count > 2:
@@ -92,11 +94,13 @@ class CogVideoXCausalConv3d(nn.Module):
92
94
 
93
95
  time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
94
96
 
95
- self.pad_mode = pad_mode
96
- time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
97
- height_pad = height_kernel_size // 2
98
- width_pad = width_kernel_size // 2
97
+ # TODO(aryan): configure calculation based on stride and dilation in the future.
98
+ # Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
99
+ time_pad = time_kernel_size - 1
100
+ height_pad = (height_kernel_size - 1) // 2
101
+ width_pad = (width_kernel_size - 1) // 2
99
102
 
103
+ self.pad_mode = pad_mode
100
104
  self.height_pad = height_pad
101
105
  self.width_pad = width_pad
102
106
  self.time_pad = time_pad
@@ -105,7 +109,7 @@ class CogVideoXCausalConv3d(nn.Module):
105
109
  self.temporal_dim = 2
106
110
  self.time_kernel_size = time_kernel_size
107
111
 
108
- stride = (stride, 1, 1)
112
+ stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
109
113
  dilation = (dilation, 1, 1)
110
114
  self.conv = CogVideoXSafeConv3d(
111
115
  in_channels=in_channels,
@@ -115,34 +119,30 @@ class CogVideoXCausalConv3d(nn.Module):
115
119
  dilation=dilation,
116
120
  )
117
121
 
118
- self.conv_cache = None
119
-
120
- def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
121
- kernel_size = self.time_kernel_size
122
- if kernel_size > 1:
123
- cached_inputs = (
124
- [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
125
- )
126
- inputs = torch.cat(cached_inputs + [inputs], dim=2)
122
+ def fake_context_parallel_forward(
123
+ self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
124
+ ) -> torch.Tensor:
125
+ if self.pad_mode == "replicate":
126
+ inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
127
+ else:
128
+ kernel_size = self.time_kernel_size
129
+ if kernel_size > 1:
130
+ cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
131
+ inputs = torch.cat(cached_inputs + [inputs], dim=2)
127
132
  return inputs
128
133
 
129
- def _clear_fake_context_parallel_cache(self):
130
- del self.conv_cache
131
- self.conv_cache = None
134
+ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
135
+ inputs = self.fake_context_parallel_forward(inputs, conv_cache)
132
136
 
133
- def forward(self, inputs: torch.Tensor) -> torch.Tensor:
134
- inputs = self.fake_context_parallel_forward(inputs)
135
-
136
- self._clear_fake_context_parallel_cache()
137
- # Note: we could move these to the cpu for a lower maximum memory usage but its only a few
138
- # hundred megabytes and so let's not do it for now
139
- self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
140
-
141
- padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
142
- inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
137
+ if self.pad_mode == "replicate":
138
+ conv_cache = None
139
+ else:
140
+ padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
141
+ conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
142
+ inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
143
143
 
144
144
  output = self.conv(inputs)
145
- return output
145
+ return output, conv_cache
146
146
 
147
147
 
148
148
  class CogVideoXSpatialNorm3D(nn.Module):
@@ -172,7 +172,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
172
172
  self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
173
173
  self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
174
174
 
175
- def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
175
+ def forward(
176
+ self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
177
+ ) -> torch.Tensor:
178
+ new_conv_cache = {}
179
+ conv_cache = conv_cache or {}
180
+
176
181
  if f.shape[2] > 1 and f.shape[2] % 2 == 1:
177
182
  f_first, f_rest = f[:, :, :1], f[:, :, 1:]
178
183
  f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
@@ -183,9 +188,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
183
188
  else:
184
189
  zq = F.interpolate(zq, size=f.shape[-3:])
185
190
 
191
+ conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
192
+ conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
193
+
186
194
  norm_f = self.norm_layer(f)
187
- new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
188
- return new_f
195
+ new_f = norm_f * conv_y + conv_b
196
+ return new_f, new_conv_cache
189
197
 
190
198
 
191
199
  class CogVideoXResnetBlock3D(nn.Module):
@@ -236,6 +244,7 @@ class CogVideoXResnetBlock3D(nn.Module):
236
244
  self.out_channels = out_channels
237
245
  self.nonlinearity = get_activation(non_linearity)
238
246
  self.use_conv_shortcut = conv_shortcut
247
+ self.spatial_norm_dim = spatial_norm_dim
239
248
 
240
249
  if spatial_norm_dim is None:
241
250
  self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
@@ -279,34 +288,43 @@ class CogVideoXResnetBlock3D(nn.Module):
279
288
  inputs: torch.Tensor,
280
289
  temb: Optional[torch.Tensor] = None,
281
290
  zq: Optional[torch.Tensor] = None,
291
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
282
292
  ) -> torch.Tensor:
293
+ new_conv_cache = {}
294
+ conv_cache = conv_cache or {}
295
+
283
296
  hidden_states = inputs
284
297
 
285
298
  if zq is not None:
286
- hidden_states = self.norm1(hidden_states, zq)
299
+ hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
287
300
  else:
288
301
  hidden_states = self.norm1(hidden_states)
289
302
 
290
303
  hidden_states = self.nonlinearity(hidden_states)
291
- hidden_states = self.conv1(hidden_states)
304
+ hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
292
305
 
293
306
  if temb is not None:
294
307
  hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
295
308
 
296
309
  if zq is not None:
297
- hidden_states = self.norm2(hidden_states, zq)
310
+ hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
298
311
  else:
299
312
  hidden_states = self.norm2(hidden_states)
300
313
 
301
314
  hidden_states = self.nonlinearity(hidden_states)
302
315
  hidden_states = self.dropout(hidden_states)
303
- hidden_states = self.conv2(hidden_states)
316
+ hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
304
317
 
305
318
  if self.in_channels != self.out_channels:
306
- inputs = self.conv_shortcut(inputs)
319
+ if self.use_conv_shortcut:
320
+ inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
321
+ inputs, conv_cache=conv_cache.get("conv_shortcut")
322
+ )
323
+ else:
324
+ inputs = self.conv_shortcut(inputs)
307
325
 
308
326
  hidden_states = hidden_states + inputs
309
- return hidden_states
327
+ return hidden_states, new_conv_cache
310
328
 
311
329
 
312
330
  class CogVideoXDownBlock3D(nn.Module):
@@ -392,9 +410,17 @@ class CogVideoXDownBlock3D(nn.Module):
392
410
  hidden_states: torch.Tensor,
393
411
  temb: Optional[torch.Tensor] = None,
394
412
  zq: Optional[torch.Tensor] = None,
413
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
395
414
  ) -> torch.Tensor:
396
- for resnet in self.resnets:
397
- if self.training and self.gradient_checkpointing:
415
+ r"""Forward method of the `CogVideoXDownBlock3D` class."""
416
+
417
+ new_conv_cache = {}
418
+ conv_cache = conv_cache or {}
419
+
420
+ for i, resnet in enumerate(self.resnets):
421
+ conv_cache_key = f"resnet_{i}"
422
+
423
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
398
424
 
399
425
  def create_custom_forward(module):
400
426
  def create_forward(*inputs):
@@ -402,17 +428,23 @@ class CogVideoXDownBlock3D(nn.Module):
402
428
 
403
429
  return create_forward
404
430
 
405
- hidden_states = torch.utils.checkpoint.checkpoint(
406
- create_custom_forward(resnet), hidden_states, temb, zq
431
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
432
+ create_custom_forward(resnet),
433
+ hidden_states,
434
+ temb,
435
+ zq,
436
+ conv_cache.get(conv_cache_key),
407
437
  )
408
438
  else:
409
- hidden_states = resnet(hidden_states, temb, zq)
439
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
440
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
441
+ )
410
442
 
411
443
  if self.downsamplers is not None:
412
444
  for downsampler in self.downsamplers:
413
445
  hidden_states = downsampler(hidden_states)
414
446
 
415
- return hidden_states
447
+ return hidden_states, new_conv_cache
416
448
 
417
449
 
418
450
  class CogVideoXMidBlock3D(nn.Module):
@@ -480,9 +512,17 @@ class CogVideoXMidBlock3D(nn.Module):
480
512
  hidden_states: torch.Tensor,
481
513
  temb: Optional[torch.Tensor] = None,
482
514
  zq: Optional[torch.Tensor] = None,
515
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
483
516
  ) -> torch.Tensor:
484
- for resnet in self.resnets:
485
- if self.training and self.gradient_checkpointing:
517
+ r"""Forward method of the `CogVideoXMidBlock3D` class."""
518
+
519
+ new_conv_cache = {}
520
+ conv_cache = conv_cache or {}
521
+
522
+ for i, resnet in enumerate(self.resnets):
523
+ conv_cache_key = f"resnet_{i}"
524
+
525
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
486
526
 
487
527
  def create_custom_forward(module):
488
528
  def create_forward(*inputs):
@@ -490,13 +530,15 @@ class CogVideoXMidBlock3D(nn.Module):
490
530
 
491
531
  return create_forward
492
532
 
493
- hidden_states = torch.utils.checkpoint.checkpoint(
494
- create_custom_forward(resnet), hidden_states, temb, zq
533
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
534
+ create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
495
535
  )
496
536
  else:
497
- hidden_states = resnet(hidden_states, temb, zq)
537
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
538
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
539
+ )
498
540
 
499
- return hidden_states
541
+ return hidden_states, new_conv_cache
500
542
 
501
543
 
502
544
  class CogVideoXUpBlock3D(nn.Module):
@@ -584,10 +626,17 @@ class CogVideoXUpBlock3D(nn.Module):
584
626
  hidden_states: torch.Tensor,
585
627
  temb: Optional[torch.Tensor] = None,
586
628
  zq: Optional[torch.Tensor] = None,
629
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
587
630
  ) -> torch.Tensor:
588
631
  r"""Forward method of the `CogVideoXUpBlock3D` class."""
589
- for resnet in self.resnets:
590
- if self.training and self.gradient_checkpointing:
632
+
633
+ new_conv_cache = {}
634
+ conv_cache = conv_cache or {}
635
+
636
+ for i, resnet in enumerate(self.resnets):
637
+ conv_cache_key = f"resnet_{i}"
638
+
639
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
591
640
 
592
641
  def create_custom_forward(module):
593
642
  def create_forward(*inputs):
@@ -595,17 +644,23 @@ class CogVideoXUpBlock3D(nn.Module):
595
644
 
596
645
  return create_forward
597
646
 
598
- hidden_states = torch.utils.checkpoint.checkpoint(
599
- create_custom_forward(resnet), hidden_states, temb, zq
647
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
648
+ create_custom_forward(resnet),
649
+ hidden_states,
650
+ temb,
651
+ zq,
652
+ conv_cache.get(conv_cache_key),
600
653
  )
601
654
  else:
602
- hidden_states = resnet(hidden_states, temb, zq)
655
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
656
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
657
+ )
603
658
 
604
659
  if self.upsamplers is not None:
605
660
  for upsampler in self.upsamplers:
606
661
  hidden_states = upsampler(hidden_states)
607
662
 
608
- return hidden_states
663
+ return hidden_states, new_conv_cache
609
664
 
610
665
 
611
666
  class CogVideoXEncoder3D(nn.Module):
@@ -705,11 +760,20 @@ class CogVideoXEncoder3D(nn.Module):
705
760
 
706
761
  self.gradient_checkpointing = False
707
762
 
708
- def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
763
+ def forward(
764
+ self,
765
+ sample: torch.Tensor,
766
+ temb: Optional[torch.Tensor] = None,
767
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
768
+ ) -> torch.Tensor:
709
769
  r"""The forward method of the `CogVideoXEncoder3D` class."""
710
- hidden_states = self.conv_in(sample)
711
770
 
712
- if self.training and self.gradient_checkpointing:
771
+ new_conv_cache = {}
772
+ conv_cache = conv_cache or {}
773
+
774
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
775
+
776
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
713
777
 
714
778
  def create_custom_forward(module):
715
779
  def custom_forward(*inputs):
@@ -718,28 +782,44 @@ class CogVideoXEncoder3D(nn.Module):
718
782
  return custom_forward
719
783
 
720
784
  # 1. Down
721
- for down_block in self.down_blocks:
722
- hidden_states = torch.utils.checkpoint.checkpoint(
723
- create_custom_forward(down_block), hidden_states, temb, None
785
+ for i, down_block in enumerate(self.down_blocks):
786
+ conv_cache_key = f"down_block_{i}"
787
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
788
+ create_custom_forward(down_block),
789
+ hidden_states,
790
+ temb,
791
+ None,
792
+ conv_cache.get(conv_cache_key),
724
793
  )
725
794
 
726
795
  # 2. Mid
727
- hidden_states = torch.utils.checkpoint.checkpoint(
728
- create_custom_forward(self.mid_block), hidden_states, temb, None
796
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
797
+ create_custom_forward(self.mid_block),
798
+ hidden_states,
799
+ temb,
800
+ None,
801
+ conv_cache.get("mid_block"),
729
802
  )
730
803
  else:
731
804
  # 1. Down
732
- for down_block in self.down_blocks:
733
- hidden_states = down_block(hidden_states, temb, None)
805
+ for i, down_block in enumerate(self.down_blocks):
806
+ conv_cache_key = f"down_block_{i}"
807
+ hidden_states, new_conv_cache[conv_cache_key] = down_block(
808
+ hidden_states, temb, None, conv_cache.get(conv_cache_key)
809
+ )
734
810
 
735
811
  # 2. Mid
736
- hidden_states = self.mid_block(hidden_states, temb, None)
812
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
813
+ hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
814
+ )
737
815
 
738
816
  # 3. Post-process
739
817
  hidden_states = self.norm_out(hidden_states)
740
818
  hidden_states = self.conv_act(hidden_states)
741
- hidden_states = self.conv_out(hidden_states)
742
- return hidden_states
819
+
820
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
821
+
822
+ return hidden_states, new_conv_cache
743
823
 
744
824
 
745
825
  class CogVideoXDecoder3D(nn.Module):
@@ -846,11 +926,20 @@ class CogVideoXDecoder3D(nn.Module):
846
926
 
847
927
  self.gradient_checkpointing = False
848
928
 
849
- def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
929
+ def forward(
930
+ self,
931
+ sample: torch.Tensor,
932
+ temb: Optional[torch.Tensor] = None,
933
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
934
+ ) -> torch.Tensor:
850
935
  r"""The forward method of the `CogVideoXDecoder3D` class."""
851
- hidden_states = self.conv_in(sample)
852
936
 
853
- if self.training and self.gradient_checkpointing:
937
+ new_conv_cache = {}
938
+ conv_cache = conv_cache or {}
939
+
940
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
941
+
942
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
854
943
 
855
944
  def create_custom_forward(module):
856
945
  def custom_forward(*inputs):
@@ -859,28 +948,45 @@ class CogVideoXDecoder3D(nn.Module):
859
948
  return custom_forward
860
949
 
861
950
  # 1. Mid
862
- hidden_states = torch.utils.checkpoint.checkpoint(
863
- create_custom_forward(self.mid_block), hidden_states, temb, sample
951
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
952
+ create_custom_forward(self.mid_block),
953
+ hidden_states,
954
+ temb,
955
+ sample,
956
+ conv_cache.get("mid_block"),
864
957
  )
865
958
 
866
959
  # 2. Up
867
- for up_block in self.up_blocks:
868
- hidden_states = torch.utils.checkpoint.checkpoint(
869
- create_custom_forward(up_block), hidden_states, temb, sample
960
+ for i, up_block in enumerate(self.up_blocks):
961
+ conv_cache_key = f"up_block_{i}"
962
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
963
+ create_custom_forward(up_block),
964
+ hidden_states,
965
+ temb,
966
+ sample,
967
+ conv_cache.get(conv_cache_key),
870
968
  )
871
969
  else:
872
970
  # 1. Mid
873
- hidden_states = self.mid_block(hidden_states, temb, sample)
971
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
972
+ hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
973
+ )
874
974
 
875
975
  # 2. Up
876
- for up_block in self.up_blocks:
877
- hidden_states = up_block(hidden_states, temb, sample)
976
+ for i, up_block in enumerate(self.up_blocks):
977
+ conv_cache_key = f"up_block_{i}"
978
+ hidden_states, new_conv_cache[conv_cache_key] = up_block(
979
+ hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
980
+ )
878
981
 
879
982
  # 3. Post-process
880
- hidden_states = self.norm_out(hidden_states, sample)
983
+ hidden_states, new_conv_cache["norm_out"] = self.norm_out(
984
+ hidden_states, sample, conv_cache=conv_cache.get("norm_out")
985
+ )
881
986
  hidden_states = self.conv_act(hidden_states)
882
- hidden_states = self.conv_out(hidden_states)
883
- return hidden_states
987
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
988
+
989
+ return hidden_states, new_conv_cache
884
990
 
885
991
 
886
992
  class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
@@ -951,6 +1057,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
951
1057
  force_upcast: float = True,
952
1058
  use_quant_conv: bool = False,
953
1059
  use_post_quant_conv: bool = False,
1060
+ invert_scale_latents: bool = False,
954
1061
  ):
955
1062
  super().__init__()
956
1063
 
@@ -1019,12 +1126,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1019
1126
  if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
1020
1127
  module.gradient_checkpointing = value
1021
1128
 
1022
- def _clear_fake_context_parallel_cache(self):
1023
- for name, module in self.named_modules():
1024
- if isinstance(module, CogVideoXCausalConv3d):
1025
- logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
1026
- module._clear_fake_context_parallel_cache()
1027
-
1028
1129
  def enable_tiling(
1029
1130
  self,
1030
1131
  tile_sample_min_height: Optional[int] = None,
@@ -1090,21 +1191,22 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1090
1191
 
1091
1192
  frame_batch_size = self.num_sample_frames_batch_size
1092
1193
  # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1093
- num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
1194
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1195
+ num_batches = max(num_frames // frame_batch_size, 1)
1196
+ conv_cache = None
1094
1197
  enc = []
1198
+
1095
1199
  for i in range(num_batches):
1096
1200
  remaining_frames = num_frames % frame_batch_size
1097
1201
  start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1098
1202
  end_frame = frame_batch_size * (i + 1) + remaining_frames
1099
1203
  x_intermediate = x[:, :, start_frame:end_frame]
1100
- x_intermediate = self.encoder(x_intermediate)
1204
+ x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
1101
1205
  if self.quant_conv is not None:
1102
1206
  x_intermediate = self.quant_conv(x_intermediate)
1103
1207
  enc.append(x_intermediate)
1104
1208
 
1105
- self._clear_fake_context_parallel_cache()
1106
1209
  enc = torch.cat(enc, dim=2)
1107
-
1108
1210
  return enc
1109
1211
 
1110
1212
  @apply_forward_hook
@@ -1142,8 +1244,10 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1142
1244
  return self.tiled_decode(z, return_dict=return_dict)
1143
1245
 
1144
1246
  frame_batch_size = self.num_latent_frames_batch_size
1145
- num_batches = num_frames // frame_batch_size
1247
+ num_batches = max(num_frames // frame_batch_size, 1)
1248
+ conv_cache = None
1146
1249
  dec = []
1250
+
1147
1251
  for i in range(num_batches):
1148
1252
  remaining_frames = num_frames % frame_batch_size
1149
1253
  start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
@@ -1151,10 +1255,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1151
1255
  z_intermediate = z[:, :, start_frame:end_frame]
1152
1256
  if self.post_quant_conv is not None:
1153
1257
  z_intermediate = self.post_quant_conv(z_intermediate)
1154
- z_intermediate = self.decoder(z_intermediate)
1258
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1155
1259
  dec.append(z_intermediate)
1156
1260
 
1157
- self._clear_fake_context_parallel_cache()
1158
1261
  dec = torch.cat(dec, dim=2)
1159
1262
 
1160
1263
  if not return_dict:
@@ -1237,8 +1340,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1237
1340
  row = []
1238
1341
  for j in range(0, width, overlap_width):
1239
1342
  # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1240
- num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
1343
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1344
+ num_batches = max(num_frames // frame_batch_size, 1)
1345
+ conv_cache = None
1241
1346
  time = []
1347
+
1242
1348
  for k in range(num_batches):
1243
1349
  remaining_frames = num_frames % frame_batch_size
1244
1350
  start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
@@ -1250,11 +1356,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1250
1356
  i : i + self.tile_sample_min_height,
1251
1357
  j : j + self.tile_sample_min_width,
1252
1358
  ]
1253
- tile = self.encoder(tile)
1359
+ tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
1254
1360
  if self.quant_conv is not None:
1255
1361
  tile = self.quant_conv(tile)
1256
1362
  time.append(tile)
1257
- self._clear_fake_context_parallel_cache()
1363
+
1258
1364
  row.append(torch.cat(time, dim=2))
1259
1365
  rows.append(row)
1260
1366
 
@@ -1314,8 +1420,10 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1314
1420
  for i in range(0, height, overlap_height):
1315
1421
  row = []
1316
1422
  for j in range(0, width, overlap_width):
1317
- num_batches = num_frames // frame_batch_size
1423
+ num_batches = max(num_frames // frame_batch_size, 1)
1424
+ conv_cache = None
1318
1425
  time = []
1426
+
1319
1427
  for k in range(num_batches):
1320
1428
  remaining_frames = num_frames % frame_batch_size
1321
1429
  start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
@@ -1329,9 +1437,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1329
1437
  ]
1330
1438
  if self.post_quant_conv is not None:
1331
1439
  tile = self.post_quant_conv(tile)
1332
- tile = self.decoder(tile)
1440
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1333
1441
  time.append(tile)
1334
- self._clear_fake_context_parallel_cache()
1442
+
1335
1443
  row.append(torch.cat(time, dim=2))
1336
1444
  rows.append(row)
1337
1445
 
@@ -1368,7 +1476,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1368
1476
  z = posterior.sample(generator=generator)
1369
1477
  else:
1370
1478
  z = posterior.mode()
1371
- dec = self.decode(z)
1479
+ dec = self.decode(z).sample
1372
1480
  if not return_dict:
1373
1481
  return (dec,)
1374
- return dec
1482
+ return DecoderOutput(sample=dec)