diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +74 -28
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,23 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL
6
+ import torch
7
+
8
+ from diffusers.utils import BaseOutput
9
+
10
+
11
+ @dataclass
12
+ class AllegroPipelineOutput(BaseOutput):
13
+ r"""
14
+ Output class for Allegro pipelines.
15
+
16
+ Args:
17
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
18
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
19
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
20
+ `(batch_size, num_frames, channels, height, width)`.
21
+ """
22
+
23
+ frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
@@ -21,14 +21,20 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
21
21
 
22
22
  from ...image_processor import PipelineImageInput
23
23
  from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
24
- from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
24
+ from ...models import (
25
+ AutoencoderKL,
26
+ ControlNetModel,
27
+ ImageProjection,
28
+ MultiControlNetModel,
29
+ UNet2DConditionModel,
30
+ UNetMotionModel,
31
+ )
25
32
  from ...models.lora import adjust_lora_scale_text_encoder
26
33
  from ...models.unets.unet_motion_model import MotionAdapter
27
34
  from ...schedulers import KarrasDiffusionSchedulers
28
35
  from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
29
36
  from ...utils.torch_utils import is_compiled_module, randn_tensor
30
37
  from ...video_processor import VideoProcessor
31
- from ..controlnet.multicontrolnet import MultiControlNetModel
32
38
  from ..free_init_utils import FreeInitMixin
33
39
  from ..free_noise_utils import AnimateDiffFreeNoiseMixin
34
40
  from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
24
24
  from ...image_processor import PipelineImageInput, VaeImageProcessor
25
25
  from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
26
26
  from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
27
- from ...models.controlnet_sparsectrl import SparseControlNetModel
27
+ from ...models.controlnets.controlnet_sparsectrl import SparseControlNetModel
28
28
  from ...models.lora import adjust_lora_scale_text_encoder
29
29
  from ...models.unets.unet_motion_model import MotionAdapter
30
30
  from ...schedulers import KarrasDiffusionSchedulers
