diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +34 -2
- diffusers/configuration_utils.py +12 -0
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +257 -54
- diffusers/loaders/__init__.py +2 -0
- diffusers/loaders/ip_adapter.py +5 -1
- diffusers/loaders/lora_base.py +14 -7
- diffusers/loaders/lora_conversion_utils.py +332 -0
- diffusers/loaders/lora_pipeline.py +707 -41
- diffusers/loaders/peft.py +1 -0
- diffusers/loaders/single_file_utils.py +81 -4
- diffusers/loaders/textual_inversion.py +2 -0
- diffusers/loaders/unet.py +39 -8
- diffusers/models/__init__.py +4 -0
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +86 -10
- diffusers/models/attention_processor.py +169 -133
- diffusers/models/autoencoders/autoencoder_kl.py +71 -11
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
- diffusers/models/controlnet_flux.py +536 -0
- diffusers/models/controlnet_sd3.py +7 -3
- diffusers/models/controlnet_sparsectrl.py +0 -1
- diffusers/models/embeddings.py +170 -61
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +182 -14
- diffusers/models/modeling_utils.py +283 -46
- diffusers/models/normalization.py +79 -0
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
- diffusers/models/transformers/pixart_transformer_2d.py +9 -1
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +161 -44
- diffusers/models/transformers/transformer_sd3.py +7 -1
- diffusers/models/unets/unet_2d_condition.py +8 -8
- diffusers/models/unets/unet_motion_model.py +41 -63
- diffusers/models/upsampling.py +6 -6
- diffusers/pipelines/__init__.py +35 -6
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
- diffusers/pipelines/auto_pipeline.py +39 -8
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +10 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -20
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
- diffusers/pipelines/pag/__init__.py +6 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_loading_utils.py +225 -27
- diffusers/pipelines/pipeline_utils.py +123 -180
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +126 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/quantization_config.py +391 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +4 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
- diffusers/schedulers/scheduling_deis_multistep.py +78 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_sasolver.py +78 -1
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
- diffusers/training_utils.py +48 -18
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
- diffusers/utils/hub_utils.py +16 -4
- diffusers/utils/import_utils.py +31 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +3 -3
- diffusers/utils/testing_utils.py +59 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ import torch.nn as nn
|
|
18
18
|
|
19
19
|
from ...configuration_utils import ConfigMixin, register_to_config
|
20
20
|
from ...loaders.single_file_model import FromOriginalModelMixin
|
21
|
+
from ...utils import deprecate
|
21
22
|
from ...utils.accelerate_utils import apply_forward_hook
|
22
23
|
from ..attention_processor import (
|
23
24
|
ADDED_KV_ATTENTION_PROCESSORS,
|
@@ -245,6 +246,18 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
245
246
|
|
246
247
|
self.set_attn_processor(processor)
|
247
248
|
|
249
|
+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
250
|
+
batch_size, num_channels, height, width = x.shape
|
251
|
+
|
252
|
+
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
|
253
|
+
return self._tiled_encode(x)
|
254
|
+
|
255
|
+
enc = self.encoder(x)
|
256
|
+
if self.quant_conv is not None:
|
257
|
+
enc = self.quant_conv(enc)
|
258
|
+
|
259
|
+
return enc
|
260
|
+
|
248
261
|
@apply_forward_hook
|
249
262
|
def encode(
|
250
263
|
self, x: torch.Tensor, return_dict: bool = True
|
@@ -261,21 +274,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
261
274
|
The latent representations of the encoded images. If `return_dict` is True, a
|
262
275
|
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
263
276
|
"""
|
264
|
-
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
265
|
-
return self.tiled_encode(x, return_dict=return_dict)
|
266
|
-
|
267
277
|
if self.use_slicing and x.shape[0] > 1:
|
268
|
-
encoded_slices = [self.
|
278
|
+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
269
279
|
h = torch.cat(encoded_slices)
|
270
280
|
else:
|
271
|
-
h = self.
|
272
|
-
|
273
|
-
if self.quant_conv is not None:
|
274
|
-
moments = self.quant_conv(h)
|
275
|
-
else:
|
276
|
-
moments = h
|
281
|
+
h = self._encode(x)
|
277
282
|
|
278
|
-
posterior = DiagonalGaussianDistribution(
|
283
|
+
posterior = DiagonalGaussianDistribution(h)
|
279
284
|
|
280
285
|
if not return_dict:
|
281
286
|
return (posterior,)
|
@@ -337,6 +342,54 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
337
342
|
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
338
343
|
return b
|
339
344
|
|
345
|
+
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
346
|
+
r"""Encode a batch of images using a tiled encoder.
|
347
|
+
|
348
|
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
349
|
+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
350
|
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
351
|
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
352
|
+
output, but they should be much less noticeable.
|
353
|
+
|
354
|
+
Args:
|
355
|
+
x (`torch.Tensor`): Input batch of images.
|
356
|
+
|
357
|
+
Returns:
|
358
|
+
`torch.Tensor`:
|
359
|
+
The latent representation of the encoded videos.
|
360
|
+
"""
|
361
|
+
|
362
|
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
363
|
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
364
|
+
row_limit = self.tile_latent_min_size - blend_extent
|
365
|
+
|
366
|
+
# Split the image into 512x512 tiles and encode them separately.
|
367
|
+
rows = []
|
368
|
+
for i in range(0, x.shape[2], overlap_size):
|
369
|
+
row = []
|
370
|
+
for j in range(0, x.shape[3], overlap_size):
|
371
|
+
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
372
|
+
tile = self.encoder(tile)
|
373
|
+
if self.config.use_quant_conv:
|
374
|
+
tile = self.quant_conv(tile)
|
375
|
+
row.append(tile)
|
376
|
+
rows.append(row)
|
377
|
+
result_rows = []
|
378
|
+
for i, row in enumerate(rows):
|
379
|
+
result_row = []
|
380
|
+
for j, tile in enumerate(row):
|
381
|
+
# blend the above tile and the left tile
|
382
|
+
# to the current tile and add the current tile to the result row
|
383
|
+
if i > 0:
|
384
|
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
385
|
+
if j > 0:
|
386
|
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
387
|
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
388
|
+
result_rows.append(torch.cat(result_row, dim=3))
|
389
|
+
|
390
|
+
enc = torch.cat(result_rows, dim=2)
|
391
|
+
return enc
|
392
|
+
|
340
393
|
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
341
394
|
r"""Encode a batch of images using a tiled encoder.
|
342
395
|
|
@@ -356,6 +409,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
356
409
|
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
357
410
|
`tuple` is returned.
|
358
411
|
"""
|
412
|
+
deprecation_message = (
|
413
|
+
"The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
|
414
|
+
"implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
|
415
|
+
"to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
|
416
|
+
)
|
417
|
+
deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
|
418
|
+
|
359
419
|
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
360
420
|
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
361
421
|
row_limit = self.tile_latent_min_size - blend_extent
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
-
from typing import Optional, Tuple, Union
|
16
|
+
from typing import Dict, Optional, Tuple, Union
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
import torch
|
@@ -41,7 +41,9 @@ class CogVideoXSafeConv3d(nn.Conv3d):
|
|
41
41
|
"""
|
42
42
|
|
43
43
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
44
|
-
memory_count =
|
44
|
+
memory_count = (
|
45
|
+
(input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
|
46
|
+
)
|
45
47
|
|
46
48
|
# Set to 2GB, suitable for CuDNN
|
47
49
|
if memory_count > 2:
|
@@ -115,34 +117,24 @@ class CogVideoXCausalConv3d(nn.Module):
|
|
115
117
|
dilation=dilation,
|
116
118
|
)
|
117
119
|
|
118
|
-
|
119
|
-
|
120
|
-
|
120
|
+
def fake_context_parallel_forward(
|
121
|
+
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
|
122
|
+
) -> torch.Tensor:
|
121
123
|
kernel_size = self.time_kernel_size
|
122
124
|
if kernel_size > 1:
|
123
|
-
cached_inputs = (
|
124
|
-
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
125
|
-
)
|
125
|
+
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
126
126
|
inputs = torch.cat(cached_inputs + [inputs], dim=2)
|
127
127
|
return inputs
|
128
128
|
|
129
|
-
def
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
134
|
-
inputs = self.fake_context_parallel_forward(inputs)
|
135
|
-
|
136
|
-
self._clear_fake_context_parallel_cache()
|
137
|
-
# Note: we could move these to the cpu for a lower maximum memory usage but its only a few
|
138
|
-
# hundred megabytes and so let's not do it for now
|
139
|
-
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
129
|
+
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
|
130
|
+
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
|
131
|
+
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
140
132
|
|
141
133
|
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
142
134
|
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
143
135
|
|
144
136
|
output = self.conv(inputs)
|
145
|
-
return output
|
137
|
+
return output, conv_cache
|
146
138
|
|
147
139
|
|
148
140
|
class CogVideoXSpatialNorm3D(nn.Module):
|
@@ -172,7 +164,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
|
172
164
|
self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
173
165
|
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
174
166
|
|
175
|
-
def forward(
|
167
|
+
def forward(
|
168
|
+
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
|
169
|
+
) -> torch.Tensor:
|
170
|
+
new_conv_cache = {}
|
171
|
+
conv_cache = conv_cache or {}
|
172
|
+
|
176
173
|
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
177
174
|
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
178
175
|
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
@@ -183,9 +180,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
|
183
180
|
else:
|
184
181
|
zq = F.interpolate(zq, size=f.shape[-3:])
|
185
182
|
|
183
|
+
conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
|
184
|
+
conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
|
185
|
+
|
186
186
|
norm_f = self.norm_layer(f)
|
187
|
-
new_f = norm_f *
|
188
|
-
return new_f
|
187
|
+
new_f = norm_f * conv_y + conv_b
|
188
|
+
return new_f, new_conv_cache
|
189
189
|
|
190
190
|
|
191
191
|
class CogVideoXResnetBlock3D(nn.Module):
|
@@ -236,6 +236,7 @@ class CogVideoXResnetBlock3D(nn.Module):
|
|
236
236
|
self.out_channels = out_channels
|
237
237
|
self.nonlinearity = get_activation(non_linearity)
|
238
238
|
self.use_conv_shortcut = conv_shortcut
|
239
|
+
self.spatial_norm_dim = spatial_norm_dim
|
239
240
|
|
240
241
|
if spatial_norm_dim is None:
|
241
242
|
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
@@ -279,34 +280,43 @@ class CogVideoXResnetBlock3D(nn.Module):
|
|
279
280
|
inputs: torch.Tensor,
|
280
281
|
temb: Optional[torch.Tensor] = None,
|
281
282
|
zq: Optional[torch.Tensor] = None,
|
283
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
282
284
|
) -> torch.Tensor:
|
285
|
+
new_conv_cache = {}
|
286
|
+
conv_cache = conv_cache or {}
|
287
|
+
|
283
288
|
hidden_states = inputs
|
284
289
|
|
285
290
|
if zq is not None:
|
286
|
-
hidden_states = self.norm1(hidden_states, zq)
|
291
|
+
hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
|
287
292
|
else:
|
288
293
|
hidden_states = self.norm1(hidden_states)
|
289
294
|
|
290
295
|
hidden_states = self.nonlinearity(hidden_states)
|
291
|
-
hidden_states = self.conv1(hidden_states)
|
296
|
+
hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
|
292
297
|
|
293
298
|
if temb is not None:
|
294
299
|
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
295
300
|
|
296
301
|
if zq is not None:
|
297
|
-
hidden_states = self.norm2(hidden_states, zq)
|
302
|
+
hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
|
298
303
|
else:
|
299
304
|
hidden_states = self.norm2(hidden_states)
|
300
305
|
|
301
306
|
hidden_states = self.nonlinearity(hidden_states)
|
302
307
|
hidden_states = self.dropout(hidden_states)
|
303
|
-
hidden_states = self.conv2(hidden_states)
|
308
|
+
hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
|
304
309
|
|
305
310
|
if self.in_channels != self.out_channels:
|
306
|
-
|
311
|
+
if self.use_conv_shortcut:
|
312
|
+
inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
|
313
|
+
inputs, conv_cache=conv_cache.get("conv_shortcut")
|
314
|
+
)
|
315
|
+
else:
|
316
|
+
inputs = self.conv_shortcut(inputs)
|
307
317
|
|
308
318
|
hidden_states = hidden_states + inputs
|
309
|
-
return hidden_states
|
319
|
+
return hidden_states, new_conv_cache
|
310
320
|
|
311
321
|
|
312
322
|
class CogVideoXDownBlock3D(nn.Module):
|
@@ -392,8 +402,16 @@ class CogVideoXDownBlock3D(nn.Module):
|
|
392
402
|
hidden_states: torch.Tensor,
|
393
403
|
temb: Optional[torch.Tensor] = None,
|
394
404
|
zq: Optional[torch.Tensor] = None,
|
405
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
395
406
|
) -> torch.Tensor:
|
396
|
-
|
407
|
+
r"""Forward method of the `CogVideoXDownBlock3D` class."""
|
408
|
+
|
409
|
+
new_conv_cache = {}
|
410
|
+
conv_cache = conv_cache or {}
|
411
|
+
|
412
|
+
for i, resnet in enumerate(self.resnets):
|
413
|
+
conv_cache_key = f"resnet_{i}"
|
414
|
+
|
397
415
|
if self.training and self.gradient_checkpointing:
|
398
416
|
|
399
417
|
def create_custom_forward(module):
|
@@ -402,17 +420,23 @@ class CogVideoXDownBlock3D(nn.Module):
|
|
402
420
|
|
403
421
|
return create_forward
|
404
422
|
|
405
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
406
|
-
create_custom_forward(resnet),
|
423
|
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
424
|
+
create_custom_forward(resnet),
|
425
|
+
hidden_states,
|
426
|
+
temb,
|
427
|
+
zq,
|
428
|
+
conv_cache=conv_cache.get(conv_cache_key),
|
407
429
|
)
|
408
430
|
else:
|
409
|
-
hidden_states = resnet(
|
431
|
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
432
|
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
433
|
+
)
|
410
434
|
|
411
435
|
if self.downsamplers is not None:
|
412
436
|
for downsampler in self.downsamplers:
|
413
437
|
hidden_states = downsampler(hidden_states)
|
414
438
|
|
415
|
-
return hidden_states
|
439
|
+
return hidden_states, new_conv_cache
|
416
440
|
|
417
441
|
|
418
442
|
class CogVideoXMidBlock3D(nn.Module):
|
@@ -480,8 +504,16 @@ class CogVideoXMidBlock3D(nn.Module):
|
|
480
504
|
hidden_states: torch.Tensor,
|
481
505
|
temb: Optional[torch.Tensor] = None,
|
482
506
|
zq: Optional[torch.Tensor] = None,
|
507
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
483
508
|
) -> torch.Tensor:
|
484
|
-
|
509
|
+
r"""Forward method of the `CogVideoXMidBlock3D` class."""
|
510
|
+
|
511
|
+
new_conv_cache = {}
|
512
|
+
conv_cache = conv_cache or {}
|
513
|
+
|
514
|
+
for i, resnet in enumerate(self.resnets):
|
515
|
+
conv_cache_key = f"resnet_{i}"
|
516
|
+
|
485
517
|
if self.training and self.gradient_checkpointing:
|
486
518
|
|
487
519
|
def create_custom_forward(module):
|
@@ -490,13 +522,15 @@ class CogVideoXMidBlock3D(nn.Module):
|
|
490
522
|
|
491
523
|
return create_forward
|
492
524
|
|
493
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
494
|
-
create_custom_forward(resnet), hidden_states, temb, zq
|
525
|
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
526
|
+
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
495
527
|
)
|
496
528
|
else:
|
497
|
-
hidden_states = resnet(
|
529
|
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
530
|
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
531
|
+
)
|
498
532
|
|
499
|
-
return hidden_states
|
533
|
+
return hidden_states, new_conv_cache
|
500
534
|
|
501
535
|
|
502
536
|
class CogVideoXUpBlock3D(nn.Module):
|
@@ -584,9 +618,16 @@ class CogVideoXUpBlock3D(nn.Module):
|
|
584
618
|
hidden_states: torch.Tensor,
|
585
619
|
temb: Optional[torch.Tensor] = None,
|
586
620
|
zq: Optional[torch.Tensor] = None,
|
621
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
587
622
|
) -> torch.Tensor:
|
588
623
|
r"""Forward method of the `CogVideoXUpBlock3D` class."""
|
589
|
-
|
624
|
+
|
625
|
+
new_conv_cache = {}
|
626
|
+
conv_cache = conv_cache or {}
|
627
|
+
|
628
|
+
for i, resnet in enumerate(self.resnets):
|
629
|
+
conv_cache_key = f"resnet_{i}"
|
630
|
+
|
590
631
|
if self.training and self.gradient_checkpointing:
|
591
632
|
|
592
633
|
def create_custom_forward(module):
|
@@ -595,17 +636,23 @@ class CogVideoXUpBlock3D(nn.Module):
|
|
595
636
|
|
596
637
|
return create_forward
|
597
638
|
|
598
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
599
|
-
create_custom_forward(resnet),
|
639
|
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
640
|
+
create_custom_forward(resnet),
|
641
|
+
hidden_states,
|
642
|
+
temb,
|
643
|
+
zq,
|
644
|
+
conv_cache=conv_cache.get(conv_cache_key),
|
600
645
|
)
|
601
646
|
else:
|
602
|
-
hidden_states = resnet(
|
647
|
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
648
|
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
649
|
+
)
|
603
650
|
|
604
651
|
if self.upsamplers is not None:
|
605
652
|
for upsampler in self.upsamplers:
|
606
653
|
hidden_states = upsampler(hidden_states)
|
607
654
|
|
608
|
-
return hidden_states
|
655
|
+
return hidden_states, new_conv_cache
|
609
656
|
|
610
657
|
|
611
658
|
class CogVideoXEncoder3D(nn.Module):
|
@@ -705,9 +752,18 @@ class CogVideoXEncoder3D(nn.Module):
|
|
705
752
|
|
706
753
|
self.gradient_checkpointing = False
|
707
754
|
|
708
|
-
def forward(
|
755
|
+
def forward(
|
756
|
+
self,
|
757
|
+
sample: torch.Tensor,
|
758
|
+
temb: Optional[torch.Tensor] = None,
|
759
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
760
|
+
) -> torch.Tensor:
|
709
761
|
r"""The forward method of the `CogVideoXEncoder3D` class."""
|
710
|
-
|
762
|
+
|
763
|
+
new_conv_cache = {}
|
764
|
+
conv_cache = conv_cache or {}
|
765
|
+
|
766
|
+
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
711
767
|
|
712
768
|
if self.training and self.gradient_checkpointing:
|
713
769
|
|
@@ -718,28 +774,44 @@ class CogVideoXEncoder3D(nn.Module):
|
|
718
774
|
return custom_forward
|
719
775
|
|
720
776
|
# 1. Down
|
721
|
-
for down_block in self.down_blocks:
|
722
|
-
|
723
|
-
|
777
|
+
for i, down_block in enumerate(self.down_blocks):
|
778
|
+
conv_cache_key = f"down_block_{i}"
|
779
|
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
780
|
+
create_custom_forward(down_block),
|
781
|
+
hidden_states,
|
782
|
+
temb,
|
783
|
+
None,
|
784
|
+
conv_cache=conv_cache.get(conv_cache_key),
|
724
785
|
)
|
725
786
|
|
726
787
|
# 2. Mid
|
727
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
728
|
-
create_custom_forward(self.mid_block),
|
788
|
+
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
789
|
+
create_custom_forward(self.mid_block),
|
790
|
+
hidden_states,
|
791
|
+
temb,
|
792
|
+
None,
|
793
|
+
conv_cache=conv_cache.get("mid_block"),
|
729
794
|
)
|
730
795
|
else:
|
731
796
|
# 1. Down
|
732
|
-
for down_block in self.down_blocks:
|
733
|
-
|
797
|
+
for i, down_block in enumerate(self.down_blocks):
|
798
|
+
conv_cache_key = f"down_block_{i}"
|
799
|
+
hidden_states, new_conv_cache[conv_cache_key] = down_block(
|
800
|
+
hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
|
801
|
+
)
|
734
802
|
|
735
803
|
# 2. Mid
|
736
|
-
hidden_states = self.mid_block(
|
804
|
+
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
805
|
+
hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
|
806
|
+
)
|
737
807
|
|
738
808
|
# 3. Post-process
|
739
809
|
hidden_states = self.norm_out(hidden_states)
|
740
810
|
hidden_states = self.conv_act(hidden_states)
|
741
|
-
|
742
|
-
|
811
|
+
|
812
|
+
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
813
|
+
|
814
|
+
return hidden_states, new_conv_cache
|
743
815
|
|
744
816
|
|
745
817
|
class CogVideoXDecoder3D(nn.Module):
|
@@ -846,9 +918,18 @@ class CogVideoXDecoder3D(nn.Module):
|
|
846
918
|
|
847
919
|
self.gradient_checkpointing = False
|
848
920
|
|
849
|
-
def forward(
|
921
|
+
def forward(
|
922
|
+
self,
|
923
|
+
sample: torch.Tensor,
|
924
|
+
temb: Optional[torch.Tensor] = None,
|
925
|
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
926
|
+
) -> torch.Tensor:
|
850
927
|
r"""The forward method of the `CogVideoXDecoder3D` class."""
|
851
|
-
|
928
|
+
|
929
|
+
new_conv_cache = {}
|
930
|
+
conv_cache = conv_cache or {}
|
931
|
+
|
932
|
+
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
852
933
|
|
853
934
|
if self.training and self.gradient_checkpointing:
|
854
935
|
|
@@ -859,28 +940,45 @@ class CogVideoXDecoder3D(nn.Module):
|
|
859
940
|
return custom_forward
|
860
941
|
|
861
942
|
# 1. Mid
|
862
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
863
|
-
create_custom_forward(self.mid_block),
|
943
|
+
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
944
|
+
create_custom_forward(self.mid_block),
|
945
|
+
hidden_states,
|
946
|
+
temb,
|
947
|
+
sample,
|
948
|
+
conv_cache=conv_cache.get("mid_block"),
|
864
949
|
)
|
865
950
|
|
866
951
|
# 2. Up
|
867
|
-
for up_block in self.up_blocks:
|
868
|
-
|
869
|
-
|
952
|
+
for i, up_block in enumerate(self.up_blocks):
|
953
|
+
conv_cache_key = f"up_block_{i}"
|
954
|
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
955
|
+
create_custom_forward(up_block),
|
956
|
+
hidden_states,
|
957
|
+
temb,
|
958
|
+
sample,
|
959
|
+
conv_cache=conv_cache.get(conv_cache_key),
|
870
960
|
)
|
871
961
|
else:
|
872
962
|
# 1. Mid
|
873
|
-
hidden_states = self.mid_block(
|
963
|
+
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
964
|
+
hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
|
965
|
+
)
|
874
966
|
|
875
967
|
# 2. Up
|
876
|
-
for up_block in self.up_blocks:
|
877
|
-
|
968
|
+
for i, up_block in enumerate(self.up_blocks):
|
969
|
+
conv_cache_key = f"up_block_{i}"
|
970
|
+
hidden_states, new_conv_cache[conv_cache_key] = up_block(
|
971
|
+
hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
|
972
|
+
)
|
878
973
|
|
879
974
|
# 3. Post-process
|
880
|
-
hidden_states = self.norm_out(
|
975
|
+
hidden_states, new_conv_cache["norm_out"] = self.norm_out(
|
976
|
+
hidden_states, sample, conv_cache=conv_cache.get("norm_out")
|
977
|
+
)
|
881
978
|
hidden_states = self.conv_act(hidden_states)
|
882
|
-
hidden_states = self.conv_out(hidden_states)
|
883
|
-
|
979
|
+
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
980
|
+
|
981
|
+
return hidden_states, new_conv_cache
|
884
982
|
|
885
983
|
|
886
984
|
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
@@ -1019,12 +1117,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1019
1117
|
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
1020
1118
|
module.gradient_checkpointing = value
|
1021
1119
|
|
1022
|
-
def _clear_fake_context_parallel_cache(self):
|
1023
|
-
for name, module in self.named_modules():
|
1024
|
-
if isinstance(module, CogVideoXCausalConv3d):
|
1025
|
-
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
|
1026
|
-
module._clear_fake_context_parallel_cache()
|
1027
|
-
|
1028
1120
|
def enable_tiling(
|
1029
1121
|
self,
|
1030
1122
|
tile_sample_min_height: Optional[int] = None,
|
@@ -1090,21 +1182,22 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1090
1182
|
|
1091
1183
|
frame_batch_size = self.num_sample_frames_batch_size
|
1092
1184
|
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
1093
|
-
|
1185
|
+
# As the extra single frame is handled inside the loop, it is not required to round up here.
|
1186
|
+
num_batches = max(num_frames // frame_batch_size, 1)
|
1187
|
+
conv_cache = None
|
1094
1188
|
enc = []
|
1189
|
+
|
1095
1190
|
for i in range(num_batches):
|
1096
1191
|
remaining_frames = num_frames % frame_batch_size
|
1097
1192
|
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
1098
1193
|
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
1099
1194
|
x_intermediate = x[:, :, start_frame:end_frame]
|
1100
|
-
x_intermediate = self.encoder(x_intermediate)
|
1195
|
+
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
|
1101
1196
|
if self.quant_conv is not None:
|
1102
1197
|
x_intermediate = self.quant_conv(x_intermediate)
|
1103
1198
|
enc.append(x_intermediate)
|
1104
1199
|
|
1105
|
-
self._clear_fake_context_parallel_cache()
|
1106
1200
|
enc = torch.cat(enc, dim=2)
|
1107
|
-
|
1108
1201
|
return enc
|
1109
1202
|
|
1110
1203
|
@apply_forward_hook
|
@@ -1142,8 +1235,10 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1142
1235
|
return self.tiled_decode(z, return_dict=return_dict)
|
1143
1236
|
|
1144
1237
|
frame_batch_size = self.num_latent_frames_batch_size
|
1145
|
-
num_batches = num_frames // frame_batch_size
|
1238
|
+
num_batches = max(num_frames // frame_batch_size, 1)
|
1239
|
+
conv_cache = None
|
1146
1240
|
dec = []
|
1241
|
+
|
1147
1242
|
for i in range(num_batches):
|
1148
1243
|
remaining_frames = num_frames % frame_batch_size
|
1149
1244
|
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
@@ -1151,10 +1246,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1151
1246
|
z_intermediate = z[:, :, start_frame:end_frame]
|
1152
1247
|
if self.post_quant_conv is not None:
|
1153
1248
|
z_intermediate = self.post_quant_conv(z_intermediate)
|
1154
|
-
z_intermediate = self.decoder(z_intermediate)
|
1249
|
+
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
1155
1250
|
dec.append(z_intermediate)
|
1156
1251
|
|
1157
|
-
self._clear_fake_context_parallel_cache()
|
1158
1252
|
dec = torch.cat(dec, dim=2)
|
1159
1253
|
|
1160
1254
|
if not return_dict:
|
@@ -1237,8 +1331,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1237
1331
|
row = []
|
1238
1332
|
for j in range(0, width, overlap_width):
|
1239
1333
|
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
1240
|
-
|
1334
|
+
# As the extra single frame is handled inside the loop, it is not required to round up here.
|
1335
|
+
num_batches = max(num_frames // frame_batch_size, 1)
|
1336
|
+
conv_cache = None
|
1241
1337
|
time = []
|
1338
|
+
|
1242
1339
|
for k in range(num_batches):
|
1243
1340
|
remaining_frames = num_frames % frame_batch_size
|
1244
1341
|
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
@@ -1250,11 +1347,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1250
1347
|
i : i + self.tile_sample_min_height,
|
1251
1348
|
j : j + self.tile_sample_min_width,
|
1252
1349
|
]
|
1253
|
-
tile = self.encoder(tile)
|
1350
|
+
tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
|
1254
1351
|
if self.quant_conv is not None:
|
1255
1352
|
tile = self.quant_conv(tile)
|
1256
1353
|
time.append(tile)
|
1257
|
-
|
1354
|
+
|
1258
1355
|
row.append(torch.cat(time, dim=2))
|
1259
1356
|
rows.append(row)
|
1260
1357
|
|
@@ -1314,8 +1411,10 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1314
1411
|
for i in range(0, height, overlap_height):
|
1315
1412
|
row = []
|
1316
1413
|
for j in range(0, width, overlap_width):
|
1317
|
-
num_batches = num_frames // frame_batch_size
|
1414
|
+
num_batches = max(num_frames // frame_batch_size, 1)
|
1415
|
+
conv_cache = None
|
1318
1416
|
time = []
|
1417
|
+
|
1319
1418
|
for k in range(num_batches):
|
1320
1419
|
remaining_frames = num_frames % frame_batch_size
|
1321
1420
|
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
@@ -1329,9 +1428,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1329
1428
|
]
|
1330
1429
|
if self.post_quant_conv is not None:
|
1331
1430
|
tile = self.post_quant_conv(tile)
|
1332
|
-
tile = self.decoder(tile)
|
1431
|
+
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
|
1333
1432
|
time.append(tile)
|
1334
|
-
|
1433
|
+
|
1335
1434
|
row.append(torch.cat(time, dim=2))
|
1336
1435
|
rows.append(row)
|
1337
1436
|
|