diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
-
from typing import Optional, Tuple, Union
|
16
|
+
from typing import Dict, Optional, Tuple, Union
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
import torch
|
@@ -41,7 +41,9 @@ class CogVideoXSafeConv3d(nn.Conv3d):
|
|
41
41
|
"""
|
42
42
|
|
43
43
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
44
|
-
memory_count =
|
44
|
+
memory_count = (
|
45
|
+
(input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
|
46
|
+
)
|
45
47
|
|
46
48
|
# Set to 2GB, suitable for CuDNN
|
47
49
|
if memory_count > 2:
|
@@ -92,11 +94,13 @@ class CogVideoXCausalConv3d(nn.Module):
|
|
92
94
|
|
93
95
|
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
94
96
|
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
97
|
+
# TODO(aryan): configure calculation based on stride and dilation in the future.
|
98
|
+
# Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
|
99
|
+
time_pad = time_kernel_size - 1
|
100
|
+
height_pad = (height_kernel_size - 1) // 2
|
101
|
+
width_pad = (width_kernel_size - 1) // 2
|
99
102
|
|
103
|
+
self.pad_mode = pad_mode
|
100
104
|
self.height_pad = height_pad
|
101
105
|
self.width_pad = width_pad
|
102
106
|
self.time_pad = time_pad
|
@@ -105,7 +109,7 @@ class CogVideoXCausalConv3d(nn.Module):
|
|
105
109
|
self.temporal_dim = 2
|
106
110
|
self.time_kernel_size = time_kernel_size
|
107
111
|
|
108
|
-
stride = (stride, 1, 1)
|
112
|
+
stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
|
109
113
|
dilation = (dilation, 1, 1)
|
110
114
|
self.conv = CogVideoXSafeConv3d(
|
111
115
|
in_channels=in_channels,
|
@@ -115,34 +119,30 @@ class CogVideoXCausalConv3d(nn.Module):
|
|
115
119
|
dilation=dilation,
|
116
120
|
)
|
117
121
|
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
122
|
+
def fake_context_parallel_forward(
|
123
|
+
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
|
124
|
+
) -> torch.Tensor:
|
125
|
+
if self.pad_mode == "replicate":
|
126
|
+
inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
|
127
|
+
else:
|
128
|
+
kernel_size = self.time_kernel_size
|
129
|
+
if kernel_size > 1:
|
130
|
+
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
131
|
+
inputs = torch.cat(cached_inputs + [inputs], dim=2)
|
127
132
|
return inputs
|
128
133
|
|
129
|
-
def
|
130
|
-
|
131
|
-
self.conv_cache = None
|
134
|
+
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
|
135
|
+
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
|
132
136
|
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
140
|
-
|
141
|
-
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
142
|
-
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
137
|
+
if self.pad_mode == "replicate":
|
138
|
+
conv_cache = None
|
139
|
+
else:
|
140
|
+
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
141
|
+
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
142
|
+
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
143
143
|
|
144
144
|
output = self.conv(inputs)
|
145
|
-
return output
|
145
|
+
return output, conv_cache
|
146
146
|
|
147
147
|
|
148
148
|
class CogVideoXSpatialNorm3D(nn.Module):
|
@@ -172,7 +172,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
|
172
172
|
self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
173
173
|
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
174
174
|
|
175
|
-
def forward(
|
175
|
+
def forward(
|
176
|
+
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
|
177
|
+
) -> torch.Tensor:
|
178
|
+
new_conv_cache = {}
|
179
|
+
conv_cache = conv_cache or {}
|
180
|
+
|
176
181
|
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
177
182
|
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
178
183
|
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
@@ -183,9 +188,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
|
183
188
|
else:
|
184
189
|
zq = F.interpolate(zq, size=f.shape[-3:])
|
185
190
|
|
191
|
+
conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
|
192
|
+
conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
|
193
|
+
|
186
194
|
norm_f = self.norm_layer(f)
|
187
|
-
new_f = norm_f *
|
188
|
-
return new_f
|
195
|
+
new_f = norm_f * conv_y + conv_b
|
196
|
+
return new_f, new_conv_cache
|
189
197
|
|
190
198
|
|
191
199
|
class CogVideoXResnetBlock3D(nn.Module):
|
@@ -236,6 +244,7 @@ class CogVideoXResnetBlock3D(nn.Module):
|
|
236
244
|
self.out_channels = out_channels
|
237
245
|
self.nonlinearity = get_activation(non_linearity)
|
238
246
|
self.use_conv_shortcut = conv_shortcut
|
247
|
+
self.spatial_norm_dim = spatial_norm_dim
|
239
248
|
|
240
249
|
if spatial_norm_dim is None:
|
241
250
|
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
@@ -279,34 +288,43 @@ class CogVideoXResnetBlock3D(nn.Module):
|
|
279
288
|
inputs: torch.Tensor,
|
280
289
|
temb: Optional[torch.Tensor] = None,
|
281
290
|
zq: Optional[torch.Tensor] = None,
|
291
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
282
292
|
) -> torch.Tensor:
|
293
|
+
new_conv_cache = {}
|
294
|
+
conv_cache = conv_cache or {}
|
295
|
+
|
283
296
|
hidden_states = inputs
|
284
297
|
|
285
298
|
if zq is not None:
|
286
|
-
hidden_states = self.norm1(hidden_states, zq)
|
299
|
+
hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
|
287
300
|
else:
|
288
301
|
hidden_states = self.norm1(hidden_states)
|
289
302
|
|
290
303
|
hidden_states = self.nonlinearity(hidden_states)
|
291
|
-
hidden_states = self.conv1(hidden_states)
|
304
|
+
hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
|
292
305
|
|
293
306
|
if temb is not None:
|
294
307
|
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
295
308
|
|
296
309
|
if zq is not None:
|
297
|
-
hidden_states = self.norm2(hidden_states, zq)
|
310
|
+
hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
|
298
311
|
else:
|
299
312
|
hidden_states = self.norm2(hidden_states)
|
300
313
|
|
301
314
|
hidden_states = self.nonlinearity(hidden_states)
|
302
315
|
hidden_states = self.dropout(hidden_states)
|
303
|
-
hidden_states = self.conv2(hidden_states)
|
316
|
+
hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
|
304
317
|
|
305
318
|
if self.in_channels != self.out_channels:
|
306
|
-
|
319
|
+
if self.use_conv_shortcut:
|
320
|
+
inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
|
321
|
+
inputs, conv_cache=conv_cache.get("conv_shortcut")
|
322
|
+
)
|
323
|
+
else:
|
324
|
+
inputs = self.conv_shortcut(inputs)
|
307
325
|
|
308
326
|
hidden_states = hidden_states + inputs
|
309
|
-
return hidden_states
|
327
|
+
return hidden_states, new_conv_cache
|
310
328
|
|
311
329
|
|
312
330
|
class CogVideoXDownBlock3D(nn.Module):
|
@@ -392,9 +410,17 @@ class CogVideoXDownBlock3D(nn.Module):
|
|
392
410
|
hidden_states: torch.Tensor,
|
393
411
|
temb: Optional[torch.Tensor] = None,
|
394
412
|
zq: Optional[torch.Tensor] = None,
|
413
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
395
414
|
) -> torch.Tensor:
|
396
|
-
|
397
|
-
|
415
|
+
r"""Forward method of the `CogVideoXDownBlock3D` class."""
|
416
|
+
|
417
|
+
new_conv_cache = {}
|
418
|
+
conv_cache = conv_cache or {}
|
419
|
+
|
420
|
+
for i, resnet in enumerate(self.resnets):
|
421
|
+
conv_cache_key = f"resnet_{i}"
|
422
|
+
|
423
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
398
424
|
|
399
425
|
def create_custom_forward(module):
|
400
426
|
def create_forward(*inputs):
|
@@ -402,17 +428,23 @@ class CogVideoXDownBlock3D(nn.Module):
|
|
402
428
|
|
403
429
|
return create_forward
|
404
430
|
|
405
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
406
|
-
create_custom_forward(resnet),
|
431
|
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
432
|
+
create_custom_forward(resnet),
|
433
|
+
hidden_states,
|
434
|
+
temb,
|
435
|
+
zq,
|
436
|
+
conv_cache.get(conv_cache_key),
|
407
437
|
)
|
408
438
|
else:
|
409
|
-
hidden_states = resnet(
|
439
|
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
440
|
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
441
|
+
)
|
410
442
|
|
411
443
|
if self.downsamplers is not None:
|
412
444
|
for downsampler in self.downsamplers:
|
413
445
|
hidden_states = downsampler(hidden_states)
|
414
446
|
|
415
|
-
return hidden_states
|
447
|
+
return hidden_states, new_conv_cache
|
416
448
|
|
417
449
|
|
418
450
|
class CogVideoXMidBlock3D(nn.Module):
|
@@ -480,9 +512,17 @@ class CogVideoXMidBlock3D(nn.Module):
|
|
480
512
|
hidden_states: torch.Tensor,
|
481
513
|
temb: Optional[torch.Tensor] = None,
|
482
514
|
zq: Optional[torch.Tensor] = None,
|
515
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
483
516
|
) -> torch.Tensor:
|
484
|
-
|
485
|
-
|
517
|
+
r"""Forward method of the `CogVideoXMidBlock3D` class."""
|
518
|
+
|
519
|
+
new_conv_cache = {}
|
520
|
+
conv_cache = conv_cache or {}
|
521
|
+
|
522
|
+
for i, resnet in enumerate(self.resnets):
|
523
|
+
conv_cache_key = f"resnet_{i}"
|
524
|
+
|
525
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
486
526
|
|
487
527
|
def create_custom_forward(module):
|
488
528
|
def create_forward(*inputs):
|
@@ -490,13 +530,15 @@ class CogVideoXMidBlock3D(nn.Module):
|
|
490
530
|
|
491
531
|
return create_forward
|
492
532
|
|
493
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
494
|
-
create_custom_forward(resnet), hidden_states, temb, zq
|
533
|
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
534
|
+
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
|
495
535
|
)
|
496
536
|
else:
|
497
|
-
hidden_states = resnet(
|
537
|
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
538
|
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
539
|
+
)
|
498
540
|
|
499
|
-
return hidden_states
|
541
|
+
return hidden_states, new_conv_cache
|
500
542
|
|
501
543
|
|
502
544
|
class CogVideoXUpBlock3D(nn.Module):
|
@@ -584,10 +626,17 @@ class CogVideoXUpBlock3D(nn.Module):
|
|
584
626
|
hidden_states: torch.Tensor,
|
585
627
|
temb: Optional[torch.Tensor] = None,
|
586
628
|
zq: Optional[torch.Tensor] = None,
|
629
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
587
630
|
) -> torch.Tensor:
|
588
631
|
r"""Forward method of the `CogVideoXUpBlock3D` class."""
|
589
|
-
|
590
|
-
|
632
|
+
|
633
|
+
new_conv_cache = {}
|
634
|
+
conv_cache = conv_cache or {}
|
635
|
+
|
636
|
+
for i, resnet in enumerate(self.resnets):
|
637
|
+
conv_cache_key = f"resnet_{i}"
|
638
|
+
|
639
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
591
640
|
|
592
641
|
def create_custom_forward(module):
|
593
642
|
def create_forward(*inputs):
|
@@ -595,17 +644,23 @@ class CogVideoXUpBlock3D(nn.Module):
|
|
595
644
|
|
596
645
|
return create_forward
|
597
646
|
|
598
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
599
|
-
create_custom_forward(resnet),
|
647
|
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
648
|
+
create_custom_forward(resnet),
|
649
|
+
hidden_states,
|
650
|
+
temb,
|
651
|
+
zq,
|
652
|
+
conv_cache.get(conv_cache_key),
|
600
653
|
)
|
601
654
|
else:
|
602
|
-
hidden_states = resnet(
|
655
|
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
656
|
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
657
|
+
)
|
603
658
|
|
604
659
|
if self.upsamplers is not None:
|
605
660
|
for upsampler in self.upsamplers:
|
606
661
|
hidden_states = upsampler(hidden_states)
|
607
662
|
|
608
|
-
return hidden_states
|
663
|
+
return hidden_states, new_conv_cache
|
609
664
|
|
610
665
|
|
611
666
|
class CogVideoXEncoder3D(nn.Module):
|
@@ -705,11 +760,20 @@ class CogVideoXEncoder3D(nn.Module):
|
|
705
760
|
|
706
761
|
self.gradient_checkpointing = False
|
707
762
|
|
708
|
-
def forward(
|
763
|
+
def forward(
|
764
|
+
self,
|
765
|
+
sample: torch.Tensor,
|
766
|
+
temb: Optional[torch.Tensor] = None,
|
767
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
768
|
+
) -> torch.Tensor:
|
709
769
|
r"""The forward method of the `CogVideoXEncoder3D` class."""
|
710
|
-
hidden_states = self.conv_in(sample)
|
711
770
|
|
712
|
-
|
771
|
+
new_conv_cache = {}
|
772
|
+
conv_cache = conv_cache or {}
|
773
|
+
|
774
|
+
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
775
|
+
|
776
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
713
777
|
|
714
778
|
def create_custom_forward(module):
|
715
779
|
def custom_forward(*inputs):
|
@@ -718,28 +782,44 @@ class CogVideoXEncoder3D(nn.Module):
|
|
718
782
|
return custom_forward
|
719
783
|
|
720
784
|
# 1. Down
|
721
|
-
for down_block in self.down_blocks:
|
722
|
-
|
723
|
-
|
785
|
+
for i, down_block in enumerate(self.down_blocks):
|
786
|
+
conv_cache_key = f"down_block_{i}"
|
787
|
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
788
|
+
create_custom_forward(down_block),
|
789
|
+
hidden_states,
|
790
|
+
temb,
|
791
|
+
None,
|
792
|
+
conv_cache.get(conv_cache_key),
|
724
793
|
)
|
725
794
|
|
726
795
|
# 2. Mid
|
727
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
728
|
-
create_custom_forward(self.mid_block),
|
796
|
+
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
797
|
+
create_custom_forward(self.mid_block),
|
798
|
+
hidden_states,
|
799
|
+
temb,
|
800
|
+
None,
|
801
|
+
conv_cache.get("mid_block"),
|
729
802
|
)
|
730
803
|
else:
|
731
804
|
# 1. Down
|
732
|
-
for down_block in self.down_blocks:
|
733
|
-
|
805
|
+
for i, down_block in enumerate(self.down_blocks):
|
806
|
+
conv_cache_key = f"down_block_{i}"
|
807
|
+
hidden_states, new_conv_cache[conv_cache_key] = down_block(
|
808
|
+
hidden_states, temb, None, conv_cache.get(conv_cache_key)
|
809
|
+
)
|
734
810
|
|
735
811
|
# 2. Mid
|
736
|
-
hidden_states = self.mid_block(
|
812
|
+
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
813
|
+
hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
|
814
|
+
)
|
737
815
|
|
738
816
|
# 3. Post-process
|
739
817
|
hidden_states = self.norm_out(hidden_states)
|
740
818
|
hidden_states = self.conv_act(hidden_states)
|
741
|
-
|
742
|
-
|
819
|
+
|
820
|
+
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
821
|
+
|
822
|
+
return hidden_states, new_conv_cache
|
743
823
|
|
744
824
|
|
745
825
|
class CogVideoXDecoder3D(nn.Module):
|
@@ -846,11 +926,20 @@ class CogVideoXDecoder3D(nn.Module):
|
|
846
926
|
|
847
927
|
self.gradient_checkpointing = False
|
848
928
|
|
849
|
-
def forward(
|
929
|
+
def forward(
|
930
|
+
self,
|
931
|
+
sample: torch.Tensor,
|
932
|
+
temb: Optional[torch.Tensor] = None,
|
933
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
934
|
+
) -> torch.Tensor:
|
850
935
|
r"""The forward method of the `CogVideoXDecoder3D` class."""
|
851
|
-
hidden_states = self.conv_in(sample)
|
852
936
|
|
853
|
-
|
937
|
+
new_conv_cache = {}
|
938
|
+
conv_cache = conv_cache or {}
|
939
|
+
|
940
|
+
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
941
|
+
|
942
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
854
943
|
|
855
944
|
def create_custom_forward(module):
|
856
945
|
def custom_forward(*inputs):
|
@@ -859,28 +948,45 @@ class CogVideoXDecoder3D(nn.Module):
|
|
859
948
|
return custom_forward
|
860
949
|
|
861
950
|
# 1. Mid
|
862
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
863
|
-
create_custom_forward(self.mid_block),
|
951
|
+
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
952
|
+
create_custom_forward(self.mid_block),
|
953
|
+
hidden_states,
|
954
|
+
temb,
|
955
|
+
sample,
|
956
|
+
conv_cache.get("mid_block"),
|
864
957
|
)
|
865
958
|
|
866
959
|
# 2. Up
|
867
|
-
for up_block in self.up_blocks:
|
868
|
-
|
869
|
-
|
960
|
+
for i, up_block in enumerate(self.up_blocks):
|
961
|
+
conv_cache_key = f"up_block_{i}"
|
962
|
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
963
|
+
create_custom_forward(up_block),
|
964
|
+
hidden_states,
|
965
|
+
temb,
|
966
|
+
sample,
|
967
|
+
conv_cache.get(conv_cache_key),
|
870
968
|
)
|
871
969
|
else:
|
872
970
|
# 1. Mid
|
873
|
-
hidden_states = self.mid_block(
|
971
|
+
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
972
|
+
hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
|
973
|
+
)
|
874
974
|
|
875
975
|
# 2. Up
|
876
|
-
for up_block in self.up_blocks:
|
877
|
-
|
976
|
+
for i, up_block in enumerate(self.up_blocks):
|
977
|
+
conv_cache_key = f"up_block_{i}"
|
978
|
+
hidden_states, new_conv_cache[conv_cache_key] = up_block(
|
979
|
+
hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
|
980
|
+
)
|
878
981
|
|
879
982
|
# 3. Post-process
|
880
|
-
hidden_states = self.norm_out(
|
983
|
+
hidden_states, new_conv_cache["norm_out"] = self.norm_out(
|
984
|
+
hidden_states, sample, conv_cache=conv_cache.get("norm_out")
|
985
|
+
)
|
881
986
|
hidden_states = self.conv_act(hidden_states)
|
882
|
-
hidden_states = self.conv_out(hidden_states)
|
883
|
-
|
987
|
+
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
988
|
+
|
989
|
+
return hidden_states, new_conv_cache
|
884
990
|
|
885
991
|
|
886
992
|
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
@@ -951,6 +1057,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
951
1057
|
force_upcast: float = True,
|
952
1058
|
use_quant_conv: bool = False,
|
953
1059
|
use_post_quant_conv: bool = False,
|
1060
|
+
invert_scale_latents: bool = False,
|
954
1061
|
):
|
955
1062
|
super().__init__()
|
956
1063
|
|
@@ -1019,12 +1126,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1019
1126
|
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
1020
1127
|
module.gradient_checkpointing = value
|
1021
1128
|
|
1022
|
-
def _clear_fake_context_parallel_cache(self):
|
1023
|
-
for name, module in self.named_modules():
|
1024
|
-
if isinstance(module, CogVideoXCausalConv3d):
|
1025
|
-
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
|
1026
|
-
module._clear_fake_context_parallel_cache()
|
1027
|
-
|
1028
1129
|
def enable_tiling(
|
1029
1130
|
self,
|
1030
1131
|
tile_sample_min_height: Optional[int] = None,
|
@@ -1090,21 +1191,22 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1090
1191
|
|
1091
1192
|
frame_batch_size = self.num_sample_frames_batch_size
|
1092
1193
|
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
1093
|
-
|
1194
|
+
# As the extra single frame is handled inside the loop, it is not required to round up here.
|
1195
|
+
num_batches = max(num_frames // frame_batch_size, 1)
|
1196
|
+
conv_cache = None
|
1094
1197
|
enc = []
|
1198
|
+
|
1095
1199
|
for i in range(num_batches):
|
1096
1200
|
remaining_frames = num_frames % frame_batch_size
|
1097
1201
|
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
1098
1202
|
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
1099
1203
|
x_intermediate = x[:, :, start_frame:end_frame]
|
1100
|
-
x_intermediate = self.encoder(x_intermediate)
|
1204
|
+
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
|
1101
1205
|
if self.quant_conv is not None:
|
1102
1206
|
x_intermediate = self.quant_conv(x_intermediate)
|
1103
1207
|
enc.append(x_intermediate)
|
1104
1208
|
|
1105
|
-
self._clear_fake_context_parallel_cache()
|
1106
1209
|
enc = torch.cat(enc, dim=2)
|
1107
|
-
|
1108
1210
|
return enc
|
1109
1211
|
|
1110
1212
|
@apply_forward_hook
|
@@ -1142,8 +1244,10 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1142
1244
|
return self.tiled_decode(z, return_dict=return_dict)
|
1143
1245
|
|
1144
1246
|
frame_batch_size = self.num_latent_frames_batch_size
|
1145
|
-
num_batches = num_frames // frame_batch_size
|
1247
|
+
num_batches = max(num_frames // frame_batch_size, 1)
|
1248
|
+
conv_cache = None
|
1146
1249
|
dec = []
|
1250
|
+
|
1147
1251
|
for i in range(num_batches):
|
1148
1252
|
remaining_frames = num_frames % frame_batch_size
|
1149
1253
|
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
@@ -1151,10 +1255,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1151
1255
|
z_intermediate = z[:, :, start_frame:end_frame]
|
1152
1256
|
if self.post_quant_conv is not None:
|
1153
1257
|
z_intermediate = self.post_quant_conv(z_intermediate)
|
1154
|
-
z_intermediate = self.decoder(z_intermediate)
|
1258
|
+
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
1155
1259
|
dec.append(z_intermediate)
|
1156
1260
|
|
1157
|
-
self._clear_fake_context_parallel_cache()
|
1158
1261
|
dec = torch.cat(dec, dim=2)
|
1159
1262
|
|
1160
1263
|
if not return_dict:
|
@@ -1237,8 +1340,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1237
1340
|
row = []
|
1238
1341
|
for j in range(0, width, overlap_width):
|
1239
1342
|
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
1240
|
-
|
1343
|
+
# As the extra single frame is handled inside the loop, it is not required to round up here.
|
1344
|
+
num_batches = max(num_frames // frame_batch_size, 1)
|
1345
|
+
conv_cache = None
|
1241
1346
|
time = []
|
1347
|
+
|
1242
1348
|
for k in range(num_batches):
|
1243
1349
|
remaining_frames = num_frames % frame_batch_size
|
1244
1350
|
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
@@ -1250,11 +1356,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1250
1356
|
i : i + self.tile_sample_min_height,
|
1251
1357
|
j : j + self.tile_sample_min_width,
|
1252
1358
|
]
|
1253
|
-
tile = self.encoder(tile)
|
1359
|
+
tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
|
1254
1360
|
if self.quant_conv is not None:
|
1255
1361
|
tile = self.quant_conv(tile)
|
1256
1362
|
time.append(tile)
|
1257
|
-
|
1363
|
+
|
1258
1364
|
row.append(torch.cat(time, dim=2))
|
1259
1365
|
rows.append(row)
|
1260
1366
|
|
@@ -1314,8 +1420,10 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1314
1420
|
for i in range(0, height, overlap_height):
|
1315
1421
|
row = []
|
1316
1422
|
for j in range(0, width, overlap_width):
|
1317
|
-
num_batches = num_frames // frame_batch_size
|
1423
|
+
num_batches = max(num_frames // frame_batch_size, 1)
|
1424
|
+
conv_cache = None
|
1318
1425
|
time = []
|
1426
|
+
|
1319
1427
|
for k in range(num_batches):
|
1320
1428
|
remaining_frames = num_frames % frame_batch_size
|
1321
1429
|
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
@@ -1329,9 +1437,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1329
1437
|
]
|
1330
1438
|
if self.post_quant_conv is not None:
|
1331
1439
|
tile = self.post_quant_conv(tile)
|
1332
|
-
tile = self.decoder(tile)
|
1440
|
+
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
|
1333
1441
|
time.append(tile)
|
1334
|
-
|
1442
|
+
|
1335
1443
|
row.append(torch.cat(time, dim=2))
|
1336
1444
|
rows.append(row)
|
1337
1445
|
|
@@ -1368,7 +1476,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1368
1476
|
z = posterior.sample(generator=generator)
|
1369
1477
|
else:
|
1370
1478
|
z = posterior.mode()
|
1371
|
-
dec = self.decode(z)
|
1479
|
+
dec = self.decode(z).sample
|
1372
1480
|
if not return_dict:
|
1373
1481
|
return (dec,)
|
1374
|
-
return dec
|
1482
|
+
return DecoderOutput(sample=dec)
|