diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +38 -2
- diffusers/configuration_utils.py +12 -0
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +257 -54
- diffusers/loaders/__init__.py +2 -0
- diffusers/loaders/ip_adapter.py +5 -1
- diffusers/loaders/lora_base.py +14 -7
- diffusers/loaders/lora_conversion_utils.py +332 -0
- diffusers/loaders/lora_pipeline.py +707 -41
- diffusers/loaders/peft.py +1 -0
- diffusers/loaders/single_file_utils.py +81 -4
- diffusers/loaders/textual_inversion.py +2 -0
- diffusers/loaders/unet.py +39 -8
- diffusers/models/__init__.py +4 -0
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +86 -10
- diffusers/models/attention_processor.py +169 -133
- diffusers/models/autoencoders/autoencoder_kl.py +71 -11
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
- diffusers/models/controlnet_flux.py +536 -0
- diffusers/models/controlnet_sd3.py +7 -3
- diffusers/models/controlnet_sparsectrl.py +0 -1
- diffusers/models/embeddings.py +238 -61
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +182 -14
- diffusers/models/modeling_utils.py +283 -46
- diffusers/models/normalization.py +79 -0
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
- diffusers/models/transformers/pixart_transformer_2d.py +9 -1
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +161 -44
- diffusers/models/transformers/transformer_sd3.py +7 -1
- diffusers/models/unets/unet_2d_condition.py +8 -8
- diffusers/models/unets/unet_motion_model.py +41 -63
- diffusers/models/upsampling.py +6 -6
- diffusers/pipelines/__init__.py +40 -7
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
- diffusers/pipelines/auto_pipeline.py +39 -8
- diffusers/pipelines/cogvideo/__init__.py +6 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +10 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -20
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
- diffusers/pipelines/pag/__init__.py +6 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_loading_utils.py +225 -27
- diffusers/pipelines/pipeline_utils.py +123 -180
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +126 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/quantization_config.py +391 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +4 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
- diffusers/schedulers/scheduling_deis_multistep.py +78 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_sasolver.py +78 -1
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
- diffusers/training_utils.py +48 -18
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
- diffusers/utils/hub_utils.py +16 -4
- diffusers/utils/import_utils.py +31 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +3 -3
- diffusers/utils/testing_utils.py +59 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
-
from typing import Optional, Tuple, Union
|
16
|
+
from typing import Dict, Optional, Tuple, Union
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
import torch
|
@@ -41,7 +41,9 @@ class CogVideoXSafeConv3d(nn.Conv3d):
|
|
41
41
|
"""
|
42
42
|
|
43
43
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
44
|
-
memory_count =
|
44
|
+
memory_count = (
|
45
|
+
(input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
|
46
|
+
)
|
45
47
|
|
46
48
|
# Set to 2GB, suitable for CuDNN
|
47
49
|
if memory_count > 2:
|
@@ -115,34 +117,24 @@ class CogVideoXCausalConv3d(nn.Module):
|
|
115
117
|
dilation=dilation,
|
116
118
|
)
|
117
119
|
|
118
|
-
|
119
|
-
|
120
|
-
|
120
|
+
def fake_context_parallel_forward(
|
121
|
+
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
|
122
|
+
) -> torch.Tensor:
|
121
123
|
kernel_size = self.time_kernel_size
|
122
124
|
if kernel_size > 1:
|
123
|
-
cached_inputs = (
|
124
|
-
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
125
|
-
)
|
125
|
+
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
126
126
|
inputs = torch.cat(cached_inputs + [inputs], dim=2)
|
127
127
|
return inputs
|
128
128
|
|
129
|
-
def
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
134
|
-
inputs = self.fake_context_parallel_forward(inputs)
|
135
|
-
|
136
|
-
self._clear_fake_context_parallel_cache()
|
137
|
-
# Note: we could move these to the cpu for a lower maximum memory usage but its only a few
|
138
|
-
# hundred megabytes and so let's not do it for now
|
139
|
-
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
129
|
+
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
|
130
|
+
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
|
131
|
+
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
140
132
|
|
141
133
|
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
142
134
|
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
143
135
|
|
144
136
|
output = self.conv(inputs)
|
145
|
-
return output
|
137
|
+
return output, conv_cache
|
146
138
|
|
147
139
|
|
148
140
|
class CogVideoXSpatialNorm3D(nn.Module):
|
@@ -172,7 +164,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
|
172
164
|
self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
173
165
|
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
174
166
|
|
175
|
-
def forward(
|
167
|
+
def forward(
|
168
|
+
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
|
169
|
+
) -> torch.Tensor:
|
170
|
+
new_conv_cache = {}
|
171
|
+
conv_cache = conv_cache or {}
|
172
|
+
|
176
173
|
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
177
174
|
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
178
175
|
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
@@ -183,9 +180,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
|
183
180
|
else:
|
184
181
|
zq = F.interpolate(zq, size=f.shape[-3:])
|
185
182
|
|
183
|
+
conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
|
184
|
+
conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
|
185
|
+
|
186
186
|
norm_f = self.norm_layer(f)
|
187
|
-
new_f = norm_f *
|
188
|
-
return new_f
|
187
|
+
new_f = norm_f * conv_y + conv_b
|
188
|
+
return new_f, new_conv_cache
|
189
189
|
|
190
190
|
|
191
191
|
class CogVideoXResnetBlock3D(nn.Module):
|
@@ -236,6 +236,7 @@ class CogVideoXResnetBlock3D(nn.Module):
|
|
236
236
|
self.out_channels = out_channels
|
237
237
|
self.nonlinearity = get_activation(non_linearity)
|
238
238
|
self.use_conv_shortcut = conv_shortcut
|
239
|
+
self.spatial_norm_dim = spatial_norm_dim
|
239
240
|
|
240
241
|
if spatial_norm_dim is None:
|
241
242
|
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
@@ -279,34 +280,43 @@ class CogVideoXResnetBlock3D(nn.Module):
|
|
279
280
|
inputs: torch.Tensor,
|
280
281
|
temb: Optional[torch.Tensor] = None,
|
281
282
|
zq: Optional[torch.Tensor] = None,
|
283
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
282
284
|
) -> torch.Tensor:
|
285
|
+
new_conv_cache = {}
|
286
|
+
conv_cache = conv_cache or {}
|
287
|
+
|
283
288
|
hidden_states = inputs
|
284
289
|
|
285
290
|
if zq is not None:
|
286
|
-
hidden_states = self.norm1(hidden_states, zq)
|
291
|
+
hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
|
287
292
|
else:
|
288
293
|
hidden_states = self.norm1(hidden_states)
|
289
294
|
|
290
295
|
hidden_states = self.nonlinearity(hidden_states)
|
291
|
-
hidden_states = self.conv1(hidden_states)
|
296
|
+
hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
|
292
297
|
|
293
298
|
if temb is not None:
|
294
299
|
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
295
300
|
|
296
301
|
if zq is not None:
|
297
|
-
hidden_states = self.norm2(hidden_states, zq)
|
302
|
+
hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
|
298
303
|
else:
|
299
304
|
hidden_states = self.norm2(hidden_states)
|
300
305
|
|
301
306
|
hidden_states = self.nonlinearity(hidden_states)
|
302
307
|
hidden_states = self.dropout(hidden_states)
|
303
|
-
hidden_states = self.conv2(hidden_states)
|
308
|
+
hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
|
304
309
|
|
305
310
|
if self.in_channels != self.out_channels:
|
306
|
-
|
311
|
+
if self.use_conv_shortcut:
|
312
|
+
inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
|
313
|
+
inputs, conv_cache=conv_cache.get("conv_shortcut")
|
314
|
+
)
|
315
|
+
else:
|
316
|
+
inputs = self.conv_shortcut(inputs)
|
307
317
|
|
308
318
|
hidden_states = hidden_states + inputs
|
309
|
-
return hidden_states
|
319
|
+
return hidden_states, new_conv_cache
|
310
320
|
|
311
321
|
|
312
322
|
class CogVideoXDownBlock3D(nn.Module):
|
@@ -392,8 +402,16 @@ class CogVideoXDownBlock3D(nn.Module):
|
|
392
402
|
hidden_states: torch.Tensor,
|
393
403
|
temb: Optional[torch.Tensor] = None,
|
394
404
|
zq: Optional[torch.Tensor] = None,
|
405
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
395
406
|
) -> torch.Tensor:
|
396
|
-
|
407
|
+
r"""Forward method of the `CogVideoXDownBlock3D` class."""
|
408
|
+
|
409
|
+
new_conv_cache = {}
|
410
|
+
conv_cache = conv_cache or {}
|
411
|
+
|
412
|
+
for i, resnet in enumerate(self.resnets):
|
413
|
+
conv_cache_key = f"resnet_{i}"
|
414
|
+
|
397
415
|
if self.training and self.gradient_checkpointing:
|
398
416
|
|
399
417
|
def create_custom_forward(module):
|
@@ -402,17 +420,23 @@ class CogVideoXDownBlock3D(nn.Module):
|
|
402
420
|
|
403
421
|
return create_forward
|
404
422
|
|
405
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
406
|
-
create_custom_forward(resnet),
|
423
|
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
424
|
+
create_custom_forward(resnet),
|
425
|
+
hidden_states,
|
426
|
+
temb,
|
427
|
+
zq,
|
428
|
+
conv_cache=conv_cache.get(conv_cache_key),
|
407
429
|
)
|
408
430
|
else:
|
409
|
-
hidden_states = resnet(
|
431
|
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
432
|
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
433
|
+
)
|
410
434
|
|
411
435
|
if self.downsamplers is not None:
|
412
436
|
for downsampler in self.downsamplers:
|
413
437
|
hidden_states = downsampler(hidden_states)
|
414
438
|
|
415
|
-
return hidden_states
|
439
|
+
return hidden_states, new_conv_cache
|
416
440
|
|
417
441
|
|
418
442
|
class CogVideoXMidBlock3D(nn.Module):
|
@@ -480,8 +504,16 @@ class CogVideoXMidBlock3D(nn.Module):
|
|
480
504
|
hidden_states: torch.Tensor,
|
481
505
|
temb: Optional[torch.Tensor] = None,
|
482
506
|
zq: Optional[torch.Tensor] = None,
|
507
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
483
508
|
) -> torch.Tensor:
|
484
|
-
|
509
|
+
r"""Forward method of the `CogVideoXMidBlock3D` class."""
|
510
|
+
|
511
|
+
new_conv_cache = {}
|
512
|
+
conv_cache = conv_cache or {}
|
513
|
+
|
514
|
+
for i, resnet in enumerate(self.resnets):
|
515
|
+
conv_cache_key = f"resnet_{i}"
|
516
|
+
|
485
517
|
if self.training and self.gradient_checkpointing:
|
486
518
|
|
487
519
|
def create_custom_forward(module):
|
@@ -490,13 +522,15 @@ class CogVideoXMidBlock3D(nn.Module):
|
|
490
522
|
|
491
523
|
return create_forward
|
492
524
|
|
493
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
494
|
-
create_custom_forward(resnet), hidden_states, temb, zq
|
525
|
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
526
|
+
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
495
527
|
)
|
496
528
|
else:
|
497
|
-
hidden_states = resnet(
|
529
|
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
530
|
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
531
|
+
)
|
498
532
|
|
499
|
-
return hidden_states
|
533
|
+
return hidden_states, new_conv_cache
|
500
534
|
|
501
535
|
|
502
536
|
class CogVideoXUpBlock3D(nn.Module):
|
@@ -584,9 +618,16 @@ class CogVideoXUpBlock3D(nn.Module):
|
|
584
618
|
hidden_states: torch.Tensor,
|
585
619
|
temb: Optional[torch.Tensor] = None,
|
586
620
|
zq: Optional[torch.Tensor] = None,
|
621
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
587
622
|
) -> torch.Tensor:
|
588
623
|
r"""Forward method of the `CogVideoXUpBlock3D` class."""
|
589
|
-
|
624
|
+
|
625
|
+
new_conv_cache = {}
|
626
|
+
conv_cache = conv_cache or {}
|
627
|
+
|
628
|
+
for i, resnet in enumerate(self.resnets):
|
629
|
+
conv_cache_key = f"resnet_{i}"
|
630
|
+
|
590
631
|
if self.training and self.gradient_checkpointing:
|
591
632
|
|
592
633
|
def create_custom_forward(module):
|
@@ -595,17 +636,23 @@ class CogVideoXUpBlock3D(nn.Module):
|
|
595
636
|
|
596
637
|
return create_forward
|
597
638
|
|
598
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
599
|
-
create_custom_forward(resnet),
|
639
|
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
640
|
+
create_custom_forward(resnet),
|
641
|
+
hidden_states,
|
642
|
+
temb,
|
643
|
+
zq,
|
644
|
+
conv_cache=conv_cache.get(conv_cache_key),
|
600
645
|
)
|
601
646
|
else:
|
602
|
-
hidden_states = resnet(
|
647
|
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
648
|
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
649
|
+
)
|
603
650
|
|
604
651
|
if self.upsamplers is not None:
|
605
652
|
for upsampler in self.upsamplers:
|
606
653
|
hidden_states = upsampler(hidden_states)
|
607
654
|
|
608
|
-
return hidden_states
|
655
|
+
return hidden_states, new_conv_cache
|
609
656
|
|
610
657
|
|
611
658
|
class CogVideoXEncoder3D(nn.Module):
|
@@ -705,9 +752,18 @@ class CogVideoXEncoder3D(nn.Module):
|
|
705
752
|
|
706
753
|
self.gradient_checkpointing = False
|
707
754
|
|
708
|
-
def forward(
|
755
|
+
def forward(
|
756
|
+
self,
|
757
|
+
sample: torch.Tensor,
|
758
|
+
temb: Optional[torch.Tensor] = None,
|
759
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
760
|
+
) -> torch.Tensor:
|
709
761
|
r"""The forward method of the `CogVideoXEncoder3D` class."""
|
710
|
-
|
762
|
+
|
763
|
+
new_conv_cache = {}
|
764
|
+
conv_cache = conv_cache or {}
|
765
|
+
|
766
|
+
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
711
767
|
|
712
768
|
if self.training and self.gradient_checkpointing:
|
713
769
|
|
@@ -718,28 +774,44 @@ class CogVideoXEncoder3D(nn.Module):
|
|
718
774
|
return custom_forward
|
719
775
|
|
720
776
|
# 1. Down
|
721
|
-
for down_block in self.down_blocks:
|
722
|
-
|
723
|
-
|
777
|
+
for i, down_block in enumerate(self.down_blocks):
|
778
|
+
conv_cache_key = f"down_block_{i}"
|
779
|
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
780
|
+
create_custom_forward(down_block),
|
781
|
+
hidden_states,
|
782
|
+
temb,
|
783
|
+
None,
|
784
|
+
conv_cache=conv_cache.get(conv_cache_key),
|
724
785
|
)
|
725
786
|
|
726
787
|
# 2. Mid
|
727
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
728
|
-
create_custom_forward(self.mid_block),
|
788
|
+
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
789
|
+
create_custom_forward(self.mid_block),
|
790
|
+
hidden_states,
|
791
|
+
temb,
|
792
|
+
None,
|
793
|
+
conv_cache=conv_cache.get("mid_block"),
|
729
794
|
)
|
730
795
|
else:
|
731
796
|
# 1. Down
|
732
|
-
for down_block in self.down_blocks:
|
733
|
-
|
797
|
+
for i, down_block in enumerate(self.down_blocks):
|
798
|
+
conv_cache_key = f"down_block_{i}"
|
799
|
+
hidden_states, new_conv_cache[conv_cache_key] = down_block(
|
800
|
+
hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
|
801
|
+
)
|
734
802
|
|
735
803
|
# 2. Mid
|
736
|
-
hidden_states = self.mid_block(
|
804
|
+
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
805
|
+
hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
|
806
|
+
)
|
737
807
|
|
738
808
|
# 3. Post-process
|
739
809
|
hidden_states = self.norm_out(hidden_states)
|
740
810
|
hidden_states = self.conv_act(hidden_states)
|
741
|
-
|
742
|
-
|
811
|
+
|
812
|
+
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
813
|
+
|
814
|
+
return hidden_states, new_conv_cache
|
743
815
|
|
744
816
|
|
745
817
|
class CogVideoXDecoder3D(nn.Module):
|
@@ -846,9 +918,18 @@ class CogVideoXDecoder3D(nn.Module):
|
|
846
918
|
|
847
919
|
self.gradient_checkpointing = False
|
848
920
|
|
849
|
-
def forward(
|
921
|
+
def forward(
|
922
|
+
self,
|
923
|
+
sample: torch.Tensor,
|
924
|
+
temb: Optional[torch.Tensor] = None,
|
925
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
926
|
+
) -> torch.Tensor:
|
850
927
|
r"""The forward method of the `CogVideoXDecoder3D` class."""
|
851
|
-
|
928
|
+
|
929
|
+
new_conv_cache = {}
|
930
|
+
conv_cache = conv_cache or {}
|
931
|
+
|
932
|
+
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
852
933
|
|
853
934
|
if self.training and self.gradient_checkpointing:
|
854
935
|
|
@@ -859,28 +940,45 @@ class CogVideoXDecoder3D(nn.Module):
|
|
859
940
|
return custom_forward
|
860
941
|
|
861
942
|
# 1. Mid
|
862
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
863
|
-
create_custom_forward(self.mid_block),
|
943
|
+
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
944
|
+
create_custom_forward(self.mid_block),
|
945
|
+
hidden_states,
|
946
|
+
temb,
|
947
|
+
sample,
|
948
|
+
conv_cache=conv_cache.get("mid_block"),
|
864
949
|
)
|
865
950
|
|
866
951
|
# 2. Up
|
867
|
-
for up_block in self.up_blocks:
|
868
|
-
|
869
|
-
|
952
|
+
for i, up_block in enumerate(self.up_blocks):
|
953
|
+
conv_cache_key = f"up_block_{i}"
|
954
|
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
955
|
+
create_custom_forward(up_block),
|
956
|
+
hidden_states,
|
957
|
+
temb,
|
958
|
+
sample,
|
959
|
+
conv_cache=conv_cache.get(conv_cache_key),
|
870
960
|
)
|
871
961
|
else:
|
872
962
|
# 1. Mid
|
873
|
-
hidden_states = self.mid_block(
|
963
|
+
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
964
|
+
hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
|
965
|
+
)
|
874
966
|
|
875
967
|
# 2. Up
|
876
|
-
for up_block in self.up_blocks:
|
877
|
-
|
968
|
+
for i, up_block in enumerate(self.up_blocks):
|
969
|
+
conv_cache_key = f"up_block_{i}"
|
970
|
+
hidden_states, new_conv_cache[conv_cache_key] = up_block(
|
971
|
+
hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
|
972
|
+
)
|
878
973
|
|
879
974
|
# 3. Post-process
|
880
|
-
hidden_states = self.norm_out(
|
975
|
+
hidden_states, new_conv_cache["norm_out"] = self.norm_out(
|
976
|
+
hidden_states, sample, conv_cache=conv_cache.get("norm_out")
|
977
|
+
)
|
881
978
|
hidden_states = self.conv_act(hidden_states)
|
882
|
-
hidden_states = self.conv_out(hidden_states)
|
883
|
-
|
979
|
+
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
980
|
+
|
981
|
+
return hidden_states, new_conv_cache
|
884
982
|
|
885
983
|
|
886
984
|
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
@@ -999,6 +1097,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
999
1097
|
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
|
1000
1098
|
# number of temporal frames.
|
1001
1099
|
self.num_latent_frames_batch_size = 2
|
1100
|
+
self.num_sample_frames_batch_size = 8
|
1002
1101
|
|
1003
1102
|
# We make the minimum height and width of sample for tiling half that of the generally supported
|
1004
1103
|
self.tile_sample_min_height = sample_height // 2
|
@@ -1018,12 +1117,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1018
1117
|
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
1019
1118
|
module.gradient_checkpointing = value
|
1020
1119
|
|
1021
|
-
def _clear_fake_context_parallel_cache(self):
|
1022
|
-
for name, module in self.named_modules():
|
1023
|
-
if isinstance(module, CogVideoXCausalConv3d):
|
1024
|
-
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
|
1025
|
-
module._clear_fake_context_parallel_cache()
|
1026
|
-
|
1027
1120
|
def enable_tiling(
|
1028
1121
|
self,
|
1029
1122
|
tile_sample_min_height: Optional[int] = None,
|
@@ -1081,6 +1174,32 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1081
1174
|
"""
|
1082
1175
|
self.use_slicing = False
|
1083
1176
|
|
1177
|
+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
1178
|
+
batch_size, num_channels, num_frames, height, width = x.shape
|
1179
|
+
|
1180
|
+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
1181
|
+
return self.tiled_encode(x)
|
1182
|
+
|
1183
|
+
frame_batch_size = self.num_sample_frames_batch_size
|
1184
|
+
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
1185
|
+
# As the extra single frame is handled inside the loop, it is not required to round up here.
|
1186
|
+
num_batches = max(num_frames // frame_batch_size, 1)
|
1187
|
+
conv_cache = None
|
1188
|
+
enc = []
|
1189
|
+
|
1190
|
+
for i in range(num_batches):
|
1191
|
+
remaining_frames = num_frames % frame_batch_size
|
1192
|
+
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
1193
|
+
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
1194
|
+
x_intermediate = x[:, :, start_frame:end_frame]
|
1195
|
+
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
|
1196
|
+
if self.quant_conv is not None:
|
1197
|
+
x_intermediate = self.quant_conv(x_intermediate)
|
1198
|
+
enc.append(x_intermediate)
|
1199
|
+
|
1200
|
+
enc = torch.cat(enc, dim=2)
|
1201
|
+
return enc
|
1202
|
+
|
1084
1203
|
@apply_forward_hook
|
1085
1204
|
def encode(
|
1086
1205
|
self, x: torch.Tensor, return_dict: bool = True
|
@@ -1094,13 +1213,17 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1094
1213
|
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
1095
1214
|
|
1096
1215
|
Returns:
|
1097
|
-
The latent representations of the encoded
|
1216
|
+
The latent representations of the encoded videos. If `return_dict` is True, a
|
1098
1217
|
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
1099
1218
|
"""
|
1100
|
-
|
1101
|
-
|
1102
|
-
h =
|
1219
|
+
if self.use_slicing and x.shape[0] > 1:
|
1220
|
+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
1221
|
+
h = torch.cat(encoded_slices)
|
1222
|
+
else:
|
1223
|
+
h = self._encode(x)
|
1224
|
+
|
1103
1225
|
posterior = DiagonalGaussianDistribution(h)
|
1226
|
+
|
1104
1227
|
if not return_dict:
|
1105
1228
|
return (posterior,)
|
1106
1229
|
return AutoencoderKLOutput(latent_dist=posterior)
|
@@ -1112,18 +1235,20 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1112
1235
|
return self.tiled_decode(z, return_dict=return_dict)
|
1113
1236
|
|
1114
1237
|
frame_batch_size = self.num_latent_frames_batch_size
|
1238
|
+
num_batches = max(num_frames // frame_batch_size, 1)
|
1239
|
+
conv_cache = None
|
1115
1240
|
dec = []
|
1116
|
-
|
1241
|
+
|
1242
|
+
for i in range(num_batches):
|
1117
1243
|
remaining_frames = num_frames % frame_batch_size
|
1118
1244
|
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
1119
1245
|
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
1120
1246
|
z_intermediate = z[:, :, start_frame:end_frame]
|
1121
1247
|
if self.post_quant_conv is not None:
|
1122
1248
|
z_intermediate = self.post_quant_conv(z_intermediate)
|
1123
|
-
z_intermediate = self.decoder(z_intermediate)
|
1249
|
+
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
1124
1250
|
dec.append(z_intermediate)
|
1125
1251
|
|
1126
|
-
self._clear_fake_context_parallel_cache()
|
1127
1252
|
dec = torch.cat(dec, dim=2)
|
1128
1253
|
|
1129
1254
|
if not return_dict:
|
@@ -1172,6 +1297,80 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1172
1297
|
)
|
1173
1298
|
return b
|
1174
1299
|
|
1300
|
+
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
1301
|
+
r"""Encode a batch of images using a tiled encoder.
|
1302
|
+
|
1303
|
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
1304
|
+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
1305
|
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
1306
|
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
1307
|
+
output, but they should be much less noticeable.
|
1308
|
+
|
1309
|
+
Args:
|
1310
|
+
x (`torch.Tensor`): Input batch of videos.
|
1311
|
+
|
1312
|
+
Returns:
|
1313
|
+
`torch.Tensor`:
|
1314
|
+
The latent representation of the encoded videos.
|
1315
|
+
"""
|
1316
|
+
# For a rough memory estimate, take a look at the `tiled_decode` method.
|
1317
|
+
batch_size, num_channels, num_frames, height, width = x.shape
|
1318
|
+
|
1319
|
+
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
|
1320
|
+
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
|
1321
|
+
blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
|
1322
|
+
blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
|
1323
|
+
row_limit_height = self.tile_latent_min_height - blend_extent_height
|
1324
|
+
row_limit_width = self.tile_latent_min_width - blend_extent_width
|
1325
|
+
frame_batch_size = self.num_sample_frames_batch_size
|
1326
|
+
|
1327
|
+
# Split x into overlapping tiles and encode them separately.
|
1328
|
+
# The tiles have an overlap to avoid seams between tiles.
|
1329
|
+
rows = []
|
1330
|
+
for i in range(0, height, overlap_height):
|
1331
|
+
row = []
|
1332
|
+
for j in range(0, width, overlap_width):
|
1333
|
+
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
1334
|
+
# As the extra single frame is handled inside the loop, it is not required to round up here.
|
1335
|
+
num_batches = max(num_frames // frame_batch_size, 1)
|
1336
|
+
conv_cache = None
|
1337
|
+
time = []
|
1338
|
+
|
1339
|
+
for k in range(num_batches):
|
1340
|
+
remaining_frames = num_frames % frame_batch_size
|
1341
|
+
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
1342
|
+
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
1343
|
+
tile = x[
|
1344
|
+
:,
|
1345
|
+
:,
|
1346
|
+
start_frame:end_frame,
|
1347
|
+
i : i + self.tile_sample_min_height,
|
1348
|
+
j : j + self.tile_sample_min_width,
|
1349
|
+
]
|
1350
|
+
tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
|
1351
|
+
if self.quant_conv is not None:
|
1352
|
+
tile = self.quant_conv(tile)
|
1353
|
+
time.append(tile)
|
1354
|
+
|
1355
|
+
row.append(torch.cat(time, dim=2))
|
1356
|
+
rows.append(row)
|
1357
|
+
|
1358
|
+
result_rows = []
|
1359
|
+
for i, row in enumerate(rows):
|
1360
|
+
result_row = []
|
1361
|
+
for j, tile in enumerate(row):
|
1362
|
+
# blend the above tile and the left tile
|
1363
|
+
# to the current tile and add the current tile to the result row
|
1364
|
+
if i > 0:
|
1365
|
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
|
1366
|
+
if j > 0:
|
1367
|
+
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
|
1368
|
+
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
1369
|
+
result_rows.append(torch.cat(result_row, dim=4))
|
1370
|
+
|
1371
|
+
enc = torch.cat(result_rows, dim=3)
|
1372
|
+
return enc
|
1373
|
+
|
1175
1374
|
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
1176
1375
|
r"""
|
1177
1376
|
Decode a batch of images using a tiled decoder.
|
@@ -1212,8 +1411,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1212
1411
|
for i in range(0, height, overlap_height):
|
1213
1412
|
row = []
|
1214
1413
|
for j in range(0, width, overlap_width):
|
1414
|
+
num_batches = max(num_frames // frame_batch_size, 1)
|
1415
|
+
conv_cache = None
|
1215
1416
|
time = []
|
1216
|
-
|
1417
|
+
|
1418
|
+
for k in range(num_batches):
|
1217
1419
|
remaining_frames = num_frames % frame_batch_size
|
1218
1420
|
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
1219
1421
|
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
@@ -1226,9 +1428,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1226
1428
|
]
|
1227
1429
|
if self.post_quant_conv is not None:
|
1228
1430
|
tile = self.post_quant_conv(tile)
|
1229
|
-
tile = self.decoder(tile)
|
1431
|
+
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
|
1230
1432
|
time.append(tile)
|
1231
|
-
|
1433
|
+
|
1232
1434
|
row.append(torch.cat(time, dim=2))
|
1233
1435
|
rows.append(row)
|
1234
1436
|
|