diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. diffusers/__init__.py +38 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +238 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +40 -7
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +6 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
  53. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  54. diffusers/pipelines/cogview3/__init__.py +47 -0
  55. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  56. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  57. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  58. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  60. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  62. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  63. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  64. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  66. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  67. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  68. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  70. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  71. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  72. diffusers/pipelines/flux/__init__.py +10 -0
  73. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  74. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  76. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  77. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  78. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  79. diffusers/pipelines/free_noise_utils.py +365 -5
  80. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  81. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  82. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  83. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  84. diffusers/pipelines/kolors/tokenizer.py +4 -0
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  86. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  87. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  89. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  90. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  91. diffusers/pipelines/pag/__init__.py +6 -0
  92. diffusers/pipelines/pag/pag_utils.py +8 -2
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  96. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  97. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  98. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  100. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  101. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  102. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  103. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  106. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  107. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  108. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  109. diffusers/pipelines/pipeline_utils.py +123 -180
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  111. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  113. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  117. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  120. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  121. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  122. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  123. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  129. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  131. diffusers/quantizers/__init__.py +16 -0
  132. diffusers/quantizers/auto.py +126 -0
  133. diffusers/quantizers/base.py +233 -0
  134. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  135. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  136. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  137. diffusers/quantizers/quantization_config.py +391 -0
  138. diffusers/schedulers/scheduling_ddim.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  140. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm.py +4 -1
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  143. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  148. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  149. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  150. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  151. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  152. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  155. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  156. diffusers/schedulers/scheduling_sasolver.py +78 -1
  157. diffusers/schedulers/scheduling_unclip.py +4 -1
  158. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  159. diffusers/training_utils.py +48 -18
  160. diffusers/utils/__init__.py +2 -1
  161. diffusers/utils/dummy_pt_objects.py +60 -0
  162. diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
  163. diffusers/utils/hub_utils.py +16 -4
  164. diffusers/utils/import_utils.py +31 -8
  165. diffusers/utils/loading_utils.py +28 -4
  166. diffusers/utils/peft_utils.py +3 -3
  167. diffusers/utils/testing_utils.py +59 -0
  168. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  169. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
  170. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  172. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  173. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Optional, Tuple, Union
16
+ from typing import Dict, Optional, Tuple, Union
17
17
 
18
18
  import numpy as np
19
19
  import torch
