diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. diffusers/__init__.py +34 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +170 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +35 -6
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +2 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
  53. diffusers/pipelines/cogview3/__init__.py +47 -0
  54. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  55. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  56. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  57. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  58. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  60. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  62. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  63. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  64. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  66. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  67. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  68. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  70. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  71. diffusers/pipelines/flux/__init__.py +10 -0
  72. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  73. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  74. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  76. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  77. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  78. diffusers/pipelines/free_noise_utils.py +365 -5
  79. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  80. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  81. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  82. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  83. diffusers/pipelines/kolors/tokenizer.py +4 -0
  84. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  86. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  87. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  89. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  90. diffusers/pipelines/pag/__init__.py +6 -0
  91. diffusers/pipelines/pag/pag_utils.py +8 -2
  92. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  96. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  97. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  98. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  100. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  101. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  102. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  103. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  106. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  107. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  108. diffusers/pipelines/pipeline_utils.py +123 -180
  109. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  111. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  117. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  120. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  121. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  122. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  126. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  127. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  130. diffusers/quantizers/__init__.py +16 -0
  131. diffusers/quantizers/auto.py +126 -0
  132. diffusers/quantizers/base.py +233 -0
  133. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  134. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  135. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  136. diffusers/quantizers/quantization_config.py +391 -0
  137. diffusers/schedulers/scheduling_ddim.py +4 -1
  138. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  140. diffusers/schedulers/scheduling_ddpm.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  142. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  143. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  145. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  146. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  147. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  149. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  150. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  151. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  152. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  153. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  154. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  155. diffusers/schedulers/scheduling_sasolver.py +78 -1
  156. diffusers/schedulers/scheduling_unclip.py +4 -1
  157. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  158. diffusers/training_utils.py +48 -18
  159. diffusers/utils/__init__.py +2 -1
  160. diffusers/utils/dummy_pt_objects.py +60 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
  162. diffusers/utils/hub_utils.py +16 -4
  163. diffusers/utils/import_utils.py +31 -8
  164. diffusers/utils/loading_utils.py +28 -4
  165. diffusers/utils/peft_utils.py +3 -3
  166. diffusers/utils/testing_utils.py +59 -0
  167. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  168. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
  169. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  170. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
  171. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  172. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ import torch.nn as nn
18
18
 
19
19
  from ...configuration_utils import ConfigMixin, register_to_config
20
20
  from ...loaders.single_file_model import FromOriginalModelMixin
21
+ from ...utils import deprecate
21
22
  from ...utils.accelerate_utils import apply_forward_hook
