diffusers 0.30.0__py3-none-any.whl → 0.30.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
36
 
37
37
 
38
38
  class CogVideoXSafeConv3d(nn.Conv3d):
39
- """
39
+ r"""
40
40
  A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
41
41
  """
42
42
 
@@ -68,12 +68,12 @@ class CogVideoXCausalConv3d(nn.Module):
68
68
  r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
69
69
 
70
70
  Args:
71
- in_channels (int): Number of channels in the input tensor.
72
- out_channels (int): Number of output channels.
73
- kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolutional kernel.
74
- stride (int, optional): Stride of the convolution. Default is 1.
75
- dilation (int, optional): Dilation rate of the convolution. Default is 1.
76
- pad_mode (str, optional): Padding mode. Default is "constant".
71
+ in_channels (`int`): Number of channels in the input tensor.
72
+ out_channels (`int`): Number of output channels produced by the convolution.
73
+ kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
74
+ stride (`int`, defaults to `1`): Stride of the convolution.
75
+ dilation (`int`, defaults to `1`): Dilation rate of the convolution.
76
+ pad_mode (`str`, defaults to `"constant"`): Padding mode.
77
77
  """
78
78
 
79
79
  def __init__(
@@ -118,19 +118,12 @@ class CogVideoXCausalConv3d(nn.Module):
118
118
  self.conv_cache = None
119
119
 
120
120
  def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
121
- dim = self.temporal_dim
122
121
  kernel_size = self.time_kernel_size
123
- if kernel_size == 1:
124
- return inputs
125
-
126
- inputs = inputs.transpose(0, dim)
127
-
128
- if self.conv_cache is not None:
129
- inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0)
130
- else:
131
- inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0)
132
-
133
- inputs = inputs.transpose(0, dim).contiguous()
122
+ if kernel_size > 1:
123
+ cached_inputs = (
124
+ [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
125
+ )
126
+ inputs = torch.cat(cached_inputs + [inputs], dim=2)
134
127
  return inputs
135
128
 
136
129
  def _clear_fake_context_parallel_cache(self):
@@ -138,16 +131,17 @@ class CogVideoXCausalConv3d(nn.Module):
138
131
  self.conv_cache = None
139
132
 
140
133
  def forward(self, inputs: torch.Tensor) -> torch.Tensor:
141
- input_parallel = self.fake_context_parallel_forward(inputs)
134
+ inputs = self.fake_context_parallel_forward(inputs)
142
135
 
143
136
  self._clear_fake_context_parallel_cache()
144
- self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
137
+ # Note: we could move these to the cpu for a lower maximum memory usage but its only a few
138
+ # hundred megabytes and so let's not do it for now
139
+ self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
145
140
 
146
141
  padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
147
- input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
142
+ inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
148
143
 
149
- output_parallel = self.conv(input_parallel)
150
- output = output_parallel
144
+ output = self.conv(inputs)
151
145
  return output
152
146
 
153
147
 
@@ -163,6 +157,8 @@ class CogVideoXSpatialNorm3D(nn.Module):
163
157
  The number of channels for input to group normalization layer, and output of the spatial norm layer.
164
158
  zq_channels (`int`):
165
159
  The number of channels for the quantized vector as described in the paper.
160
+ groups (`int`):
161
+ Number of groups to separate the channels into for group normalization.
166
162
  """
167
163
 
168
164
  def __init__(
@@ -197,17 +193,26 @@ class CogVideoXResnetBlock3D(nn.Module):
197
193
  A 3D ResNet block used in the CogVideoX model.
198
194
 
199
195
  Args:
200
- in_channels (int): Number of input channels.
201
- out_channels (Optional[int], optional):
202
- Number of output channels. If None, defaults to `in_channels`. Default is None.
203
- dropout (float, optional): Dropout rate. Default is 0.0.
204
- temb_channels (int, optional): Number of time embedding channels. Default is 512.
205
- groups (int, optional): Number of groups for group normalization. Default is 32.
206
- eps (float, optional): Epsilon value for normalization layers. Default is 1e-6.
207
- non_linearity (str, optional): Activation function to use. Default is "swish".
208
- conv_shortcut (bool, optional): If True, use a convolutional shortcut. Default is False.
209
- spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
210
- pad_mode (str, optional): Padding mode. Default is "first".
196
+ in_channels (`int`):
197
+ Number of input channels.
198
+ out_channels (`int`, *optional*):
199
+ Number of output channels. If None, defaults to `in_channels`.
200
+ dropout (`float`, defaults to `0.0`):
201
+ Dropout rate.
202
+ temb_channels (`int`, defaults to `512`):
203
+ Number of time embedding channels.
204
+ groups (`int`, defaults to `32`):
205
+ Number of groups to separate the channels into for group normalization.
206
+ eps (`float`, defaults to `1e-6`):
207
+ Epsilon value for normalization layers.
208
+ non_linearity (`str`, defaults to `"swish"`):
209
+ Activation function to use.
210
+ conv_shortcut (bool, defaults to `False`):
211
+ Whether or not to use a convolution shortcut.
212
+ spatial_norm_dim (`int`, *optional*):
213
+ The dimension to use for spatial norm if it is to be used instead of group norm.
214
+ pad_mode (str, defaults to `"first"`):
215
+ Padding mode.
211
216
  """
212
217
 
213
218
  def __init__(
@@ -309,18 +314,28 @@ class CogVideoXDownBlock3D(nn.Module):
309
314
  A downsampling block used in the CogVideoX model.
310
315
 
311
316
  Args:
312
- in_channels (int): Number of input channels.
313
- out_channels (int): Number of output channels.
314
- temb_channels (int): Number of time embedding channels.
315
- dropout (float, optional): Dropout rate. Default is 0.0.
316
- num_layers (int, optional): Number of layers in the block. Default is 1.
317
- resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
318
- resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
319
- resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
320
- add_downsample (bool, optional): If True, add a downsampling layer at the end of the block. Default is True.
321
- downsample_padding (int, optional): Padding for the downsampling layer. Default is 0.
322
- compress_time (bool, optional): If True, apply temporal compression. Default is False.
323
- pad_mode (str, optional): Padding mode. Default is "first".
317
+ in_channels (`int`):
318
+ Number of input channels.
319
+ out_channels (`int`, *optional*):
320
+ Number of output channels. If None, defaults to `in_channels`.
321
+ temb_channels (`int`, defaults to `512`):
322
+ Number of time embedding channels.
323
+ num_layers (`int`, defaults to `1`):
324
+ Number of resnet layers.
325
+ dropout (`float`, defaults to `0.0`):
326
+ Dropout rate.
327
+ resnet_eps (`float`, defaults to `1e-6`):
328
+ Epsilon value for normalization layers.
329
+ resnet_act_fn (`str`, defaults to `"swish"`):
330
+ Activation function to use.
331
+ resnet_groups (`int`, defaults to `32`):
332
+ Number of groups to separate the channels into for group normalization.
333
+ add_downsample (`bool`, defaults to `True`):
334
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
335
+ compress_time (`bool`, defaults to `False`):
336
+ Whether or not to downsample across temporal dimension.
337
+ pad_mode (str, defaults to `"first"`):
338
+ Padding mode.
324
339
  """
325
340
 
326
341
  _supports_gradient_checkpointing = True
@@ -405,15 +420,24 @@ class CogVideoXMidBlock3D(nn.Module):
405
420
  A middle block used in the CogVideoX model.
406
421
 
407
422
  Args:
408
- in_channels (int): Number of input channels.
409
- temb_channels (int): Number of time embedding channels.
410
- dropout (float, optional): Dropout rate. Default is 0.0.
411
- num_layers (int, optional): Number of layers in the block. Default is 1.
412
- resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
413
- resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
414
- resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
415
- spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
416
- pad_mode (str, optional): Padding mode. Default is "first".
423
+ in_channels (`int`):
424
+ Number of input channels.
425
+ temb_channels (`int`, defaults to `512`):
426
+ Number of time embedding channels.
427
+ dropout (`float`, defaults to `0.0`):
428
+ Dropout rate.
429
+ num_layers (`int`, defaults to `1`):
430
+ Number of resnet layers.
431
+ resnet_eps (`float`, defaults to `1e-6`):
432
+ Epsilon value for normalization layers.
433
+ resnet_act_fn (`str`, defaults to `"swish"`):
434
+ Activation function to use.
435
+ resnet_groups (`int`, defaults to `32`):
436
+ Number of groups to separate the channels into for group normalization.
437
+ spatial_norm_dim (`int`, *optional*):
438
+ The dimension to use for spatial norm if it is to be used instead of group norm.
439
+ pad_mode (str, defaults to `"first"`):
440
+ Padding mode.
417
441
  """
418
442
 
419
443
  _supports_gradient_checkpointing = True
@@ -480,19 +504,30 @@ class CogVideoXUpBlock3D(nn.Module):
480
504
  An upsampling block used in the CogVideoX model.
481
505
 
482
506
  Args:
483
- in_channels (int): Number of input channels.
484
- out_channels (int): Number of output channels.
485
- temb_channels (int): Number of time embedding channels.
486
- dropout (float, optional): Dropout rate. Default is 0.0.
487
- num_layers (int, optional): Number of layers in the block. Default is 1.
488
- resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
489
- resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
490
- resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
491
- spatial_norm_dim (int, optional): Dimension of the spatial normalization. Default is 16.
492
- add_upsample (bool, optional): If True, add an upsampling layer at the end of the block. Default is True.
493
- upsample_padding (int, optional): Padding for the upsampling layer. Default is 1.
494
- compress_time (bool, optional): If True, apply temporal compression. Default is False.
495
- pad_mode (str, optional): Padding mode. Default is "first".
507
+ in_channels (`int`):
508
+ Number of input channels.
509
+ out_channels (`int`, *optional*):
510
+ Number of output channels. If None, defaults to `in_channels`.
511
+ temb_channels (`int`, defaults to `512`):
512
+ Number of time embedding channels.
513
+ dropout (`float`, defaults to `0.0`):
514
+ Dropout rate.
515
+ num_layers (`int`, defaults to `1`):
516
+ Number of resnet layers.
517
+ resnet_eps (`float`, defaults to `1e-6`):
518
+ Epsilon value for normalization layers.
519
+ resnet_act_fn (`str`, defaults to `"swish"`):
520
+ Activation function to use.
521
+ resnet_groups (`int`, defaults to `32`):
522
+ Number of groups to separate the channels into for group normalization.
523
+ spatial_norm_dim (`int`, defaults to `16`):
524
+ The dimension to use for spatial norm if it is to be used instead of group norm.
525
+ add_upsample (`bool`, defaults to `True`):
526
+ Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
527
+ compress_time (`bool`, defaults to `False`):
528
+ Whether or not to downsample across temporal dimension.
529
+ pad_mode (str, defaults to `"first"`):
530
+ Padding mode.
496
531
  """
497
532
 
498
533
  def __init__(
@@ -587,14 +622,12 @@ class CogVideoXEncoder3D(nn.Module):
587
622
  options.
588
623
  block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
589
624
  The number of output channels for each block.
625
+ act_fn (`str`, *optional*, defaults to `"silu"`):
626
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
590
627
  layers_per_block (`int`, *optional*, defaults to 2):
591
628
  The number of layers per block.
592
629
  norm_num_groups (`int`, *optional*, defaults to 32):
593
630
  The number of groups for normalization.
594
- act_fn (`str`, *optional*, defaults to `"silu"`):
595
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
596
- double_z (`bool`, *optional*, defaults to `True`):
597
- Whether to double the number of output channels for the last block.
598
631
  """
599
632
 
600
633
  _supports_gradient_checkpointing = True
@@ -723,14 +756,12 @@ class CogVideoXDecoder3D(nn.Module):
723
756
  The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
724
757
  block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
725
758
  The number of output channels for each block.
759
+ act_fn (`str`, *optional*, defaults to `"silu"`):
760
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
726
761
  layers_per_block (`int`, *optional*, defaults to 2):
727
762
  The number of layers per block.
728
763
  norm_num_groups (`int`, *optional*, defaults to 32):
729
764
  The number of groups for normalization.
730
- act_fn (`str`, *optional*, defaults to `"silu"`):
731
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
732
- norm_type (`str`, *optional*, defaults to `"group"`):
733
- The normalization type to use. Can be either `"group"` or `"spatial"`.
734
765
  """
735
766
 
736
767
  _supports_gradient_checkpointing = True
@@ -871,7 +902,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
871
902
  Tuple of block output channels.
872
903
  act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
873
904
  sample_size (`int`, *optional*, defaults to `32`): Sample input size.
874
- scaling_factor (`float`, *optional*, defaults to 0.18215):
905
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
875
906
  The component-wise standard deviation of the trained latent space computed using the first batch of the
876
907
  training set. This is used to scale the latent space to have unit variance when training the diffusion
877
908
  model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
@@ -911,7 +942,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
911
942
  norm_eps: float = 1e-6,
912
943
  norm_num_groups: int = 32,
913
944
  temporal_compression_ratio: float = 4,
914
- sample_size: int = 256,
945
+ sample_height: int = 480,
946
+ sample_width: int = 720,
915
947
  scaling_factor: float = 1.15258426,
916
948
  shift_factor: Optional[float] = None,
917
949
  latents_mean: Optional[Tuple[float]] = None,
@@ -950,25 +982,105 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
950
982
  self.use_slicing = False
951
983
  self.use_tiling = False
952
984
 
953
- self.tile_sample_min_size = self.config.sample_size
954
- sample_size = (
955
- self.config.sample_size[0]
956
- if isinstance(self.config.sample_size, (list, tuple))
957
- else self.config.sample_size
985
+ # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
986
+ # recommended because the temporal parts of the VAE, here, are tricky to understand.
987
+ # If you decode X latent frames together, the number of output frames is:
988
+ # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
989
+ #
990
+ # Example with num_latent_frames_batch_size = 2:
991
+ # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
992
+ # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
993
+ # => 6 * 8 = 48 frames
994
+ # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
995
+ # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
996
+ # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
997
+ # => 1 * 9 + 5 * 8 = 49 frames
998
+ # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
999
+ # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
1000
+ # number of temporal frames.
1001
+ self.num_latent_frames_batch_size = 2
1002
+
1003
+ # We make the minimum height and width of sample for tiling half that of the generally supported
1004
+ self.tile_sample_min_height = sample_height // 2
1005
+ self.tile_sample_min_width = sample_width // 2
1006
+ self.tile_latent_min_height = int(
1007
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
958
1008
  )
959
- self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
960
- self.tile_overlap_factor = 0.25
1009
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1010
+
1011
+ # These are experimental overlap factors that were chosen based on experimentation and seem to work best for
1012
+ # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
1013
+ # and so the tiling implementation has only been tested on those specific resolutions.
1014
+ self.tile_overlap_factor_height = 1 / 6
1015
+ self.tile_overlap_factor_width = 1 / 5
961
1016
 
962
1017
  def _set_gradient_checkpointing(self, module, value=False):
963
1018
  if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
964
1019
  module.gradient_checkpointing = value
965
1020
 
966
- def clear_fake_context_parallel_cache(self):
1021
+ def _clear_fake_context_parallel_cache(self):
967
1022
  for name, module in self.named_modules():
968
1023
  if isinstance(module, CogVideoXCausalConv3d):
969
1024
  logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
970
1025
  module._clear_fake_context_parallel_cache()
971
1026
 
1027
+ def enable_tiling(
1028
+ self,
1029
+ tile_sample_min_height: Optional[int] = None,
1030
+ tile_sample_min_width: Optional[int] = None,
1031
+ tile_overlap_factor_height: Optional[float] = None,
1032
+ tile_overlap_factor_width: Optional[float] = None,
1033
+ ) -> None:
1034
+ r"""
1035
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1036
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1037
+ processing larger images.
1038
+
1039
+ Args:
1040
+ tile_sample_min_height (`int`, *optional*):
1041
+ The minimum height required for a sample to be separated into tiles across the height dimension.
1042
+ tile_sample_min_width (`int`, *optional*):
1043
+ The minimum width required for a sample to be separated into tiles across the width dimension.
1044
+ tile_overlap_factor_height (`int`, *optional*):
1045
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
1046
+ no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
1047
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1048
+ tile_overlap_factor_width (`int`, *optional*):
1049
+ The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
1050
+ are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
1051
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1052
+ """
1053
+ self.use_tiling = True
1054
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1055
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1056
+ self.tile_latent_min_height = int(
1057
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1058
+ )
1059
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1060
+ self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
1061
+ self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
1062
+
1063
+ def disable_tiling(self) -> None:
1064
+ r"""
1065
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1066
+ decoding in one step.
1067
+ """
1068
+ self.use_tiling = False
1069
+
1070
+ def enable_slicing(self) -> None:
1071
+ r"""
1072
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1073
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1074
+ """
1075
+ self.use_slicing = True
1076
+
1077
+ def disable_slicing(self) -> None:
1078
+ r"""
1079
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1080
+ decoding in one step.
1081
+ """
1082
+ self.use_slicing = False
1083
+
972
1084
  @apply_forward_hook
973
1085
  def encode(
974
1086
  self, x: torch.Tensor, return_dict: bool = True
@@ -993,8 +1105,34 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
993
1105
  return (posterior,)
994
1106
  return AutoencoderKLOutput(latent_dist=posterior)
995
1107
 
1108
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1109
+ batch_size, num_channels, num_frames, height, width = z.shape
1110
+
1111
+ if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
1112
+ return self.tiled_decode(z, return_dict=return_dict)
1113
+
1114
+ frame_batch_size = self.num_latent_frames_batch_size
1115
+ dec = []
1116
+ for i in range(num_frames // frame_batch_size):
1117
+ remaining_frames = num_frames % frame_batch_size
1118
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1119
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1120
+ z_intermediate = z[:, :, start_frame:end_frame]
1121
+ if self.post_quant_conv is not None:
1122
+ z_intermediate = self.post_quant_conv(z_intermediate)
1123
+ z_intermediate = self.decoder(z_intermediate)
1124
+ dec.append(z_intermediate)
1125
+
1126
+ self._clear_fake_context_parallel_cache()
1127
+ dec = torch.cat(dec, dim=2)
1128
+
1129
+ if not return_dict:
1130
+ return (dec,)
1131
+
1132
+ return DecoderOutput(sample=dec)
1133
+
996
1134
  @apply_forward_hook
997
- def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
1135
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
998
1136
  """
999
1137
  Decode a batch of images.
1000
1138
 
@@ -1007,13 +1145,111 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1007
1145
  [`~models.vae.DecoderOutput`] or `tuple`:
1008
1146
  If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1009
1147
  returned.
1148
+ """
1149
+ if self.use_slicing and z.shape[0] > 1:
1150
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1151
+ decoded = torch.cat(decoded_slices)
1152
+ else:
1153
+ decoded = self._decode(z).sample
1010
1154
 
1155
+ if not return_dict:
1156
+ return (decoded,)
1157
+ return DecoderOutput(sample=decoded)
1158
+
1159
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1160
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1161
+ for y in range(blend_extent):
1162
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1163
+ y / blend_extent
1164
+ )
1165
+ return b
1166
+
1167
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1168
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
1169
+ for x in range(blend_extent):
1170
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1171
+ x / blend_extent
1172
+ )
1173
+ return b
1174
+
1175
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1176
+ r"""
1177
+ Decode a batch of images using a tiled decoder.
1178
+
1179
+ Args:
1180
+ z (`torch.Tensor`): Input batch of latent vectors.
1181
+ return_dict (`bool`, *optional*, defaults to `True`):
1182
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1183
+
1184
+ Returns:
1185
+ [`~models.vae.DecoderOutput`] or `tuple`:
1186
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1187
+ returned.
1011
1188
  """
1012
- if self.post_quant_conv is not None:
1013
- z = self.post_quant_conv(z)
1014
- dec = self.decoder(z)
1189
+ # Rough memory assessment:
1190
+ # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
1191
+ # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
1192
+ # - Assume fp16 (2 bytes per value).
1193
+ # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
1194
+ #
1195
+ # Memory assessment when using tiling:
1196
+ # - Assume everything as above but now HxW is 240x360 by tiling in half
1197
+ # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
1198
+
1199
+ batch_size, num_channels, num_frames, height, width = z.shape
1200
+
1201
+ overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
1202
+ overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
1203
+ blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
1204
+ blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
1205
+ row_limit_height = self.tile_sample_min_height - blend_extent_height
1206
+ row_limit_width = self.tile_sample_min_width - blend_extent_width
1207
+ frame_batch_size = self.num_latent_frames_batch_size
1208
+
1209
+ # Split z into overlapping tiles and decode them separately.
1210
+ # The tiles have an overlap to avoid seams between tiles.
1211
+ rows = []
1212
+ for i in range(0, height, overlap_height):
1213
+ row = []
1214
+ for j in range(0, width, overlap_width):
1215
+ time = []
1216
+ for k in range(num_frames // frame_batch_size):
1217
+ remaining_frames = num_frames % frame_batch_size
1218
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1219
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1220
+ tile = z[
1221
+ :,
1222
+ :,
1223
+ start_frame:end_frame,
1224
+ i : i + self.tile_latent_min_height,
1225
+ j : j + self.tile_latent_min_width,
1226
+ ]
1227
+ if self.post_quant_conv is not None:
1228
+ tile = self.post_quant_conv(tile)
1229
+ tile = self.decoder(tile)
1230
+ time.append(tile)
1231
+ self._clear_fake_context_parallel_cache()
1232
+ row.append(torch.cat(time, dim=2))
1233
+ rows.append(row)
1234
+
1235
+ result_rows = []
1236
+ for i, row in enumerate(rows):
1237
+ result_row = []
1238
+ for j, tile in enumerate(row):
1239
+ # blend the above tile and the left tile
1240
+ # to the current tile and add the current tile to the result row
1241
+ if i > 0:
1242
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1243
+ if j > 0:
1244
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1245
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1246
+ result_rows.append(torch.cat(result_row, dim=4))
1247
+
1248
+ dec = torch.cat(result_rows, dim=3)
1249
+
1015
1250
  if not return_dict:
1016
1251
  return (dec,)
1252
+
1017
1253
  return DecoderOutput(sample=dec)
1018
1254
 
1019
1255
  def forward(
@@ -374,6 +374,90 @@ class CogVideoXPatchEmbed(nn.Module):
374
374
  return embeds
375
375
 
376
376
 
377
+ def get_3d_rotary_pos_embed(
378
+ embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
379
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
380
+ """
381
+ RoPE for video tokens with 3D structure.
382
+
383
+ Args:
384
+ embed_dim: (`int`):
385
+ The embedding dimension size, corresponding to hidden_size_head.
386
+ crops_coords (`Tuple[int]`):
387
+ The top-left and bottom-right coordinates of the crop.
388
+ grid_size (`Tuple[int]`):
389
+ The grid size of the spatial positional embedding (height, width).
390
+ temporal_size (`int`):
391
+ The size of the temporal dimension.
392
+ theta (`float`):
393
+ Scaling factor for frequency computation.
394
+ use_real (`bool`):
395
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
396
+
397
+ Returns:
398
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
399
+ """
400
+ start, stop = crops_coords
401
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
402
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
403
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
404
+
405
+ # Compute dimensions for each axis
406
+ dim_t = embed_dim // 4
407
+ dim_h = embed_dim // 8 * 3
408
+ dim_w = embed_dim // 8 * 3
409
+
410
+ # Temporal frequencies
411
+ freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
412
+ grid_t = torch.from_numpy(grid_t).float()
413
+ freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
414
+ freqs_t = freqs_t.repeat_interleave(2, dim=-1)
415
+
416
+ # Spatial frequencies for height and width
417
+ freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
418
+ freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
419
+ grid_h = torch.from_numpy(grid_h).float()
420
+ grid_w = torch.from_numpy(grid_w).float()
421
+ freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
422
+ freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
423
+ freqs_h = freqs_h.repeat_interleave(2, dim=-1)
424
+ freqs_w = freqs_w.repeat_interleave(2, dim=-1)
425
+
426
+ # Broadcast and concatenate tensors along specified dimension
427
+ def broadcast(tensors, dim=-1):
428
+ num_tensors = len(tensors)
429
+ shape_lens = {len(t.shape) for t in tensors}
430
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
431
+ shape_len = list(shape_lens)[0]
432
+ dim = (dim + shape_len) if dim < 0 else dim
433
+ dims = list(zip(*(list(t.shape) for t in tensors)))
434
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
435
+ assert all(
436
+ [*(len(set(t[1])) <= 2 for t in expandable_dims)]
437
+ ), "invalid dimensions for broadcastable concatenation"
438
+ max_dims = [(t[0], max(t[1])) for t in expandable_dims]
439
+ expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
440
+ expanded_dims.insert(dim, (dim, dims[dim]))
441
+ expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
442
+ tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
443
+ return torch.cat(tensors, dim=dim)
444
+
445
+ freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
446
+
447
+ t, h, w, d = freqs.shape
448
+ freqs = freqs.view(t * h * w, d)
449
+
450
+ # Generate sine and cosine components
451
+ sin = freqs.sin()
452
+ cos = freqs.cos()
453
+
454
+ if use_real:
455
+ return cos, sin
456
+ else:
457
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
458
+ return freqs_cis
459
+
460
+
377
461
  def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
378
462
  """
379
463
  RoPE for image tokens with 2d structure.