@@ -41,7 +41,9 @@ class CogVideoXSafeConv3d(nn.Conv3d):
41
41
  """
42
42
 
43
43
  def forward(self, input: torch.Tensor) -> torch.Tensor:
44
- memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
44
+ memory_count = (
45
+ (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
46
+ )
45
47
 
46
48
  # Set to 2GB, suitable for CuDNN
47
49
  if memory_count > 2:
@@ -115,34 +117,24 @@ class CogVideoXCausalConv3d(nn.Module):
115
117
  dilation=dilation,
116
118
  )
117
119
 
118
- self.conv_cache = None
119
-
120
- def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
120
+ def fake_context_parallel_forward(
121
+ self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
122
+ ) -> torch.Tensor:
121
123
  kernel_size = self.time_kernel_size
122
124
  if kernel_size > 1:
123
- cached_inputs = (
124
- [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
125
- )
125
+ cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
126
126
  inputs = torch.cat(cached_inputs + [inputs], dim=2)
127
127
  return inputs
128
128
 
129
- def _clear_fake_context_parallel_cache(self):
130
- del self.conv_cache
131
- self.conv_cache = None
132
-
133
- def forward(self, inputs: torch.Tensor) -> torch.Tensor:
134
- inputs = self.fake_context_parallel_forward(inputs)
135
-
136
- self._clear_fake_context_parallel_cache()
137
- # Note: we could move these to the cpu for a lower maximum memory usage but its only a few
138
- # hundred megabytes and so let's not do it for now
139
- self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
129
+ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
130
+ inputs = self.fake_context_parallel_forward(inputs, conv_cache)
131
+ conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
140
132
 
141
133
  padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
142
134
  inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
143
135
 
144
136
  output = self.conv(inputs)
145
- return output
137
+ return output, conv_cache
146
138
 
147
139
 
148
140
  class CogVideoXSpatialNorm3D(nn.Module):
@@ -172,7 +164,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
172
164
  self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
173
165
  self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
174
166
 
175
- def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
167
+ def forward(
168
+ self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
169
+ ) -> torch.Tensor:
170
+ new_conv_cache = {}
171
+ conv_cache = conv_cache or {}
172
+
176
173
  if f.shape[2] > 1 and f.shape[2] % 2 == 1:
177
174
  f_first, f_rest = f[:, :, :1], f[:, :, 1:]
178
175
  f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
@@ -183,9 +180,12 @@ class CogVideoXSpatialNorm3D(nn.Module):
183
180
  else:
184
181
  zq = F.interpolate(zq, size=f.shape[-3:])
185
182
 
183
+ conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
184
+ conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
185
+
186
186
  norm_f = self.norm_layer(f)
187
- new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
188
- return new_f
187
+ new_f = norm_f * conv_y + conv_b
188
+ return new_f, new_conv_cache
189
189
 
190
190
 
191
191
  class CogVideoXResnetBlock3D(nn.Module):
@@ -236,6 +236,7 @@ class CogVideoXResnetBlock3D(nn.Module):
236
236
  self.out_channels = out_channels
237
237
  self.nonlinearity = get_activation(non_linearity)
238
238
  self.use_conv_shortcut = conv_shortcut
239
+ self.spatial_norm_dim = spatial_norm_dim
239
240
 
240
241
  if spatial_norm_dim is None:
241
242
  self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
@@ -279,34 +280,43 @@ class CogVideoXResnetBlock3D(nn.Module):
279
280
  inputs: torch.Tensor,
280
281
  temb: Optional[torch.Tensor] = None,
281
282
  zq: Optional[torch.Tensor] = None,
283
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
282
284
  ) -> torch.Tensor:
285
+ new_conv_cache = {}
286
+ conv_cache = conv_cache or {}
287
+
283
288
  hidden_states = inputs
284
289
 
285
290
  if zq is not None:
286
- hidden_states = self.norm1(hidden_states, zq)
291
+ hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
287
292
  else:
288
293
  hidden_states = self.norm1(hidden_states)
289
294
 
290
295
  hidden_states = self.nonlinearity(hidden_states)
291
- hidden_states = self.conv1(hidden_states)
296
+ hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
292
297
 
293
298
  if temb is not None:
294
299
  hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
295
300
 
296
301
  if zq is not None:
297
- hidden_states = self.norm2(hidden_states, zq)
302
+ hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
298
303
  else:
299
304
  hidden_states = self.norm2(hidden_states)
300
305
 
301
306
  hidden_states = self.nonlinearity(hidden_states)
302
307
  hidden_states = self.dropout(hidden_states)
303
- hidden_states = self.conv2(hidden_states)
308
+ hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
304
309
 
305
310
  if self.in_channels != self.out_channels:
306
- inputs = self.conv_shortcut(inputs)
311
+ if self.use_conv_shortcut:
312
+ inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
313
+ inputs, conv_cache=conv_cache.get("conv_shortcut")
314
+ )
315
+ else:
316
+ inputs = self.conv_shortcut(inputs)
307
317
 
308
318
  hidden_states = hidden_states + inputs
309
- return hidden_states
319
+ return hidden_states, new_conv_cache
310
320
 
311
321
 
312
322
  class CogVideoXDownBlock3D(nn.Module):
@@ -392,8 +402,16 @@ class CogVideoXDownBlock3D(nn.Module):
392
402
  hidden_states: torch.Tensor,
393
403
  temb: Optional[torch.Tensor] = None,
394
404
  zq: Optional[torch.Tensor] = None,
405
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
395
406
  ) -> torch.Tensor:
396
- for resnet in self.resnets:
407
+ r"""Forward method of the `CogVideoXDownBlock3D` class."""
408
+
409
+ new_conv_cache = {}
410
+ conv_cache = conv_cache or {}
411
+
412
+ for i, resnet in enumerate(self.resnets):
413
+ conv_cache_key = f"resnet_{i}"
414
+
397
415
  if self.training and self.gradient_checkpointing:
398
416
 
399
417
  def create_custom_forward(module):
@@ -402,17 +420,23 @@ class CogVideoXDownBlock3D(nn.Module):
402
420
 
403
421
  return create_forward
404
422
 
405
- hidden_states = torch.utils.checkpoint.checkpoint(
406
- create_custom_forward(resnet), hidden_states, temb, zq
423
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
424
+ create_custom_forward(resnet),
425
+ hidden_states,
426
+ temb,
427
+ zq,
428
+ conv_cache=conv_cache.get(conv_cache_key),
407
429
  )
408
430
  else:
409
- hidden_states = resnet(hidden_states, temb, zq)
431
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
432
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
433
+ )
410
434
 
411
435
  if self.downsamplers is not None:
412
436
  for downsampler in self.downsamplers:
413
437
  hidden_states = downsampler(hidden_states)
414
438
 
415
- return hidden_states
439
+ return hidden_states, new_conv_cache
416
440
 
417
441
 
418
442
  class CogVideoXMidBlock3D(nn.Module):
@@ -480,8 +504,16 @@ class CogVideoXMidBlock3D(nn.Module):
480
504
  hidden_states: torch.Tensor,
481
505
  temb: Optional[torch.Tensor] = None,
482
506
  zq: Optional[torch.Tensor] = None,
507
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
483
508
  ) -> torch.Tensor:
484
- for resnet in self.resnets:
509
+ r"""Forward method of the `CogVideoXMidBlock3D` class."""
510
+
511
+ new_conv_cache = {}
512
+ conv_cache = conv_cache or {}
513
+
514
+ for i, resnet in enumerate(self.resnets):
515
+ conv_cache_key = f"resnet_{i}"
516
+
485
517
  if self.training and self.gradient_checkpointing:
486
518
 
487
519
  def create_custom_forward(module):
@@ -490,13 +522,15 @@ class CogVideoXMidBlock3D(nn.Module):
490
522
 
491
523
  return create_forward
492
524
 
493
- hidden_states = torch.utils.checkpoint.checkpoint(
494
- create_custom_forward(resnet), hidden_states, temb, zq
525
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
526
+ create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
495
527
  )
496
528
  else:
497
- hidden_states = resnet(hidden_states, temb, zq)
529
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
530
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
531
+ )
498
532
 
499
- return hidden_states
533
+ return hidden_states, new_conv_cache
500
534
 
501
535
 
502
536
  class CogVideoXUpBlock3D(nn.Module):
@@ -584,9 +618,16 @@ class CogVideoXUpBlock3D(nn.Module):
584
618
  hidden_states: torch.Tensor,
585
619
  temb: Optional[torch.Tensor] = None,
586
620
  zq: Optional[torch.Tensor] = None,
621
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
587
622
  ) -> torch.Tensor:
588
623
  r"""Forward method of the `CogVideoXUpBlock3D` class."""
589
- for resnet in self.resnets:
624
+
625
+ new_conv_cache = {}
626
+ conv_cache = conv_cache or {}
627
+
628
+ for i, resnet in enumerate(self.resnets):
629
+ conv_cache_key = f"resnet_{i}"
630
+
590
631
  if self.training and self.gradient_checkpointing:
591
632
 
592
633
  def create_custom_forward(module):
@@ -595,17 +636,23 @@ class CogVideoXUpBlock3D(nn.Module):
595
636
 
596
637
  return create_forward
597
638
 
598
- hidden_states = torch.utils.checkpoint.checkpoint(
599
- create_custom_forward(resnet), hidden_states, temb, zq
639
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
640
+ create_custom_forward(resnet),
641
+ hidden_states,
642
+ temb,
643
+ zq,
644
+ conv_cache=conv_cache.get(conv_cache_key),
600
645
  )
601
646
  else:
602
- hidden_states = resnet(hidden_states, temb, zq)
647
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
648
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
649
+ )
603
650
 
604
651
  if self.upsamplers is not None:
605
652
  for upsampler in self.upsamplers:
606
653
  hidden_states = upsampler(hidden_states)
607
654
 
608
- return hidden_states
655
+ return hidden_states, new_conv_cache
609
656
 
610
657
 
611
658
  class CogVideoXEncoder3D(nn.Module):
@@ -705,9 +752,18 @@ class CogVideoXEncoder3D(nn.Module):
705
752
 
706
753
  self.gradient_checkpointing = False
707
754
 
708
- def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
755
+ def forward(
756
+ self,
757
+ sample: torch.Tensor,
758
+ temb: Optional[torch.Tensor] = None,
759
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
760
+ ) -> torch.Tensor:
709
761
  r"""The forward method of the `CogVideoXEncoder3D` class."""
710
- hidden_states = self.conv_in(sample)
762
+
763
+ new_conv_cache = {}
764
+ conv_cache = conv_cache or {}
765
+
766
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
711
767
 
712
768
  if self.training and self.gradient_checkpointing:
713
769
 
@@ -718,28 +774,44 @@ class CogVideoXEncoder3D(nn.Module):
718
774
  return custom_forward
719
775
 
720
776
  # 1. Down
721
- for down_block in self.down_blocks:
722
- hidden_states = torch.utils.checkpoint.checkpoint(
723
- create_custom_forward(down_block), hidden_states, temb, None
777
+ for i, down_block in enumerate(self.down_blocks):
778
+ conv_cache_key = f"down_block_{i}"
779
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
780
+ create_custom_forward(down_block),
781
+ hidden_states,
782
+ temb,
783
+ None,
784
+ conv_cache=conv_cache.get(conv_cache_key),
724
785
  )
725
786
 
726
787
  # 2. Mid
727
- hidden_states = torch.utils.checkpoint.checkpoint(
728
- create_custom_forward(self.mid_block), hidden_states, temb, None
788
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
789
+ create_custom_forward(self.mid_block),
790
+ hidden_states,
791
+ temb,
792
+ None,
793
+ conv_cache=conv_cache.get("mid_block"),
729
794
  )
730
795
  else:
731
796
  # 1. Down
732
- for down_block in self.down_blocks:
733
- hidden_states = down_block(hidden_states, temb, None)
797
+ for i, down_block in enumerate(self.down_blocks):
798
+ conv_cache_key = f"down_block_{i}"
799
+ hidden_states, new_conv_cache[conv_cache_key] = down_block(
800
+ hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
801
+ )
734
802
 
735
803
  # 2. Mid
736
- hidden_states = self.mid_block(hidden_states, temb, None)
804
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
805
+ hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
806
+ )
737
807
 
738
808
  # 3. Post-process
739
809
  hidden_states = self.norm_out(hidden_states)
740
810
  hidden_states = self.conv_act(hidden_states)
741
- hidden_states = self.conv_out(hidden_states)
742
- return hidden_states
811
+
812
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
813
+
814
+ return hidden_states, new_conv_cache
743
815
 
744
816
 
745
817
  class CogVideoXDecoder3D(nn.Module):
@@ -846,9 +918,18 @@ class CogVideoXDecoder3D(nn.Module):
846
918
 
847
919
  self.gradient_checkpointing = False
848
920
 
849
- def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
921
+ def forward(
922
+ self,
923
+ sample: torch.Tensor,
924
+ temb: Optional[torch.Tensor] = None,
925
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
926
+ ) -> torch.Tensor:
850
927
  r"""The forward method of the `CogVideoXDecoder3D` class."""
851
- hidden_states = self.conv_in(sample)
928
+
929
+ new_conv_cache = {}
930
+ conv_cache = conv_cache or {}
931
+
932
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
852
933
 
853
934
  if self.training and self.gradient_checkpointing:
854
935
 
@@ -859,28 +940,45 @@ class CogVideoXDecoder3D(nn.Module):
859
940
  return custom_forward
860
941
 
861
942
  # 1. Mid
862
- hidden_states = torch.utils.checkpoint.checkpoint(
863
- create_custom_forward(self.mid_block), hidden_states, temb, sample
943
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
944
+ create_custom_forward(self.mid_block),
945
+ hidden_states,
946
+ temb,
947
+ sample,
948
+ conv_cache=conv_cache.get("mid_block"),
864
949
  )
865
950
 
866
951
  # 2. Up
867
- for up_block in self.up_blocks:
868
- hidden_states = torch.utils.checkpoint.checkpoint(
869
- create_custom_forward(up_block), hidden_states, temb, sample
952
+ for i, up_block in enumerate(self.up_blocks):
953
+ conv_cache_key = f"up_block_{i}"
954
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
955
+ create_custom_forward(up_block),
956
+ hidden_states,
957
+ temb,
958
+ sample,
959
+ conv_cache=conv_cache.get(conv_cache_key),
870
960
  )
871
961
  else:
872
962
  # 1. Mid
873
- hidden_states = self.mid_block(hidden_states, temb, sample)
963
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
964
+ hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
965
+ )
874
966
 
875
967
  # 2. Up
876
- for up_block in self.up_blocks:
877
- hidden_states = up_block(hidden_states, temb, sample)
968
+ for i, up_block in enumerate(self.up_blocks):
969
+ conv_cache_key = f"up_block_{i}"
970
+ hidden_states, new_conv_cache[conv_cache_key] = up_block(
971
+ hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
972
+ )
878
973
 
879
974
  # 3. Post-process
880
- hidden_states = self.norm_out(hidden_states, sample)
975
+ hidden_states, new_conv_cache["norm_out"] = self.norm_out(
976
+ hidden_states, sample, conv_cache=conv_cache.get("norm_out")
977
+ )
881
978
  hidden_states = self.conv_act(hidden_states)
882
- hidden_states = self.conv_out(hidden_states)
883
- return hidden_states
979
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
980
+
981
+ return hidden_states, new_conv_cache
884
982
 
885
983
 
886
984
  class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
@@ -999,6 +1097,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
999
1097
  # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
1000
1098
  # number of temporal frames.
1001
1099
  self.num_latent_frames_batch_size = 2
1100
+ self.num_sample_frames_batch_size = 8
1002
1101
 
1003
1102
  # We make the minimum height and width of sample for tiling half that of the generally supported
1004
1103
  self.tile_sample_min_height = sample_height // 2
@@ -1018,12 +1117,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1018
1117
  if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
1019
1118
  module.gradient_checkpointing = value
1020
1119
 
1021
- def _clear_fake_context_parallel_cache(self):
1022
- for name, module in self.named_modules():
1023
- if isinstance(module, CogVideoXCausalConv3d):
1024
- logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
1025
- module._clear_fake_context_parallel_cache()
1026
-
1027
1120
  def enable_tiling(
1028
1121
  self,
1029
1122
  tile_sample_min_height: Optional[int] = None,
@@ -1081,6 +1174,32 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1081
1174
  """