22
23
  from ..attention_processor import (
23
24
  ADDED_KV_ATTENTION_PROCESSORS,
@@ -245,6 +246,18 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
245
246
 
246
247
  self.set_attn_processor(processor)
247
248
 
249
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
250
+ batch_size, num_channels, height, width = x.shape
251
+
252
+ if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
253
+ return self._tiled_encode(x)
254
+
255
+ enc = self.encoder(x)
256
+ if self.quant_conv is not None:
257
+ enc = self.quant_conv(enc)
258
+
259
+ return enc
260
+
248
261
  @apply_forward_hook
249
262
  def encode(
250
263
  self, x: torch.Tensor, return_dict: bool = True
@@ -261,21 +274,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
261
274
  The latent representations of the encoded images. If `return_dict` is True, a
262
275
  [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
263
276
  """
264
- if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
265
- return self.tiled_encode(x, return_dict=return_dict)
266
-
267
277
  if self.use_slicing and x.shape[0] > 1:
268
- encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
278
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
269
279
  h = torch.cat(encoded_slices)
270
280
  else:
271
- h = self.encoder(x)
272
-
273
- if self.quant_conv is not None:
274
- moments = self.quant_conv(h)
275
- else:
276
- moments = h
281
+ h = self._encode(x)
277
282
 
278
- posterior = DiagonalGaussianDistribution(moments)
283
+ posterior = DiagonalGaussianDistribution(h)
279
284
 
280
285
  if not return_dict:
281
286
  return (posterior,)
@@ -337,6 +342,54 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
337
342
  b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
338
343
  return b
339
344
 
345
+ def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
346
+ r"""Encode a batch of images using a tiled encoder.
347
+
348
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
349
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
350
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
351
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
352
+ output, but they should be much less noticeable.
353
+
354
+ Args:
355
+ x (`torch.Tensor`): Input batch of images.
356
+
357
+ Returns:
358
+ `torch.Tensor`:
359
+ The latent representation of the encoded videos.
360
+ """
361
+
362
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
363
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
364
+ row_limit = self.tile_latent_min_size - blend_extent
365
+
366
+ # Split the image into 512x512 tiles and encode them separately.
367
+ rows = []
368
+ for i in range(0, x.shape[2], overlap_size):
369
+ row = []
370
+ for j in range(0, x.shape[3], overlap_size):
371
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
372
+ tile = self.encoder(tile)
373
+ if self.config.use_quant_conv:
374
+ tile = self.quant_conv(tile)
375
+ row.append(tile)
376
+ rows.append(row)
377
+ result_rows = []
378
+ for i, row in enumerate(rows):
379
+ result_row = []
380
+ for j, tile in enumerate(row):
381
+ # blend the above tile and the left tile
382
+ # to the current tile and add the current tile to the result row
383
+ if i > 0:
384
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
385
+ if j > 0:
386
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
387
+ result_row.append(tile[:, :, :row_limit, :row_limit])
388
+ result_rows.append(torch.cat(result_row, dim=3))
389
+
390
+ enc = torch.cat(result_rows, dim=2)
391
+ return enc
392
+
340
393
  def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
341
394
  r"""Encode a batch of images using a tiled encoder.
342
395
 
@@ -356,6 +409,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
356
409
  If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
357
410
  `tuple` is returned.
358
411
  """
412
+ deprecation_message = (
413
+ "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
414
+ "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
415
+ "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
416
+ )
417
+ deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
418
+
359
419
  overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
360
420
  blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
361
421
  row_limit = self.tile_latent_min_size - blend_extent
@@ -13,7 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Optional, Tuple, Union
16
+ from typing import Dict, Optional, Tuple, Union
17
17
 
18
18
  import numpy as np
19
19
  import torch
@@ -41,7 +41,9 @@ class CogVideoXSafeConv3d(nn.Conv3d):
41
41
  """
42
42
 
43
43
  def forward(self, input: torch.Tensor) -> torch.Tensor:
44
- memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
44
+ memory_count = (
45
+ (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
46
+ )
45
47
 
46
48
  # Set to 2GB, suitable for CuDNN
47
49
  if memory_count > 2:
@@ -115,34 +117,24 @@ class CogVideoXCausalConv3d(nn.Module):
115
117
  dilation=dilation,
116
118
  )
117
119
 
118
- self.conv_cache = None
119
-
120
- def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
120
+ def fake_context_parallel_forward(
121
+ self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
122
+ ) -> torch.Tensor:
121
123
  kernel_size = self.time_kernel_size
122
124
  if kernel_size > 1:
123
- cached_inputs = (
124
- [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
125
- )
125
+ cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
126
126
  inputs = torch.cat(cached_inputs + [inputs], dim=2)
127
127
  return inputs
128
128
 
129
- def _clear_fake_context_parallel_cache(self):
130
- del self.conv_cache
131
- self.conv_cache = None
132
-
133
- def forward(self, inputs: torch.Tensor) -> torch.Tensor:
134
- inputs = self.fake_context_parallel_forward(inputs)
135
-
136
- self._clear_fake_context_parallel_cache()
137
- # Note: we could move these to the cpu for a lower maximum memory usage but its only a few
138
- # hundred megabytes and so let's not do it for now
139
- self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
129
+ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
130
+ inputs = self.fake_context_parallel_forward(inputs, conv_cache)
131
+ conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
140
132
 
141
133
  padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
142
134
  inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
143
135
 
144
136
  output = self.conv(inputs)
145
- return output
137
+ return output, conv_cache
146
138
 
147
139
 
148
140
  class CogVideoXSpatialNorm3D(nn.Module):
@@ -172,7 +164,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
172
164
  self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
173
165
  self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
174
166
 
175
- def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
167
+ def forward(
168
+ self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
169
+ ) -> torch.Tensor:
170
+ new_conv_cache = {}
171
+ conv_cache = conv_cache or {}
172
+
176
173
  if f.shape[2] > 1 and f.shape[2] % 2 == 1:
177
174
  f_first, f_rest = f[:, :, :1], f[:, :, 1:]
178
175
  f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
@@ -183,9 +180,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
183
180
  else:
184
181
  zq = F.interpolate(zq, size=f.shape[-3:])
185
182
 
183
+ conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
184
+ conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
185
+
186
186
  norm_f = self.norm_layer(f)
187
- new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
188
- return new_f
187
+ new_f = norm_f * conv_y + conv_b
188
+ return new_f, new_conv_cache
189
189
 
190
190
 
191
191
  class CogVideoXResnetBlock3D(nn.Module):
@@ -236,6 +236,7 @@ class CogVideoXResnetBlock3D(nn.Module):
236
236
  self.out_channels = out_channels
237
237
  self.nonlinearity = get_activation(non_linearity)
238
238
  self.use_conv_shortcut = conv_shortcut
239
+ self.spatial_norm_dim = spatial_norm_dim
239
240
 
240
241
  if spatial_norm_dim is None:
241
242
  self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
@@ -279,34 +280,43 @@ class CogVideoXResnetBlock3D(nn.Module):
279
280
  inputs: torch.Tensor,
280
281
  temb: Optional[torch.Tensor] = None,
281
282
  zq: Optional[torch.Tensor] = None,
283
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
282
284
  ) -> torch.Tensor:
285
+ new_conv_cache = {}
286
+ conv_cache = conv_cache or {}
287
+
283
288
  hidden_states = inputs
284
289
 
285
290
  if zq is not None:
286
- hidden_states = self.norm1(hidden_states, zq)
291
+ hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
287
292
  else:
288
293
  hidden_states = self.norm1(hidden_states)
289
294
 
290
295
  hidden_states = self.nonlinearity(hidden_states)
291
- hidden_states = self.conv1(hidden_states)
296
+ hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
292
297
 
293
298
  if temb is not None:
294
299
  hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
295
300
 
296
301
  if zq is not None:
297
- hidden_states = self.norm2(hidden_states, zq)
302
+ hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
298
303
  else:
299
304
  hidden_states = self.norm2(hidden_states)
300
305
 
301
306
  hidden_states = self.nonlinearity(hidden_states)
302
307
  hidden_states = self.dropout(hidden_states)
303
- hidden_states = self.conv2(hidden_states)
308
+ hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
304
309
 
305
310
  if self.in_channels != self.out_channels:
306
- inputs = self.conv_shortcut(inputs)
311
+ if self.use_conv_shortcut:
312
+ inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
313
+ inputs, conv_cache=conv_cache.get("conv_shortcut")
314
+ )
315
+ else:
316
+ inputs = self.conv_shortcut(inputs)
307
317
 
308
318
  hidden_states = hidden_states + inputs
309
- return hidden_states
319
+ return hidden_states, new_conv_cache
310
320
 
311
321
 
312
322
  class CogVideoXDownBlock3D(nn.Module):
@@ -392,8 +402,16 @@ class CogVideoXDownBlock3D(nn.Module):
392
402
  hidden_states: torch.Tensor,
393
403
  temb: Optional[torch.Tensor] = None,
394
404
  zq: Optional[torch.Tensor] = None,
405
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
395
406
  ) -> torch.Tensor:
396
- for resnet in self.resnets:
407
+ r"""Forward method of the `CogVideoXDownBlock3D` class."""
408
+
409
+ new_conv_cache = {}
410
+ conv_cache = conv_cache or {}
411
+
412
+ for i, resnet in enumerate(self.resnets):
413
+ conv_cache_key = f"resnet_{i}"
414
+
397
415
  if self.training and self.gradient_checkpointing:
398
416
 
399
417
  def create_custom_forward(module):
@@ -402,17 +420,23 @@ class CogVideoXDownBlock3D(nn.Module):
402
420
 
403
421
  return create_forward
404
422
 
405
- hidden_states = torch.utils.checkpoint.checkpoint(
406
- create_custom_forward(resnet), hidden_states, temb, zq
423
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
424
+ create_custom_forward(resnet),
425
+ hidden_states,
426
+ temb,
427
+ zq,
428
+ conv_cache=conv_cache.get(conv_cache_key),
407
429
  )
408
430
  else:
409
- hidden_states = resnet(hidden_states, temb, zq)
431
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
432
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
433
+ )
410
434
 