@@ -662,12 +662,6 @@ class AnimateDiffVideoToVideoPipeline(
662
662
  self.vae.to(dtype=torch.float32)
663
663
 
664
664
  if isinstance(generator, list):
665
- if len(generator) != batch_size:
666
- raise ValueError(
667
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
668
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
669
- )
670
-
671
665
  init_latents = [
672
666
  self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
673
667
  for i in range(batch_size)
@@ -21,7 +21,14 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
21
21
 
22
22
  from ...image_processor import PipelineImageInput
23
23
  from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
24
- from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
24
+ from ...models import (
25
+ AutoencoderKL,
26
+ ControlNetModel,
27
+ ImageProjection,
28
+ MultiControlNetModel,
29
+ UNet2DConditionModel,
30
+ UNetMotionModel,
31
+ )
25
32
  from ...models.lora import adjust_lora_scale_text_encoder
26
33
  from ...models.unets.unet_motion_model import MotionAdapter
27
34
  from ...schedulers import (
@@ -35,7 +42,6 @@ from ...schedulers import (
35
42
  from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
36
43
  from ...utils.torch_utils import is_compiled_module, randn_tensor
37
44
  from ...video_processor import VideoProcessor
38
- from ..controlnet.multicontrolnet import MultiControlNetModel
39
45
  from ..free_init_utils import FreeInitMixin
40
46
  from ..free_noise_utils import AnimateDiffFreeNoiseMixin
41
47
  from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
@@ -788,12 +794,6 @@ class AnimateDiffVideoToVideoControlNetPipeline(
788
794
  self.vae.to(dtype=torch.float32)
789
795
 
790
796
  if isinstance(generator, list):
791
- if len(generator) != batch_size:
792
- raise ValueError(
793
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
794
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
795
- )
796
-
797
797
  init_latents = [
798
798
  self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
799
799
  for i in range(batch_size)
@@ -1112,7 +1112,7 @@ class CrossAttnDownBlock2D(nn.Module):
1112
1112
  )
1113
1113
 
1114
1114
  for i in range(num_layers):
1115
- if self.training and self.gradient_checkpointing:
1115
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1116
1116
 
1117
1117
  def create_custom_forward(module, return_dict=None):
1118
1118
  def custom_forward(*inputs):
@@ -1290,7 +1290,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
1290
1290
  )
1291
1291
 
1292
1292
  for i in range(len(self.resnets[1:])):
1293
- if self.training and self.gradient_checkpointing:
1293
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1294
1294
 
1295
1295
  def create_custom_forward(module, return_dict=None):
1296
1296
  def custom_forward(*inputs):
@@ -1464,7 +1464,7 @@ class CrossAttnUpBlock2D(nn.Module):
1464
1464
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1465
1465
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1466
1466
 
1467
- if self.training and self.gradient_checkpointing:
1467
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1468
1468
 
1469
1469
  def create_custom_forward(module, return_dict=None):
1470
1470
  def custom_forward(*inputs):
@@ -387,7 +387,6 @@ class AuraFlowPipeline(DiffusionPipeline):
387
387
  prompt: Union[str, List[str]] = None,
388
388
  negative_prompt: Union[str, List[str]] = None,
389
389
  num_inference_steps: int = 50,
390
- timesteps: List[int] = None,
391
390
  sigmas: List[float] = None,
392
391
  guidance_scale: float = 3.5,
393
392
  num_images_per_prompt: Optional[int] = 1,
@@ -424,10 +423,6 @@ class AuraFlowPipeline(DiffusionPipeline):
424
423
  sigmas (`List[float]`, *optional*):
425
424
  Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
426
425
  `num_inference_steps` and `timesteps` must be `None`.
427
- timesteps (`List[int]`, *optional*):
428
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
429
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
430
- passed will be used. Must be in descending order.
431
426
  guidance_scale (`float`, *optional*, defaults to 5.0):
432
427
  Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
433
428
  `guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -522,9 +517,7 @@ class AuraFlowPipeline(DiffusionPipeline):
522
517
  # 4. Prepare timesteps
523
518
 
524
519
  # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
525
- timesteps, num_inference_steps = retrieve_timesteps(
526
- self.scheduler, num_inference_steps, device, timesteps, sigmas
527
- )
520
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
528
521
 
529
522
  # 5. Prepare latents.
530
523
  latent_channels = self.transformer.config.in_channels
@@ -18,6 +18,7 @@ from collections import OrderedDict
18
18
  from huggingface_hub.utils import validate_hf_hub_args
19
19
 
20
20
  from ..configuration_utils import ConfigMixin
21
+ from ..models.controlnets import ControlNetUnionModel
21
22
  from ..utils import is_sentencepiece_available
22
23
  from .aura_flow import AuraFlowPipeline
23
24
  from .cogview3 import CogView3PlusPipeline
@@ -28,12 +29,18 @@ from .controlnet import (
28
29
  StableDiffusionXLControlNetImg2ImgPipeline,
29
30
  StableDiffusionXLControlNetInpaintPipeline,
30
31
  StableDiffusionXLControlNetPipeline,
32
+ StableDiffusionXLControlNetUnionImg2ImgPipeline,
33
+ StableDiffusionXLControlNetUnionInpaintPipeline,
34
+ StableDiffusionXLControlNetUnionPipeline,
31
35
  )
32
36
  from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
33
37
  from .flux import (
38
+ FluxControlImg2ImgPipeline,
39
+ FluxControlInpaintPipeline,
34
40
  FluxControlNetImg2ImgPipeline,
35
41
  FluxControlNetInpaintPipeline,
36
42
  FluxControlNetPipeline,
43
+ FluxControlPipeline,
37
44
  FluxImg2ImgPipeline,
38
45
  FluxInpaintPipeline,
39
46
  FluxPipeline,
@@ -61,10 +68,12 @@ from .lumina import LuminaText2ImgPipeline
61
68
  from .pag import (
62
69
  HunyuanDiTPAGPipeline,
63
70
  PixArtSigmaPAGPipeline,
71
+ StableDiffusion3PAGImg2ImgPipeline,
64
72
  StableDiffusion3PAGPipeline,
65
73
  StableDiffusionControlNetPAGInpaintPipeline,
66
74
  StableDiffusionControlNetPAGPipeline,
67
75
  StableDiffusionPAGImg2ImgPipeline,
76
+ StableDiffusionPAGInpaintPipeline,
68
77
  StableDiffusionPAGPipeline,
69
78
  StableDiffusionXLControlNetPAGImg2ImgPipeline,
70
79
  StableDiffusionXLControlNetPAGPipeline,
@@ -106,6 +115,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
106
115
  ("kandinsky3", Kandinsky3Pipeline),
107
116
  ("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
108
117
  ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
118
+ ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionPipeline),
109
119
  ("wuerstchen", WuerstchenCombinedPipeline),
110
120
  ("cascade", StableCascadeCombinedPipeline),
111
121
  ("lcm", LatentConsistencyModelPipeline),
@@ -118,6 +128,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
118
128
  ("pixart-sigma-pag", PixArtSigmaPAGPipeline),
119
129
  ("auraflow", AuraFlowPipeline),
120
130
  ("flux", FluxPipeline),
131
+ ("flux-control", FluxControlPipeline),
121
132
  ("flux-controlnet", FluxControlNetPipeline),
122
133
  ("lumina", LuminaText2ImgPipeline),
123
134
  ("cogview3", CogView3PlusPipeline),
@@ -129,6 +140,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
129
140
  ("stable-diffusion", StableDiffusionImg2ImgPipeline),
130
141
  ("stable-diffusion-xl", StableDiffusionXLImg2ImgPipeline),
131
142
  ("stable-diffusion-3", StableDiffusion3Img2ImgPipeline),
143
+ ("stable-diffusion-3-pag", StableDiffusion3PAGImg2ImgPipeline),
132
144
  ("if", IFImg2ImgPipeline),
133
145
  ("kandinsky", KandinskyImg2ImgCombinedPipeline),
134
146
  ("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
@@ -136,11 +148,13 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
136
148
  ("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
137
149
  ("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
138
150
  ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
151
+ ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline),
139
152
  ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
140
153
  ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
141
154
  ("lcm", LatentConsistencyModelImg2ImgPipeline),
142
155
  ("flux", FluxImg2ImgPipeline),
143
156
  ("flux-controlnet", FluxControlNetImg2ImgPipeline),
157
+ ("flux-control", FluxControlImg2ImgPipeline),
144
158
  ]
145
159
  )
146
160
 
@@ -155,9 +169,12 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
155
169
  ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
156
170
  ("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline),
157
171
  ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
172
+ ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionInpaintPipeline),
158
173
  ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
159
174
  ("flux", FluxInpaintPipeline),
160
175
  ("flux-controlnet", FluxControlNetInpaintPipeline),
176
+ ("flux-control", FluxControlInpaintPipeline),
177
+ ("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
161
178
  ]
162
179
  )
163
180
 
@@ -390,13 +407,20 @@ class AutoPipelineForText2Image(ConfigMixin):
390
407
 
391
408
  config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
392
409
  orig_class_name = config["_class_name"]
410
+ if "ControlPipeline" in orig_class_name:
411
+ to_replace = "ControlPipeline"
412
+ else:
413
+ to_replace = "Pipeline"
393
414
 
394
415
  if "controlnet" in kwargs:
395
- orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
416
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
417
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
418
+ else:
419
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
396
420
  if "enable_pag" in kwargs:
397
421
  enable_pag = kwargs.pop("enable_pag")
398
422
  if enable_pag:
399
- orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")
423
+ orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
400
424
 
401
425
  text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
402
426
 
@@ -680,16 +704,28 @@ class AutoPipelineForImage2Image(ConfigMixin):
680
704
 
681
705
  # the `orig_class_name` can be:
682
706
  # `- *Pipeline` (for regular text-to-image checkpoint)
707
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
683
708
  # `- *Img2ImgPipeline` (for refiner checkpoint)
684
- to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
709
+ if "Img2Img" in orig_class_name:
710
+ to_replace = "Img2ImgPipeline"
711
+ elif "ControlPipeline" in orig_class_name:
712
+ to_replace = "ControlPipeline"
713
+ else:
714
+ to_replace = "Pipeline"
685
715
 
686
716
  if "controlnet" in kwargs:
687
- orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
717
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
718
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
719
+ else:
720
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
688
721
  if "enable_pag" in kwargs:
689
722
  enable_pag = kwargs.pop("enable_pag")
690
723
  if enable_pag:
691
724
  orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
692
725
 
726
+ if to_replace == "ControlPipeline":
727
+ orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
728
+
693
729
  image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
694
730
 
695
731
  kwargs = {**load_config_kwargs, **kwargs}
@@ -977,15 +1013,26 @@ class AutoPipelineForInpainting(ConfigMixin):
977
1013
 
978
1014
  # The `orig_class_name`` can be:
979
1015
  # `- *InpaintPipeline` (for inpaint-specific checkpoint)
1016
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
980
1017
  # - or *Pipeline (for regular text-to-image checkpoint)
981
- to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
1018
+ if "Inpaint" in orig_class_name:
1019
+ to_replace = "InpaintPipeline"
1020
+ elif "ControlPipeline" in orig_class_name:
1021
+ to_replace = "ControlPipeline"
1022
+ else:
1023
+ to_replace = "Pipeline"
982
1024
 
983
1025
  if "controlnet" in kwargs:
984
- orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
1026
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
1027
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
1028
+ else:
1029
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
985
1030
  if "enable_pag" in kwargs:
986
1031
  enable_pag = kwargs.pop("enable_pag")
987
1032
  if enable_pag:
988
1033
  orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
1034
+ if to_replace == "ControlPipeline":
1035
+ orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
989
1036
  inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
990
1037
 
991
1038
  kwargs = {**load_config_kwargs, **kwargs}
@@ -167,7 +167,7 @@ class Blip2QFormerEncoder(nn.Module):
167
167
  layer_head_mask = head_mask[i] if head_mask is not None else None
168
168
  past_key_value = past_key_values[i] if past_key_values is not None else None
169
169
 
170
- if getattr(self.config, "gradient_checkpointing", False) and self.training:
170
+ if getattr(self.config, "gradient_checkpointing", False) and torch.is_grad_enabled():
171
171
  if use_cache:
172
172
  logger.warning(
173
173
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
@@ -442,21 +442,39 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
442
442
  ) -> Tuple[torch.Tensor, torch.Tensor]:
443
443
  grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
444
444
  grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
445
- base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
446
- base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
447
445
 
448
- grid_crops_coords = get_resize_crop_region_for_grid(
449
- (grid_height, grid_width), base_size_width, base_size_height
450
- )
451
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
452
- embed_dim=self.transformer.config.attention_head_dim,
453
- crops_coords=grid_crops_coords,
454
- grid_size=(grid_height, grid_width),
455
- temporal_size=num_frames,
456
- )
446
+ p = self.transformer.config.patch_size
447
+ p_t = self.transformer.config.patch_size_t
448
+
449
+ base_size_width = self.transformer.config.sample_width // p
450
+ base_size_height = self.transformer.config.sample_height // p
451
+
452
+ if p_t is None:
453
+ # CogVideoX 1.0
454
+ grid_crops_coords = get_resize_crop_region_for_grid(
455
+ (grid_height, grid_width), base_size_width, base_size_height
456
+ )
457
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
458
+ embed_dim=self.transformer.config.attention_head_dim,
459
+ crops_coords=grid_crops_coords,
460
+ grid_size=(grid_height, grid_width),
461
+ temporal_size=num_frames,
462
+ device=device,
463
+ )
464
+ else:
465
+ # CogVideoX 1.5
466
+ base_num_frames = (num_frames + p_t - 1) // p_t
467
+
468
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
469
+ embed_dim=self.transformer.config.attention_head_dim,
470
+ crops_coords=None,
471
+ grid_size=(grid_height, grid_width),
472
+ temporal_size=base_num_frames,
473
+ grid_type="slice",
474
+ max_size=(base_size_height, base_size_width),
475
+ device=device,
476
+ )
457
477
 
458
- freqs_cos = freqs_cos.to(device=device)
459
- freqs_sin = freqs_sin.to(device=device)
460
478
  return freqs_cos, freqs_sin
461
479
 
462
480
  @property
@@ -481,9 +499,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
481
499
  self,
482
500
  prompt: Optional[Union[str, List[str]]] = None,
483
501
  negative_prompt: Optional[Union[str, List[str]]] = None,
484
- height: int = 480,
485
- width: int = 720,
486
- num_frames: int = 49,
502
+ height: Optional[int] = None,
503
+ width: Optional[int] = None,
504
+ num_frames: Optional[int] = None,
487
505
  num_inference_steps: int = 50,
488
506
  timesteps: Optional[List[int]] = None,
489
507
  guidance_scale: float = 6,
@@ -583,14 +601,13 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
583
601
  `tuple`. When returning a tuple, the first element is a list with the generated images.
584
602
  """
585
603
 
586
- if num_frames > 49:
587
- raise ValueError(
588
- "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
589
- )
590
-
591
604
  if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
592
605
  callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
593
606
 
607
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
608
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
609
+ num_frames = num_frames or self.transformer.config.sample_frames
610
+
594
611
  num_videos_per_prompt = 1
595
612
 
596
613
  # 1. Check inputs. Raise error if not correct
@@ -640,7 +657,16 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
640
657
  timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
641
658
  self._num_timesteps = len(timesteps)
642
659
 
643
- # 5. Prepare latents.
660
+ # 5. Prepare latents
661
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
662
+
663
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
664
+ patch_size_t = self.transformer.config.patch_size_t
665
+ additional_frames = 0
666
+ if patch_size_t is not None and latent_frames % patch_size_t != 0:
667
+ additional_frames = patch_size_t - latent_frames % patch_size_t
668
+ num_frames += additional_frames * self.vae_scale_factor_temporal
669
+
644
670
  latent_channels = self.transformer.config.in_channels
645
671
  latents = self.prepare_latents(
646
672
  batch_size * num_videos_per_prompt,
@@ -730,6 +756,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
730
756
  progress_bar.update()
731
757
 
732
758
  if not output_type == "latent":
759
+ # Discard any padding frames that were added for CogVideoX 1.5
760
+ latents = latents[:, additional_frames:]
733
761
  video = self.decode_latents(latents)
734
762
  video = self.video_processor.postprocess_video(video=video, output_type=output_type)
735
763
  else:
@@ -488,21 +488,39 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
488
488
  ) -> Tuple[torch.Tensor, torch.Tensor]:
489
489
  grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
490
490
  grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
491
- base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
492
- base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
493
491
 
494
- grid_crops_coords = get_resize_crop_region_for_grid(
495
- (grid_height, grid_width), base_size_width, base_size_height
496
- )
497
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
498
- embed_dim=self.transformer.config.attention_head_dim,
499
- crops_coords=grid_crops_coords,
500
- grid_size=(grid_height, grid_width),
501
- temporal_size=num_frames,
502
- )
492
+ p = self.transformer.config.patch_size
493
+ p_t = self.transformer.config.patch_size_t
494
+
495
+ base_size_width = self.transformer.config.sample_width // p
496
+ base_size_height = self.transformer.config.sample_height // p
497
+
498
+ if p_t is None:
499
+ # CogVideoX 1.0
500
+ grid_crops_coords = get_resize_crop_region_for_grid(
501
+ (grid_height, grid_width), base_size_width, base_size_height
502
+ )
503
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
504
+ embed_dim=self.transformer.config.attention_head_dim,
505
+ crops_coords=grid_crops_coords,
506
+ grid_size=(grid_height, grid_width),
507
+ temporal_size=num_frames,
508
+ device=device,
509
+ )
510
+ else:
511
+ # CogVideoX 1.5
512
+ base_num_frames = (num_frames + p_t - 1) // p_t
513
+
514
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
515
+ embed_dim=self.transformer.config.attention_head_dim,
516
+ crops_coords=None,
517
+ grid_size=(grid_height, grid_width),
518
+ temporal_size=base_num_frames,
519
+ grid_type="slice",
520
+ max_size=(base_size_height, base_size_width),
521
+ device=device,
522
+ )
503
523
 
504
- freqs_cos = freqs_cos.to(device=device)
505
- freqs_sin = freqs_sin.to(device=device)
506
524
  return freqs_cos, freqs_sin
507
525
 
508
526
  @property
@@ -528,8 +546,8 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
528
546
  prompt: Optional[Union[str, List[str]]] = None,
529
547
  negative_prompt: Optional[Union[str, List[str]]] = None,
530
548
  control_video: Optional[List[Image.Image]] = None,
531
- height: int = 480,
532
- width: int = 720,
549
+ height: Optional[int] = None,
550
+ width: Optional[int] = None,
533
551
  num_inference_steps: int = 50,
534
552
  timesteps: Optional[List[int]] = None,
535
553
  guidance_scale: float = 6,
@@ -634,6 +652,13 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
634
652
  if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
635
653
  callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
636
654
 
655
+ if control_video is not None and isinstance(control_video[0], Image.Image):
656
+ control_video = [control_video]
657
+
658
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
659
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
660
+ num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2)
661
+
637
662
  num_videos_per_prompt = 1
638
663
 
639
664
  # 1. Check inputs. Raise error if not correct
@@ -660,9 +685,6 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
660
685
  else:
661
686
  batch_size = prompt_embeds.shape[0]
662
687
 
663
- if control_video is not None and isinstance(control_video[0], Image.Image):
664
- control_video = [control_video]
665
-
666
688
  device = self._execution_device
667
689
 
668
690
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
@@ -688,9 +710,18 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
688
710
  timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
689
711
  self._num_timesteps = len(timesteps)
690
712
 
691
- # 5. Prepare latents.
713
+ # 5. Prepare latents
714
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
715
+
716
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
717
+ patch_size_t = self.transformer.config.patch_size_t
718
+ if patch_size_t is not None and latent_frames % patch_size_t != 0:
719
+ raise ValueError(
720
+ f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video "
721
+ f"contains {latent_frames=}, which is not divisible."
722
+ )
723
+
692
724
  latent_channels = self.transformer.config.in_channels // 2
693
- num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2)
694
725
  latents = self.prepare_latents(
695
726
  batch_size * num_videos_per_prompt,
696
727
  latent_channels,