1082
1175
  self.use_slicing = False
1083
1176
 
1177
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
1178
+ batch_size, num_channels, num_frames, height, width = x.shape
1179
+
1180
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1181
+ return self.tiled_encode(x)
1182
+
1183
+ frame_batch_size = self.num_sample_frames_batch_size
1184
+ # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1185
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1186
+ num_batches = max(num_frames // frame_batch_size, 1)
1187
+ conv_cache = None
1188
+ enc = []
1189
+
1190
+ for i in range(num_batches):
1191
+ remaining_frames = num_frames % frame_batch_size
1192
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1193
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1194
+ x_intermediate = x[:, :, start_frame:end_frame]
1195
+ x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
1196
+ if self.quant_conv is not None:
1197
+ x_intermediate = self.quant_conv(x_intermediate)
1198
+ enc.append(x_intermediate)
1199
+
1200
+ enc = torch.cat(enc, dim=2)
1201
+ return enc
1202
+
1084
1203
  @apply_forward_hook
1085
1204
  def encode(
1086
1205
  self, x: torch.Tensor, return_dict: bool = True
@@ -1094,13 +1213,17 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1094
1213
  Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1095
1214
 
1096
1215
  Returns:
1097
- The latent representations of the encoded images. If `return_dict` is True, a
1216
+ The latent representations of the encoded videos. If `return_dict` is True, a
1098
1217
  [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
1099
1218
  """
1100
- h = self.encoder(x)
1101
- if self.quant_conv is not None:
1102
- h = self.quant_conv(h)
1219
+ if self.use_slicing and x.shape[0] > 1:
1220
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1221
+ h = torch.cat(encoded_slices)
1222
+ else:
1223
+ h = self._encode(x)
1224
+
1103
1225
  posterior = DiagonalGaussianDistribution(h)
1226
+
1104
1227
  if not return_dict:
1105
1228
  return (posterior,)
1106
1229
  return AutoencoderKLOutput(latent_dist=posterior)
@@ -1112,18 +1235,20 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1112
1235
  return self.tiled_decode(z, return_dict=return_dict)
1113
1236
 
1114
1237
  frame_batch_size = self.num_latent_frames_batch_size
1238
+ num_batches = max(num_frames // frame_batch_size, 1)
1239
+ conv_cache = None
1115
1240
  dec = []
1116
- for i in range(num_frames // frame_batch_size):
1241
+
1242
+ for i in range(num_batches):
1117
1243
  remaining_frames = num_frames % frame_batch_size
1118
1244
  start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1119
1245
  end_frame = frame_batch_size * (i + 1) + remaining_frames
1120
1246
  z_intermediate = z[:, :, start_frame:end_frame]
1121
1247
  if self.post_quant_conv is not None:
1122
1248
  z_intermediate = self.post_quant_conv(z_intermediate)
1123
- z_intermediate = self.decoder(z_intermediate)
1249
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1124
1250
  dec.append(z_intermediate)
1125
1251
 
1126
- self._clear_fake_context_parallel_cache()
1127
1252
  dec = torch.cat(dec, dim=2)
1128
1253
 
1129
1254
  if not return_dict:
@@ -1172,6 +1297,80 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1172
1297
  )
1173
1298
  return b
1174
1299
 
1300
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1301
+ r"""Encode a batch of images using a tiled encoder.
1302
+
1303
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
1304
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
1305
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
1306
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
1307
+ output, but they should be much less noticeable.
1308
+
1309
+ Args:
1310
+ x (`torch.Tensor`): Input batch of videos.
1311
+
1312
+ Returns:
1313
+ `torch.Tensor`:
1314
+ The latent representation of the encoded videos.
1315
+ """
1316
+ # For a rough memory estimate, take a look at the `tiled_decode` method.
1317
+ batch_size, num_channels, num_frames, height, width = x.shape
1318
+
1319
+ overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
1320
+ overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
1321
+ blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
1322
+ blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
1323
+ row_limit_height = self.tile_latent_min_height - blend_extent_height
1324
+ row_limit_width = self.tile_latent_min_width - blend_extent_width
1325
+ frame_batch_size = self.num_sample_frames_batch_size
1326
+
1327
+ # Split x into overlapping tiles and encode them separately.
1328
+ # The tiles have an overlap to avoid seams between tiles.
1329
+ rows = []
1330
+ for i in range(0, height, overlap_height):
1331
+ row = []
1332
+ for j in range(0, width, overlap_width):
1333
+ # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1334
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1335
+ num_batches = max(num_frames // frame_batch_size, 1)
1336
+ conv_cache = None
1337
+ time = []
1338
+
1339
+ for k in range(num_batches):
1340
+ remaining_frames = num_frames % frame_batch_size
1341
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1342
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1343
+ tile = x[
1344
+ :,
1345
+ :,
1346
+ start_frame:end_frame,
1347
+ i : i + self.tile_sample_min_height,
1348
+ j : j + self.tile_sample_min_width,
1349
+ ]
1350
+ tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
1351
+ if self.quant_conv is not None:
1352
+ tile = self.quant_conv(tile)
1353
+ time.append(tile)
1354
+
1355
+ row.append(torch.cat(time, dim=2))
1356
+ rows.append(row)
1357
+
1358
+ result_rows = []
1359
+ for i, row in enumerate(rows):
1360
+ result_row = []
1361
+ for j, tile in enumerate(row):
1362
+ # blend the above tile and the left tile
1363
+ # to the current tile and add the current tile to the result row
1364
+ if i > 0:
1365
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1366
+ if j > 0:
1367
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1368
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1369
+ result_rows.append(torch.cat(result_row, dim=4))
1370
+
1371
+ enc = torch.cat(result_rows, dim=3)
1372
+ return enc
1373
+
1175
1374
  def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1176
1375
  r"""
1177
1376
  Decode a batch of images using a tiled decoder.
@@ -1212,8 +1411,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1212
1411
  for i in range(0, height, overlap_height):
1213
1412
  row = []
1214
1413
  for j in range(0, width, overlap_width):
1414
+ num_batches = max(num_frames // frame_batch_size, 1)
1415
+ conv_cache = None
1215
1416
  time = []
1216
- for k in range(num_frames // frame_batch_size):
1417
+
1418
+ for k in range(num_batches):
1217
1419
  remaining_frames = num_frames % frame_batch_size
1218
1420
  start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1219
1421
  end_frame = frame_batch_size * (k + 1) + remaining_frames
@@ -1226,9 +1428,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1226
1428
  ]
1227
1429
  if self.post_quant_conv is not None:
1228
1430
  tile = self.post_quant_conv(tile)
1229
- tile = self.decoder(tile)
1431
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1230
1432
  time.append(tile)
1231
- self._clear_fake_context_parallel_cache()
1433
+
1232
1434
  row.append(torch.cat(time, dim=2))
1233
1435
  rows.append(row)
1234
1436