411
435
  if self.downsamplers is not None:
412
436
  for downsampler in self.downsamplers:
413
437
  hidden_states = downsampler(hidden_states)
414
438
 
415
- return hidden_states
439
+ return hidden_states, new_conv_cache
416
440
 
417
441
 
418
442
  class CogVideoXMidBlock3D(nn.Module):
@@ -480,8 +504,16 @@ class CogVideoXMidBlock3D(nn.Module):
480
504
  hidden_states: torch.Tensor,
481
505
  temb: Optional[torch.Tensor] = None,
482
506
  zq: Optional[torch.Tensor] = None,
507
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
483
508
  ) -> torch.Tensor:
484
- for resnet in self.resnets:
509
+ r"""Forward method of the `CogVideoXMidBlock3D` class."""
510
+
511
+ new_conv_cache = {}
512
+ conv_cache = conv_cache or {}
513
+
514
+ for i, resnet in enumerate(self.resnets):
515
+ conv_cache_key = f"resnet_{i}"
516
+
485
517
  if self.training and self.gradient_checkpointing:
486
518
 
487
519
  def create_custom_forward(module):
@@ -490,13 +522,15 @@ class CogVideoXMidBlock3D(nn.Module):
490
522
 
491
523
  return create_forward
492
524
 
493
- hidden_states = torch.utils.checkpoint.checkpoint(
494
- create_custom_forward(resnet), hidden_states, temb, zq
525
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
526
+ create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
495
527
  )
