diffusers 0.30.0__py3-none-any.whl → 0.30.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +1 -1
- diffusers/loaders/ip_adapter.py +2 -0
- diffusers/loaders/lora_pipeline.py +37 -7
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_utils.py +36 -11
- diffusers/models/attention_processor.py +142 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +327 -91
- diffusers/models/embeddings.py +84 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +17 -1
- diffusers/models/transformers/cogvideox_transformer_3d.py +196 -56
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +78 -19
- diffusers/pipelines/flux/pipeline_flux.py +1 -1
- diffusers/utils/export_utils.py +50 -3
- diffusers/utils/import_utils.py +19 -0
- diffusers/utils/loading_utils.py +16 -12
- {diffusers-0.30.0.dist-info → diffusers-0.30.2.dist-info}/METADATA +1 -1
- {diffusers-0.30.0.dist-info → diffusers-0.30.2.dist-info}/RECORD +21 -21
- {diffusers-0.30.0.dist-info → diffusers-0.30.2.dist-info}/WHEEL +1 -1
- {diffusers-0.30.0.dist-info → diffusers-0.30.2.dist-info}/LICENSE +0 -0
- {diffusers-0.30.0.dist-info → diffusers-0.30.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.0.dist-info → diffusers-0.30.2.dist-info}/top_level.txt +0 -0
@@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
36
36
|
|
37
37
|
|
38
38
|
class CogVideoXSafeConv3d(nn.Conv3d):
|
39
|
-
"""
|
39
|
+
r"""
|
40
40
|
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
|
41
41
|
"""
|
42
42
|
|
@@ -68,12 +68,12 @@ class CogVideoXCausalConv3d(nn.Module):
|
|
68
68
|
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
|
69
69
|
|
70
70
|
Args:
|
71
|
-
in_channels (int): Number of channels in the input tensor.
|
72
|
-
out_channels (int): Number of output channels.
|
73
|
-
kernel_size (
|
74
|
-
stride (int
|
75
|
-
dilation (int
|
76
|
-
pad_mode (str
|
71
|
+
in_channels (`int`): Number of channels in the input tensor.
|
72
|
+
out_channels (`int`): Number of output channels produced by the convolution.
|
73
|
+
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
|
74
|
+
stride (`int`, defaults to `1`): Stride of the convolution.
|
75
|
+
dilation (`int`, defaults to `1`): Dilation rate of the convolution.
|
76
|
+
pad_mode (`str`, defaults to `"constant"`): Padding mode.
|
77
77
|
"""
|
78
78
|
|
79
79
|
def __init__(
|
@@ -118,19 +118,12 @@ class CogVideoXCausalConv3d(nn.Module):
|
|
118
118
|
self.conv_cache = None
|
119
119
|
|
120
120
|
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
121
|
-
dim = self.temporal_dim
|
122
121
|
kernel_size = self.time_kernel_size
|
123
|
-
if kernel_size
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
if self.conv_cache is not None:
|
129
|
-
inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0)
|
130
|
-
else:
|
131
|
-
inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0)
|
132
|
-
|
133
|
-
inputs = inputs.transpose(0, dim).contiguous()
|
122
|
+
if kernel_size > 1:
|
123
|
+
cached_inputs = (
|
124
|
+
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
125
|
+
)
|
126
|
+
inputs = torch.cat(cached_inputs + [inputs], dim=2)
|
134
127
|
return inputs
|
135
128
|
|
136
129
|
def _clear_fake_context_parallel_cache(self):
|
@@ -138,16 +131,17 @@ class CogVideoXCausalConv3d(nn.Module):
|
|
138
131
|
self.conv_cache = None
|
139
132
|
|
140
133
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
141
|
-
|
134
|
+
inputs = self.fake_context_parallel_forward(inputs)
|
142
135
|
|
143
136
|
self._clear_fake_context_parallel_cache()
|
144
|
-
|
137
|
+
# Note: we could move these to the cpu for a lower maximum memory usage but its only a few
|
138
|
+
# hundred megabytes and so let's not do it for now
|
139
|
+
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
145
140
|
|
146
141
|
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
147
|
-
|
142
|
+
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
148
143
|
|
149
|
-
|
150
|
-
output = output_parallel
|
144
|
+
output = self.conv(inputs)
|
151
145
|
return output
|
152
146
|
|
153
147
|
|
@@ -163,6 +157,8 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
|
163
157
|
The number of channels for input to group normalization layer, and output of the spatial norm layer.
|
164
158
|
zq_channels (`int`):
|
165
159
|
The number of channels for the quantized vector as described in the paper.
|
160
|
+
groups (`int`):
|
161
|
+
Number of groups to separate the channels into for group normalization.
|
166
162
|
"""
|
167
163
|
|
168
164
|
def __init__(
|
@@ -197,17 +193,26 @@ class CogVideoXResnetBlock3D(nn.Module):
|
|
197
193
|
A 3D ResNet block used in the CogVideoX model.
|
198
194
|
|
199
195
|
Args:
|
200
|
-
in_channels (int):
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
196
|
+
in_channels (`int`):
|
197
|
+
Number of input channels.
|
198
|
+
out_channels (`int`, *optional*):
|
199
|
+
Number of output channels. If None, defaults to `in_channels`.
|
200
|
+
dropout (`float`, defaults to `0.0`):
|
201
|
+
Dropout rate.
|
202
|
+
temb_channels (`int`, defaults to `512`):
|
203
|
+
Number of time embedding channels.
|
204
|
+
groups (`int`, defaults to `32`):
|
205
|
+
Number of groups to separate the channels into for group normalization.
|
206
|
+
eps (`float`, defaults to `1e-6`):
|
207
|
+
Epsilon value for normalization layers.
|
208
|
+
non_linearity (`str`, defaults to `"swish"`):
|
209
|
+
Activation function to use.
|
210
|
+
conv_shortcut (bool, defaults to `False`):
|
211
|
+
Whether or not to use a convolution shortcut.
|
212
|
+
spatial_norm_dim (`int`, *optional*):
|
213
|
+
The dimension to use for spatial norm if it is to be used instead of group norm.
|
214
|
+
pad_mode (str, defaults to `"first"`):
|
215
|
+
Padding mode.
|
211
216
|
"""
|
212
217
|
|
213
218
|
def __init__(
|
@@ -309,18 +314,28 @@ class CogVideoXDownBlock3D(nn.Module):
|
|
309
314
|
A downsampling block used in the CogVideoX model.
|
310
315
|
|
311
316
|
Args:
|
312
|
-
in_channels (int):
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
317
|
+
in_channels (`int`):
|
318
|
+
Number of input channels.
|
319
|
+
out_channels (`int`, *optional*):
|
320
|
+
Number of output channels. If None, defaults to `in_channels`.
|
321
|
+
temb_channels (`int`, defaults to `512`):
|
322
|
+
Number of time embedding channels.
|
323
|
+
num_layers (`int`, defaults to `1`):
|
324
|
+
Number of resnet layers.
|
325
|
+
dropout (`float`, defaults to `0.0`):
|
326
|
+
Dropout rate.
|
327
|
+
resnet_eps (`float`, defaults to `1e-6`):
|
328
|
+
Epsilon value for normalization layers.
|
329
|
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
330
|
+
Activation function to use.
|
331
|
+
resnet_groups (`int`, defaults to `32`):
|
332
|
+
Number of groups to separate the channels into for group normalization.
|
333
|
+
add_downsample (`bool`, defaults to `True`):
|
334
|
+
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
|
335
|
+
compress_time (`bool`, defaults to `False`):
|
336
|
+
Whether or not to downsample across temporal dimension.
|
337
|
+
pad_mode (str, defaults to `"first"`):
|
338
|
+
Padding mode.
|
324
339
|
"""
|
325
340
|
|
326
341
|
_supports_gradient_checkpointing = True
|
@@ -405,15 +420,24 @@ class CogVideoXMidBlock3D(nn.Module):
|
|
405
420
|
A middle block used in the CogVideoX model.
|
406
421
|
|
407
422
|
Args:
|
408
|
-
in_channels (int):
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
423
|
+
in_channels (`int`):
|
424
|
+
Number of input channels.
|
425
|
+
temb_channels (`int`, defaults to `512`):
|
426
|
+
Number of time embedding channels.
|
427
|
+
dropout (`float`, defaults to `0.0`):
|
428
|
+
Dropout rate.
|
429
|
+
num_layers (`int`, defaults to `1`):
|
430
|
+
Number of resnet layers.
|
431
|
+
resnet_eps (`float`, defaults to `1e-6`):
|
432
|
+
Epsilon value for normalization layers.
|
433
|
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
434
|
+
Activation function to use.
|
435
|
+
resnet_groups (`int`, defaults to `32`):
|
436
|
+
Number of groups to separate the channels into for group normalization.
|
437
|
+
spatial_norm_dim (`int`, *optional*):
|
438
|
+
The dimension to use for spatial norm if it is to be used instead of group norm.
|
439
|
+
pad_mode (str, defaults to `"first"`):
|
440
|
+
Padding mode.
|
417
441
|
"""
|
418
442
|
|
419
443
|
_supports_gradient_checkpointing = True
|
@@ -480,19 +504,30 @@ class CogVideoXUpBlock3D(nn.Module):
|
|
480
504
|
An upsampling block used in the CogVideoX model.
|
481
505
|
|
482
506
|
Args:
|
483
|
-
in_channels (int):
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
507
|
+
in_channels (`int`):
|
508
|
+
Number of input channels.
|
509
|
+
out_channels (`int`, *optional*):
|
510
|
+
Number of output channels. If None, defaults to `in_channels`.
|
511
|
+
temb_channels (`int`, defaults to `512`):
|
512
|
+
Number of time embedding channels.
|
513
|
+
dropout (`float`, defaults to `0.0`):
|
514
|
+
Dropout rate.
|
515
|
+
num_layers (`int`, defaults to `1`):
|
516
|
+
Number of resnet layers.
|
517
|
+
resnet_eps (`float`, defaults to `1e-6`):
|
518
|
+
Epsilon value for normalization layers.
|
519
|
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
520
|
+
Activation function to use.
|
521
|
+
resnet_groups (`int`, defaults to `32`):
|
522
|
+
Number of groups to separate the channels into for group normalization.
|
523
|
+
spatial_norm_dim (`int`, defaults to `16`):
|
524
|
+
The dimension to use for spatial norm if it is to be used instead of group norm.
|
525
|
+
add_upsample (`bool`, defaults to `True`):
|
526
|
+
Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
|
527
|
+
compress_time (`bool`, defaults to `False`):
|
528
|
+
Whether or not to downsample across temporal dimension.
|
529
|
+
pad_mode (str, defaults to `"first"`):
|
530
|
+
Padding mode.
|
496
531
|
"""
|
497
532
|
|
498
533
|
def __init__(
|
@@ -587,14 +622,12 @@ class CogVideoXEncoder3D(nn.Module):
|
|
587
622
|
options.
|
588
623
|
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
589
624
|
The number of output channels for each block.
|
625
|
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
626
|
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
590
627
|
layers_per_block (`int`, *optional*, defaults to 2):
|
591
628
|
The number of layers per block.
|
592
629
|
norm_num_groups (`int`, *optional*, defaults to 32):
|
593
630
|
The number of groups for normalization.
|
594
|
-
act_fn (`str`, *optional*, defaults to `"silu"`):
|
595
|
-
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
596
|
-
double_z (`bool`, *optional*, defaults to `True`):
|
597
|
-
Whether to double the number of output channels for the last block.
|
598
631
|
"""
|
599
632
|
|
600
633
|
_supports_gradient_checkpointing = True
|
@@ -723,14 +756,12 @@ class CogVideoXDecoder3D(nn.Module):
|
|
723
756
|
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
724
757
|
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
725
758
|
The number of output channels for each block.
|
759
|
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
760
|
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
726
761
|
layers_per_block (`int`, *optional*, defaults to 2):
|
727
762
|
The number of layers per block.
|
728
763
|
norm_num_groups (`int`, *optional*, defaults to 32):
|
729
764
|
The number of groups for normalization.
|
730
|
-
act_fn (`str`, *optional*, defaults to `"silu"`):
|
731
|
-
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
732
|
-
norm_type (`str`, *optional*, defaults to `"group"`):
|
733
|
-
The normalization type to use. Can be either `"group"` or `"spatial"`.
|
734
765
|
"""
|
735
766
|
|
736
767
|
_supports_gradient_checkpointing = True
|
@@ -871,7 +902,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
871
902
|
Tuple of block output channels.
|
872
903
|
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
873
904
|
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
874
|
-
scaling_factor (`float`, *optional*, defaults to
|
905
|
+
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
|
875
906
|
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
876
907
|
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
877
908
|
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
@@ -911,7 +942,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
911
942
|
norm_eps: float = 1e-6,
|
912
943
|
norm_num_groups: int = 32,
|
913
944
|
temporal_compression_ratio: float = 4,
|
914
|
-
|
945
|
+
sample_height: int = 480,
|
946
|
+
sample_width: int = 720,
|
915
947
|
scaling_factor: float = 1.15258426,
|
916
948
|
shift_factor: Optional[float] = None,
|
917
949
|
latents_mean: Optional[Tuple[float]] = None,
|
@@ -950,25 +982,105 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
950
982
|
self.use_slicing = False
|
951
983
|
self.use_tiling = False
|
952
984
|
|
953
|
-
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
985
|
+
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
|
986
|
+
# recommended because the temporal parts of the VAE, here, are tricky to understand.
|
987
|
+
# If you decode X latent frames together, the number of output frames is:
|
988
|
+
# (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
|
989
|
+
#
|
990
|
+
# Example with num_latent_frames_batch_size = 2:
|
991
|
+
# - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
|
992
|
+
# => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
993
|
+
# => 6 * 8 = 48 frames
|
994
|
+
# - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
|
995
|
+
# => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
|
996
|
+
# ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
997
|
+
# => 1 * 9 + 5 * 8 = 49 frames
|
998
|
+
# It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
|
999
|
+
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
|
1000
|
+
# number of temporal frames.
|
1001
|
+
self.num_latent_frames_batch_size = 2
|
1002
|
+
|
1003
|
+
# We make the minimum height and width of sample for tiling half that of the generally supported
|
1004
|
+
self.tile_sample_min_height = sample_height // 2
|
1005
|
+
self.tile_sample_min_width = sample_width // 2
|
1006
|
+
self.tile_latent_min_height = int(
|
1007
|
+
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
958
1008
|
)
|
959
|
-
self.
|
960
|
-
|
1009
|
+
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
1010
|
+
|
1011
|
+
# These are experimental overlap factors that were chosen based on experimentation and seem to work best for
|
1012
|
+
# 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
|
1013
|
+
# and so the tiling implementation has only been tested on those specific resolutions.
|
1014
|
+
self.tile_overlap_factor_height = 1 / 6
|
1015
|
+
self.tile_overlap_factor_width = 1 / 5
|
961
1016
|
|
962
1017
|
def _set_gradient_checkpointing(self, module, value=False):
|
963
1018
|
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
964
1019
|
module.gradient_checkpointing = value
|
965
1020
|
|
966
|
-
def
|
1021
|
+
def _clear_fake_context_parallel_cache(self):
|
967
1022
|
for name, module in self.named_modules():
|
968
1023
|
if isinstance(module, CogVideoXCausalConv3d):
|
969
1024
|
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
|
970
1025
|
module._clear_fake_context_parallel_cache()
|
971
1026
|
|
1027
|
+
def enable_tiling(
|
1028
|
+
self,
|
1029
|
+
tile_sample_min_height: Optional[int] = None,
|
1030
|
+
tile_sample_min_width: Optional[int] = None,
|
1031
|
+
tile_overlap_factor_height: Optional[float] = None,
|
1032
|
+
tile_overlap_factor_width: Optional[float] = None,
|
1033
|
+
) -> None:
|
1034
|
+
r"""
|
1035
|
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
1036
|
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
1037
|
+
processing larger images.
|
1038
|
+
|
1039
|
+
Args:
|
1040
|
+
tile_sample_min_height (`int`, *optional*):
|
1041
|
+
The minimum height required for a sample to be separated into tiles across the height dimension.
|
1042
|
+
tile_sample_min_width (`int`, *optional*):
|
1043
|
+
The minimum width required for a sample to be separated into tiles across the width dimension.
|
1044
|
+
tile_overlap_factor_height (`int`, *optional*):
|
1045
|
+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
1046
|
+
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
|
1047
|
+
value might cause more tiles to be processed leading to slow down of the decoding process.
|
1048
|
+
tile_overlap_factor_width (`int`, *optional*):
|
1049
|
+
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
|
1050
|
+
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
|
1051
|
+
value might cause more tiles to be processed leading to slow down of the decoding process.
|
1052
|
+
"""
|
1053
|
+
self.use_tiling = True
|
1054
|
+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
1055
|
+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
1056
|
+
self.tile_latent_min_height = int(
|
1057
|
+
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
1058
|
+
)
|
1059
|
+
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
1060
|
+
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
|
1061
|
+
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
|
1062
|
+
|
1063
|
+
def disable_tiling(self) -> None:
|
1064
|
+
r"""
|
1065
|
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
1066
|
+
decoding in one step.
|
1067
|
+
"""
|
1068
|
+
self.use_tiling = False
|
1069
|
+
|
1070
|
+
def enable_slicing(self) -> None:
|
1071
|
+
r"""
|
1072
|
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
1073
|
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
1074
|
+
"""
|
1075
|
+
self.use_slicing = True
|
1076
|
+
|
1077
|
+
def disable_slicing(self) -> None:
|
1078
|
+
r"""
|
1079
|
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
1080
|
+
decoding in one step.
|
1081
|
+
"""
|
1082
|
+
self.use_slicing = False
|
1083
|
+
|
972
1084
|
@apply_forward_hook
|
973
1085
|
def encode(
|
974
1086
|
self, x: torch.Tensor, return_dict: bool = True
|
@@ -993,8 +1105,34 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
993
1105
|
return (posterior,)
|
994
1106
|
return AutoencoderKLOutput(latent_dist=posterior)
|
995
1107
|
|
1108
|
+
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
1109
|
+
batch_size, num_channels, num_frames, height, width = z.shape
|
1110
|
+
|
1111
|
+
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
|
1112
|
+
return self.tiled_decode(z, return_dict=return_dict)
|
1113
|
+
|
1114
|
+
frame_batch_size = self.num_latent_frames_batch_size
|
1115
|
+
dec = []
|
1116
|
+
for i in range(num_frames // frame_batch_size):
|
1117
|
+
remaining_frames = num_frames % frame_batch_size
|
1118
|
+
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
1119
|
+
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
1120
|
+
z_intermediate = z[:, :, start_frame:end_frame]
|
1121
|
+
if self.post_quant_conv is not None:
|
1122
|
+
z_intermediate = self.post_quant_conv(z_intermediate)
|
1123
|
+
z_intermediate = self.decoder(z_intermediate)
|
1124
|
+
dec.append(z_intermediate)
|
1125
|
+
|
1126
|
+
self._clear_fake_context_parallel_cache()
|
1127
|
+
dec = torch.cat(dec, dim=2)
|
1128
|
+
|
1129
|
+
if not return_dict:
|
1130
|
+
return (dec,)
|
1131
|
+
|
1132
|
+
return DecoderOutput(sample=dec)
|
1133
|
+
|
996
1134
|
@apply_forward_hook
|
997
|
-
def decode(self, z: torch.
|
1135
|
+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
998
1136
|
"""
|
999
1137
|
Decode a batch of images.
|
1000
1138
|
|
@@ -1007,13 +1145,111 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1007
1145
|
[`~models.vae.DecoderOutput`] or `tuple`:
|
1008
1146
|
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
1009
1147
|
returned.
|
1148
|
+
"""
|
1149
|
+
if self.use_slicing and z.shape[0] > 1:
|
1150
|
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
1151
|
+
decoded = torch.cat(decoded_slices)
|
1152
|
+
else:
|
1153
|
+
decoded = self._decode(z).sample
|
1010
1154
|
|
1155
|
+
if not return_dict:
|
1156
|
+
return (decoded,)
|
1157
|
+
return DecoderOutput(sample=decoded)
|
1158
|
+
|
1159
|
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
1160
|
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
1161
|
+
for y in range(blend_extent):
|
1162
|
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
1163
|
+
y / blend_extent
|
1164
|
+
)
|
1165
|
+
return b
|
1166
|
+
|
1167
|
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
1168
|
+
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
1169
|
+
for x in range(blend_extent):
|
1170
|
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
1171
|
+
x / blend_extent
|
1172
|
+
)
|
1173
|
+
return b
|
1174
|
+
|
1175
|
+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
1176
|
+
r"""
|
1177
|
+
Decode a batch of images using a tiled decoder.
|
1178
|
+
|
1179
|
+
Args:
|
1180
|
+
z (`torch.Tensor`): Input batch of latent vectors.
|
1181
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1182
|
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
1183
|
+
|
1184
|
+
Returns:
|
1185
|
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
1186
|
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
1187
|
+
returned.
|
1011
1188
|
"""
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1189
|
+
# Rough memory assessment:
|
1190
|
+
# - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
|
1191
|
+
# - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
|
1192
|
+
# - Assume fp16 (2 bytes per value).
|
1193
|
+
# Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
|
1194
|
+
#
|
1195
|
+
# Memory assessment when using tiling:
|
1196
|
+
# - Assume everything as above but now HxW is 240x360 by tiling in half
|
1197
|
+
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
|
1198
|
+
|
1199
|
+
batch_size, num_channels, num_frames, height, width = z.shape
|
1200
|
+
|
1201
|
+
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
|
1202
|
+
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
|
1203
|
+
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
|
1204
|
+
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
1205
|
+
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
1206
|
+
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
1207
|
+
frame_batch_size = self.num_latent_frames_batch_size
|
1208
|
+
|
1209
|
+
# Split z into overlapping tiles and decode them separately.
|
1210
|
+
# The tiles have an overlap to avoid seams between tiles.
|
1211
|
+
rows = []
|
1212
|
+
for i in range(0, height, overlap_height):
|
1213
|
+
row = []
|
1214
|
+
for j in range(0, width, overlap_width):
|
1215
|
+
time = []
|
1216
|
+
for k in range(num_frames // frame_batch_size):
|
1217
|
+
remaining_frames = num_frames % frame_batch_size
|
1218
|
+
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
1219
|
+
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
1220
|
+
tile = z[
|
1221
|
+
:,
|
1222
|
+
:,
|
1223
|
+
start_frame:end_frame,
|
1224
|
+
i : i + self.tile_latent_min_height,
|
1225
|
+
j : j + self.tile_latent_min_width,
|
1226
|
+
]
|
1227
|
+
if self.post_quant_conv is not None:
|
1228
|
+
tile = self.post_quant_conv(tile)
|
1229
|
+
tile = self.decoder(tile)
|
1230
|
+
time.append(tile)
|
1231
|
+
self._clear_fake_context_parallel_cache()
|
1232
|
+
row.append(torch.cat(time, dim=2))
|
1233
|
+
rows.append(row)
|
1234
|
+
|
1235
|
+
result_rows = []
|
1236
|
+
for i, row in enumerate(rows):
|
1237
|
+
result_row = []
|
1238
|
+
for j, tile in enumerate(row):
|
1239
|
+
# blend the above tile and the left tile
|
1240
|
+
# to the current tile and add the current tile to the result row
|
1241
|
+
if i > 0:
|
1242
|
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
|
1243
|
+
if j > 0:
|
1244
|
+
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
|
1245
|
+
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
1246
|
+
result_rows.append(torch.cat(result_row, dim=4))
|
1247
|
+
|
1248
|
+
dec = torch.cat(result_rows, dim=3)
|
1249
|
+
|
1015
1250
|
if not return_dict:
|
1016
1251
|
return (dec,)
|
1252
|
+
|
1017
1253
|
return DecoderOutput(sample=dec)
|
1018
1254
|
|
1019
1255
|
def forward(
|
diffusers/models/embeddings.py
CHANGED
@@ -374,6 +374,90 @@ class CogVideoXPatchEmbed(nn.Module):
|
|
374
374
|
return embeds
|
375
375
|
|
376
376
|
|
377
|
+
def get_3d_rotary_pos_embed(
|
378
|
+
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
379
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
380
|
+
"""
|
381
|
+
RoPE for video tokens with 3D structure.
|
382
|
+
|
383
|
+
Args:
|
384
|
+
embed_dim: (`int`):
|
385
|
+
The embedding dimension size, corresponding to hidden_size_head.
|
386
|
+
crops_coords (`Tuple[int]`):
|
387
|
+
The top-left and bottom-right coordinates of the crop.
|
388
|
+
grid_size (`Tuple[int]`):
|
389
|
+
The grid size of the spatial positional embedding (height, width).
|
390
|
+
temporal_size (`int`):
|
391
|
+
The size of the temporal dimension.
|
392
|
+
theta (`float`):
|
393
|
+
Scaling factor for frequency computation.
|
394
|
+
use_real (`bool`):
|
395
|
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
396
|
+
|
397
|
+
Returns:
|
398
|
+
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
399
|
+
"""
|
400
|
+
start, stop = crops_coords
|
401
|
+
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
402
|
+
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
403
|
+
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
404
|
+
|
405
|
+
# Compute dimensions for each axis
|
406
|
+
dim_t = embed_dim // 4
|
407
|
+
dim_h = embed_dim // 8 * 3
|
408
|
+
dim_w = embed_dim // 8 * 3
|
409
|
+
|
410
|
+
# Temporal frequencies
|
411
|
+
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
412
|
+
grid_t = torch.from_numpy(grid_t).float()
|
413
|
+
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
414
|
+
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
415
|
+
|
416
|
+
# Spatial frequencies for height and width
|
417
|
+
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
418
|
+
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
419
|
+
grid_h = torch.from_numpy(grid_h).float()
|
420
|
+
grid_w = torch.from_numpy(grid_w).float()
|
421
|
+
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
422
|
+
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
423
|
+
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
424
|
+
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
425
|
+
|
426
|
+
# Broadcast and concatenate tensors along specified dimension
|
427
|
+
def broadcast(tensors, dim=-1):
|
428
|
+
num_tensors = len(tensors)
|
429
|
+
shape_lens = {len(t.shape) for t in tensors}
|
430
|
+
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
431
|
+
shape_len = list(shape_lens)[0]
|
432
|
+
dim = (dim + shape_len) if dim < 0 else dim
|
433
|
+
dims = list(zip(*(list(t.shape) for t in tensors)))
|
434
|
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
435
|
+
assert all(
|
436
|
+
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
437
|
+
), "invalid dimensions for broadcastable concatenation"
|
438
|
+
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
439
|
+
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
440
|
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
441
|
+
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
442
|
+
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
443
|
+
return torch.cat(tensors, dim=dim)
|
444
|
+
|
445
|
+
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
446
|
+
|
447
|
+
t, h, w, d = freqs.shape
|
448
|
+
freqs = freqs.view(t * h * w, d)
|
449
|
+
|
450
|
+
# Generate sine and cosine components
|
451
|
+
sin = freqs.sin()
|
452
|
+
cos = freqs.cos()
|
453
|
+
|
454
|
+
if use_real:
|
455
|
+
return cos, sin
|
456
|
+
else:
|
457
|
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
458
|
+
return freqs_cis
|
459
|
+
|
460
|
+
|
377
461
|
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
378
462
|
"""
|
379
463
|
RoPE for image tokens with 2d structure.
|