496
528
  else:
497
- hidden_states = resnet(hidden_states, temb, zq)
529
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
530
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
531
+ )
498
532
 
499
- return hidden_states
533
+ return hidden_states, new_conv_cache
500
534
 
501
535
 
502
536
  class CogVideoXUpBlock3D(nn.Module):
@@ -584,9 +618,16 @@ class CogVideoXUpBlock3D(nn.Module):
584
618
  hidden_states: torch.Tensor,
585
619
  temb: Optional[torch.Tensor] = None,
586
620
  zq: Optional[torch.Tensor] = None,
621
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
587
622
  ) -> torch.Tensor:
588
623
  r"""Forward method of the `CogVideoXUpBlock3D` class."""
589
- for resnet in self.resnets:
624
+
625
+ new_conv_cache = {}
626
+ conv_cache = conv_cache or {}
627
+
628
+ for i, resnet in enumerate(self.resnets):
629
+ conv_cache_key = f"resnet_{i}"
630
+
590
631
  if self.training and self.gradient_checkpointing:
591
632
 
592
633
  def create_custom_forward(module):
@@ -595,17 +636,23 @@ class CogVideoXUpBlock3D(nn.Module):
595
636
 
596
637
  return create_forward
597
638
 
598
- hidden_states = torch.utils.checkpoint.checkpoint(
599
- create_custom_forward(resnet), hidden_states, temb, zq
639
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
640
+ create_custom_forward(resnet),
641
+ hidden_states,
642
+ temb,
643
+ zq,
644
+ conv_cache=conv_cache.get(conv_cache_key),
600
645
  )
601
646
  else:
602
- hidden_states = resnet(hidden_states, temb, zq)
647
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
648
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
649
+ )
603
650
 
604
651
  if self.upsamplers is not None:
605
652
  for upsampler in self.upsamplers:
606
653
  hidden_states = upsampler(hidden_states)
607
654
 
608
- return hidden_states
655
+ return hidden_states, new_conv_cache
609
656
 
610
657
 
611
658
  class CogVideoXEncoder3D(nn.Module):
@@ -705,9 +752,18 @@ class CogVideoXEncoder3D(nn.Module):
705
752
 
706
753
  self.gradient_checkpointing = False
707
754
 
708
- def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
755
+ def forward(
756
+ self,
757
+ sample: torch.Tensor,
758
+ temb: Optional[torch.Tensor] = None,
759
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
760
+ ) -> torch.Tensor:
709
761
  r"""The forward method of the `CogVideoXEncoder3D` class."""
710
- hidden_states = self.conv_in(sample)
762
+
763
+ new_conv_cache = {}
764
+ conv_cache = conv_cache or {}
765
+
766
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
711
767
 
712
768
  if self.training and self.gradient_checkpointing:
713
769
 
@@ -718,28 +774,44 @@ class CogVideoXEncoder3D(nn.Module):
718
774
  return custom_forward
719
775
 
720
776
  # 1. Down
721
- for down_block in self.down_blocks:
722
- hidden_states = torch.utils.checkpoint.checkpoint(
723
- create_custom_forward(down_block), hidden_states, temb, None
777
+ for i, down_block in enumerate(self.down_blocks):
778
+ conv_cache_key = f"down_block_{i}"
779
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
780
+ create_custom_forward(down_block),
781
+ hidden_states,
782
+ temb,
783
+ None,
784
+ conv_cache=conv_cache.get(conv_cache_key),
724
785
  )
725
786
 
726
787
  # 2. Mid
727
- hidden_states = torch.utils.checkpoint.checkpoint(
728
- create_custom_forward(self.mid_block), hidden_states, temb, None
788
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
789
+ create_custom_forward(self.mid_block),
790
+ hidden_states,
791
+ temb,
792
+ None,
793
+ conv_cache=conv_cache.get("mid_block"),
729
794
  )
730
795
  else:
731
796
  # 1. Down
732
- for down_block in self.down_blocks:
733
- hidden_states = down_block(hidden_states, temb, None)
797
+ for i, down_block in enumerate(self.down_blocks):
798
+ conv_cache_key = f"down_block_{i}"
799
+ hidden_states, new_conv_cache[conv_cache_key] = down_block(
800
+ hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
801
+ )
734
802
 
735
803
  # 2. Mid
736
- hidden_states = self.mid_block(hidden_states, temb, None)
804
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
805
+ hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
806
+ )
737
807
 
738
808
  # 3. Post-process
739
809
  hidden_states = self.norm_out(hidden_states)
740
810
  hidden_states = self.conv_act(hidden_states)
741
- hidden_states = self.conv_out(hidden_states)
742
- return hidden_states
811
+
812
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
813
+
814
+ return hidden_states, new_conv_cache
743
815
 
744
816
 
745
817
  class CogVideoXDecoder3D(nn.Module):
@@ -846,9 +918,18 @@ class CogVideoXDecoder3D(nn.Module):
846
918
 
847
919
  self.gradient_checkpointing = False
848
920
 
849
- def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
921
+ def forward(
922
+ self,
923
+ sample: torch.Tensor,
924
+ temb: Optional[torch.Tensor] = None,
925
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
926
+ ) -> torch.Tensor:
850
927
  r"""The forward method of the `CogVideoXDecoder3D` class."""
851
- hidden_states = self.conv_in(sample)
928
+
929
+ new_conv_cache = {}
930
+ conv_cache = conv_cache or {}
931
+
932
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
852
933
 
853
934
  if self.training and self.gradient_checkpointing:
854
935
 
@@ -859,28 +940,45 @@ class CogVideoXDecoder3D(nn.Module):
859
940
  return custom_forward
860
941
 
861
942
  # 1. Mid
862
- hidden_states = torch.utils.checkpoint.checkpoint(
863
- create_custom_forward(self.mid_block), hidden_states, temb, sample
943
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
944
+ create_custom_forward(self.mid_block),
945
+ hidden_states,
946
+ temb,
947
+ sample,
948
+ conv_cache=conv_cache.get("mid_block"),
864
949
  )
865
950
 
866
951
  # 2. Up
867
- for up_block in self.up_blocks:
868
- hidden_states = torch.utils.checkpoint.checkpoint(
869
- create_custom_forward(up_block), hidden_states, temb, sample
952
+ for i, up_block in enumerate(self.up_blocks):
953
+ conv_cache_key = f"up_block_{i}"
954
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
955
+ create_custom_forward(up_block),
956
+ hidden_states,
957
+ temb,
958
+ sample,
959
+ conv_cache=conv_cache.get(conv_cache_key),
870
960
  )
871
961
  else:
872
962
  # 1. Mid
873
- hidden_states = self.mid_block(hidden_states, temb, sample)
963
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
964
+ hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
965
+ )
874
966
 
875
967
  # 2. Up
876
- for up_block in self.up_blocks:
877
- hidden_states = up_block(hidden_states, temb, sample)
968
+ for i, up_block in enumerate(self.up_blocks):
969
+ conv_cache_key = f"up_block_{i}"
970
+ hidden_states, new_conv_cache[conv_cache_key] = up_block(
971
+ hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
972
+ )
878
973
 
879
974
  # 3. Post-process
880
- hidden_states = self.norm_out(hidden_states, sample)
975
+ hidden_states, new_conv_cache["norm_out"] = self.norm_out(
976
+ hidden_states, sample, conv_cache=conv_cache.get("norm_out")
977
+ )
881
978
  hidden_states = self.conv_act(hidden_states)
882
- hidden_states = self.conv_out(hidden_states)
883
- return hidden_states
979
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
980
+
981
+ return hidden_states, new_conv_cache
884
982
 
885
983
 
886
984
  class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
@@ -1019,12 +1117,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1019
1117
  if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
1020
1118
  module.gradient_checkpointing = value
1021
1119
 
1022
- def _clear_fake_context_parallel_cache(self):
1023
- for name, module in self.named_modules():
1024
- if isinstance(module, CogVideoXCausalConv3d):
1025
- logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
1026
- module._clear_fake_context_parallel_cache()
1027
-
1028
1120
  def enable_tiling(
1029
1121
  self,
1030
1122
  tile_sample_min_height: Optional[int] = None,
@@ -1090,21 +1182,22 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1090
1182
 
1091
1183
  frame_batch_size = self.num_sample_frames_batch_size
1092
1184
  # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1093
- num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
1185
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1186
+ num_batches = max(num_frames // frame_batch_size, 1)
1187
+ conv_cache = None
1094
1188
  enc = []
1189
+
1095
1190
  for i in range(num_batches):
1096
1191
  remaining_frames = num_frames % frame_batch_size
1097
1192
  start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1098
1193
  end_frame = frame_batch_size * (i + 1) + remaining_frames
1099
1194
  x_intermediate = x[:, :, start_frame:end_frame]
1100
- x_intermediate = self.encoder(x_intermediate)
1195
+ x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
1101
1196
  if self.quant_conv is not None:
1102
1197
  x_intermediate = self.quant_conv(x_intermediate)
1103
1198
  enc.append(x_intermediate)
1104
1199
 
1105
- self._clear_fake_context_parallel_cache()
1106
1200
  enc = torch.cat(enc, dim=2)
1107
-
1108
1201
  return enc
1109
1202
 
1110
1203
  @apply_forward_hook
@@ -1142,8 +1235,10 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1142
1235
  return self.tiled_decode(z, return_dict=return_dict)
1143
1236
 
1144
1237
  frame_batch_size = self.num_latent_frames_batch_size
1145
- num_batches = num_frames // frame_batch_size
1238
+ num_batches = max(num_frames // frame_batch_size, 1)
1239
+ conv_cache = None
1146
1240
  dec = []
1241
+
1147
1242
  for i in range(num_batches):
1148
1243
  remaining_frames = num_frames % frame_batch_size
1149
1244
  start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
@@ -1151,10 +1246,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1151
1246
  z_intermediate = z[:, :, start_frame:end_frame]
1152
1247
  if self.post_quant_conv is not None:
1153
1248
  z_intermediate = self.post_quant_conv(z_intermediate)
1154
- z_intermediate = self.decoder(z_intermediate)
1249
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1155
1250
  dec.append(z_intermediate)
1156
1251
 
1157
- self._clear_fake_context_parallel_cache()
1158
1252
  dec = torch.cat(dec, dim=2)
1159
1253
 
1160
1254
  if not return_dict:
@@ -1237,8 +1331,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1237
1331
  row = []
1238
1332
  for j in range(0, width, overlap_width):
1239
1333
  # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1240
- num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
1334
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1335
+ num_batches = max(num_frames // frame_batch_size, 1)
1336
+ conv_cache = None
1241
1337
  time = []
1338
+
1242
1339
  for k in range(num_batches):
1243
1340
  remaining_frames = num_frames % frame_batch_size
1244
1341
  start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
@@ -1250,11 +1347,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1250
1347
  i : i + self.tile_sample_min_height,
1251
1348
  j : j + self.tile_sample_min_width,
1252
1349
  ]
1253
- tile = self.encoder(tile)
1350
+ tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
1254
1351
  if self.quant_conv is not None:
1255
1352
  tile = self.quant_conv(tile)
1256
1353
  time.append(tile)
1257
- self._clear_fake_context_parallel_cache()
1354
+
1258
1355
  row.append(torch.cat(time, dim=2))
1259
1356
  rows.append(row)
1260
1357
 
@@ -1314,8 +1411,10 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1314
1411
  for i in range(0, height, overlap_height):
1315
1412
  row = []
1316
1413
  for j in range(0, width, overlap_width):
1317
- num_batches = num_frames // frame_batch_size
1414
+ num_batches = max(num_frames // frame_batch_size, 1)
1415
+ conv_cache = None
1318
1416
  time = []
1417
+
1319
1418
  for k in range(num_batches):
1320
1419
  remaining_frames = num_frames % frame_batch_size
1321
1420
  start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
@@ -1329,9 +1428,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1329
1428
  ]
1330
1429
  if self.post_quant_conv is not None:
1331
1430
  tile = self.post_quant_conv(tile)
1332
- tile = self.decoder(tile)
1431
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1333
1432
  time.append(tile)
1334
- self._clear_fake_context_parallel_cache()
1433
+
1335
1434
  row.append(torch.cat(time, dim=2))
1336
1435
  rows.append(row)
1337
1436