diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -1112,7 +1112,7 @@ class CrossAttnDownBlock2D(nn.Module):
1112
1112
  )
1113
1113
 
1114
1114
  for i in range(num_layers):
1115
- if self.training and self.gradient_checkpointing:
1115
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1116
1116
 
1117
1117
  def create_custom_forward(module, return_dict=None):
1118
1118
  def custom_forward(*inputs):
@@ -1290,7 +1290,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
1290
1290
  )
1291
1291
 
1292
1292
  for i in range(len(self.resnets[1:])):
1293
- if self.training and self.gradient_checkpointing:
1293
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1294
1294
 
1295
1295
  def create_custom_forward(module, return_dict=None):
1296
1296
  def custom_forward(*inputs):
@@ -1464,7 +1464,7 @@ class CrossAttnUpBlock2D(nn.Module):
1464
1464
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1465
1465
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1466
1466
 
1467
- if self.training and self.gradient_checkpointing:
1467
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1468
1468
 
1469
1469
  def create_custom_forward(module, return_dict=None):
1470
1470
  def custom_forward(*inputs):
@@ -53,7 +53,7 @@ def retrieve_timesteps(
53
53
  sigmas: Optional[List[float]] = None,
54
54
  **kwargs,
55
55
  ):
56
- """
56
+ r"""
57
57
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
58
58
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
59
59
 
@@ -387,7 +387,6 @@ class AuraFlowPipeline(DiffusionPipeline):
387
387
  prompt: Union[str, List[str]] = None,
388
388
  negative_prompt: Union[str, List[str]] = None,
389
389
  num_inference_steps: int = 50,
390
- timesteps: List[int] = None,
391
390
  sigmas: List[float] = None,
392
391
  guidance_scale: float = 3.5,
393
392
  num_images_per_prompt: Optional[int] = 1,
@@ -424,10 +423,6 @@ class AuraFlowPipeline(DiffusionPipeline):
424
423
  sigmas (`List[float]`, *optional*):
425
424
  Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
426
425
  `num_inference_steps` and `timesteps` must be `None`.
427
- timesteps (`List[int]`, *optional*):
428
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
429
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
430
- passed will be used. Must be in descending order.
431
426
  guidance_scale (`float`, *optional*, defaults to 5.0):
432
427
  Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
433
428
  `guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -522,9 +517,7 @@ class AuraFlowPipeline(DiffusionPipeline):
522
517
  # 4. Prepare timesteps
523
518
 
524
519
  # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
525
- timesteps, num_inference_steps = retrieve_timesteps(
526
- self.scheduler, num_inference_steps, device, timesteps, sigmas
527
- )
520
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
528
521
 
529
522
  # 5. Prepare latents.
530
523
  latent_channels = self.transformer.config.in_channels
@@ -18,8 +18,10 @@ from collections import OrderedDict
18
18
  from huggingface_hub.utils import validate_hf_hub_args
19
19
 
20
20
  from ..configuration_utils import ConfigMixin
21
+ from ..models.controlnets import ControlNetUnionModel
21
22
  from ..utils import is_sentencepiece_available
22
23
  from .aura_flow import AuraFlowPipeline
24
+ from .cogview3 import CogView3PlusPipeline
23
25
  from .controlnet import (
24
26
  StableDiffusionControlNetImg2ImgPipeline,
25
27
  StableDiffusionControlNetInpaintPipeline,
@@ -27,9 +29,22 @@ from .controlnet import (
27
29
  StableDiffusionXLControlNetImg2ImgPipeline,
28
30
  StableDiffusionXLControlNetInpaintPipeline,
29
31
  StableDiffusionXLControlNetPipeline,
32
+ StableDiffusionXLControlNetUnionImg2ImgPipeline,
33
+ StableDiffusionXLControlNetUnionInpaintPipeline,
34
+ StableDiffusionXLControlNetUnionPipeline,
30
35
  )
31
36
  from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
32
- from .flux import FluxPipeline
37
+ from .flux import (
38
+ FluxControlImg2ImgPipeline,
39
+ FluxControlInpaintPipeline,
40
+ FluxControlNetImg2ImgPipeline,
41
+ FluxControlNetInpaintPipeline,
42
+ FluxControlNetPipeline,
43
+ FluxControlPipeline,
44
+ FluxImg2ImgPipeline,
45
+ FluxInpaintPipeline,
46
+ FluxPipeline,
47
+ )
33
48
  from .hunyuandit import HunyuanDiTPipeline
34
49
  from .kandinsky import (
35
50
  KandinskyCombinedPipeline,
@@ -49,12 +64,18 @@ from .kandinsky2_2 import (
49
64
  )
50
65
  from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
51
66
  from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
67
+ from .lumina import LuminaText2ImgPipeline
52
68
  from .pag import (
53
69
  HunyuanDiTPAGPipeline,
54
70
  PixArtSigmaPAGPipeline,
71
+ StableDiffusion3PAGImg2ImgPipeline,
55
72
  StableDiffusion3PAGPipeline,
73
+ StableDiffusionControlNetPAGInpaintPipeline,
56
74
  StableDiffusionControlNetPAGPipeline,
75
+ StableDiffusionPAGImg2ImgPipeline,
76
+ StableDiffusionPAGInpaintPipeline,
57
77
  StableDiffusionPAGPipeline,
78
+ StableDiffusionXLControlNetPAGImg2ImgPipeline,
58
79
  StableDiffusionXLControlNetPAGPipeline,
59
80
  StableDiffusionXLPAGImg2ImgPipeline,
60
81
  StableDiffusionXLPAGInpaintPipeline,
@@ -94,6 +115,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
94
115
  ("kandinsky3", Kandinsky3Pipeline),
95
116
  ("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
96
117
  ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
118
+ ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionPipeline),
97
119
  ("wuerstchen", WuerstchenCombinedPipeline),
98
120
  ("cascade", StableCascadeCombinedPipeline),
99
121
  ("lcm", LatentConsistencyModelPipeline),
@@ -106,6 +128,10 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
106
128
  ("pixart-sigma-pag", PixArtSigmaPAGPipeline),
107
129
  ("auraflow", AuraFlowPipeline),
108
130
  ("flux", FluxPipeline),
131
+ ("flux-control", FluxControlPipeline),
132
+ ("flux-controlnet", FluxControlNetPipeline),
133
+ ("lumina", LuminaText2ImgPipeline),
134
+ ("cogview3", CogView3PlusPipeline),
109
135
  ]
110
136
  )
111
137
 
@@ -114,14 +140,21 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
114
140
  ("stable-diffusion", StableDiffusionImg2ImgPipeline),
115
141
  ("stable-diffusion-xl", StableDiffusionXLImg2ImgPipeline),
116
142
  ("stable-diffusion-3", StableDiffusion3Img2ImgPipeline),
143
+ ("stable-diffusion-3-pag", StableDiffusion3PAGImg2ImgPipeline),
117
144
  ("if", IFImg2ImgPipeline),
118
145
  ("kandinsky", KandinskyImg2ImgCombinedPipeline),
119
146
  ("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
120
147
  ("kandinsky3", Kandinsky3Img2ImgPipeline),
121
148
  ("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
149
+ ("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
122
150
  ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
151
+ ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline),
123
152
  ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
153
+ ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
124
154
  ("lcm", LatentConsistencyModelImg2ImgPipeline),
155
+ ("flux", FluxImg2ImgPipeline),
156
+ ("flux-controlnet", FluxControlNetImg2ImgPipeline),
157
+ ("flux-control", FluxControlImg2ImgPipeline),
125
158
  ]
126
159
  )
127
160
 
@@ -134,8 +167,14 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
134
167
  ("kandinsky", KandinskyInpaintCombinedPipeline),
135
168
  ("kandinsky22", KandinskyV22InpaintCombinedPipeline),
136
169
  ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
170
+ ("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline),
137
171
  ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
172
+ ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionInpaintPipeline),
138
173
  ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
174
+ ("flux", FluxInpaintPipeline),
175
+ ("flux-controlnet", FluxControlNetInpaintPipeline),
176
+ ("flux-control", FluxControlInpaintPipeline),
177
+ ("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
139
178
  ]
140
179
  )
141
180
 
@@ -161,12 +200,12 @@ _AUTO_INPAINT_DECODER_PIPELINES_MAPPING = OrderedDict(
161
200
  )
162
201
 
163
202
  if is_sentencepiece_available():
164
- from .kolors import KolorsPipeline
203
+ from .kolors import KolorsImg2ImgPipeline, KolorsPipeline
165
204
  from .pag import KolorsPAGPipeline
166
205
 
167
206
  AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline
168
207
  AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors-pag"] = KolorsPAGPipeline
169
- AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline
208
+ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsImg2ImgPipeline
170
209
 
171
210
  SUPPORTED_TASKS_MAPPINGS = [
172
211
  AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
@@ -368,13 +407,20 @@ class AutoPipelineForText2Image(ConfigMixin):
368
407
 
369
408
  config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
370
409
  orig_class_name = config["_class_name"]
410
+ if "ControlPipeline" in orig_class_name:
411
+ to_replace = "ControlPipeline"
412
+ else:
413
+ to_replace = "Pipeline"
371
414
 
372
415
  if "controlnet" in kwargs:
373
- orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
416
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
417
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
418
+ else:
419
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
374
420
  if "enable_pag" in kwargs:
375
421
  enable_pag = kwargs.pop("enable_pag")
376
422
  if enable_pag:
377
- orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")
423
+ orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
378
424
 
379
425
  text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
380
426
 
@@ -656,12 +702,29 @@ class AutoPipelineForImage2Image(ConfigMixin):
656
702
  config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
657
703
  orig_class_name = config["_class_name"]
658
704
 
705
+ # the `orig_class_name` can be:
706
+ # `- *Pipeline` (for regular text-to-image checkpoint)
707
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
708
+ # `- *Img2ImgPipeline` (for refiner checkpoint)
709
+ if "Img2Img" in orig_class_name:
710
+ to_replace = "Img2ImgPipeline"
711
+ elif "ControlPipeline" in orig_class_name:
712
+ to_replace = "ControlPipeline"
713
+ else:
714
+ to_replace = "Pipeline"
715
+
659
716
  if "controlnet" in kwargs:
660
- orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
717
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
718
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
719
+ else:
720
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
661
721
  if "enable_pag" in kwargs:
662
722
  enable_pag = kwargs.pop("enable_pag")
663
723
  if enable_pag:
664
- orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")
724
+ orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
725
+
726
+ if to_replace == "ControlPipeline":
727
+ orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
665
728
 
666
729
  image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
667
730
 
@@ -948,13 +1011,28 @@ class AutoPipelineForInpainting(ConfigMixin):
948
1011
  config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
949
1012
  orig_class_name = config["_class_name"]
950
1013
 
1014
+ # The `orig_class_name`` can be:
1015
+ # `- *InpaintPipeline` (for inpaint-specific checkpoint)
1016
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
1017
+ # - or *Pipeline (for regular text-to-image checkpoint)
1018
+ if "Inpaint" in orig_class_name:
1019
+ to_replace = "InpaintPipeline"
1020
+ elif "ControlPipeline" in orig_class_name:
1021
+ to_replace = "ControlPipeline"
1022
+ else:
1023
+ to_replace = "Pipeline"
1024
+
951
1025
  if "controlnet" in kwargs:
952
- orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
1026
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
1027
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
1028
+ else:
1029
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
953
1030
  if "enable_pag" in kwargs:
954
1031
  enable_pag = kwargs.pop("enable_pag")
955
1032
  if enable_pag:
956
- orig_class_name = config["_class_name"].replace("Pipeline", "PAGPipeline")
957
-
1033
+ orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
1034
+ if to_replace == "ControlPipeline":
1035
+ orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
958
1036
  inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
959
1037
 
960
1038
  kwargs = {**load_config_kwargs, **kwargs}
@@ -167,7 +167,7 @@ class Blip2QFormerEncoder(nn.Module):
167
167
  layer_head_mask = head_mask[i] if head_mask is not None else None
168
168
  past_key_value = past_key_values[i] if past_key_values is not None else None
169
169
 
170
- if getattr(self.config, "gradient_checkpointing", False) and self.training:
170
+ if getattr(self.config, "gradient_checkpointing", False) and torch.is_grad_enabled():
171
171
  if use_cache:
172
172
  logger.warning(
173
173
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
23
23
  _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
24
  else:
25
25
  _import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
26
+ _import_structure["pipeline_cogvideox_fun_control"] = ["CogVideoXFunControlPipeline"]
26
27
  _import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"]
27
28
  _import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"]
28
29
 
@@ -35,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
35
36
  from ...utils.dummy_torch_and_transformers_objects import *
36
37
  else:
37
38
  from .pipeline_cogvideox import CogVideoXPipeline
39
+ from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline
38
40
  from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
39
41
  from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline
40
42
 
@@ -15,12 +15,13 @@
15
15
 
16
16
  import inspect
17
17
  import math
18
- from typing import Callable, Dict, List, Optional, Tuple, Union
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
19
 
20
20
  import torch
21
21
  from transformers import T5EncoderModel, T5Tokenizer
22
22
 
23
23
  from ...callbacks import MultiPipelineCallbacks, PipelineCallback
24
+ from ...loaders import CogVideoXLoraLoaderMixin
24
25
  from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
25
26
  from ...models.embeddings import get_3d_rotary_pos_embed
26
27
  from ...pipelines.pipeline_utils import DiffusionPipeline
@@ -85,7 +86,7 @@ def retrieve_timesteps(
85
86
  sigmas: Optional[List[float]] = None,
86
87
  **kwargs,
87
88
  ):
88
- """
89
+ r"""
89
90
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
90
91
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
91
92
 
@@ -136,7 +137,7 @@ def retrieve_timesteps(
136
137
  return timesteps, num_inference_steps
137
138
 
138
139
 
139
- class CogVideoXPipeline(DiffusionPipeline):
140
+ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
140
141
  r"""
141
142
  Pipeline for text-to-video generation using CogVideoX.
142
143
 
@@ -187,6 +188,9 @@ class CogVideoXPipeline(DiffusionPipeline):
187
188
  self.vae_scale_factor_temporal = (
188
189
  self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
189
190
  )
191
+ self.vae_scaling_factor_image = (
192
+ self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
193
+ )
190
194
 
191
195
  self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
192
196
 
@@ -316,6 +320,12 @@ class CogVideoXPipeline(DiffusionPipeline):
316
320
  def prepare_latents(
317
321
  self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
318
322
  ):
323
+ if isinstance(generator, list) and len(generator) != batch_size:
324
+ raise ValueError(
325
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
326
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
327
+ )
328
+
319
329
  shape = (
320
330
  batch_size,
321
331
  (num_frames - 1) // self.vae_scale_factor_temporal + 1,
@@ -323,11 +333,6 @@ class CogVideoXPipeline(DiffusionPipeline):
323
333
  height // self.vae_scale_factor_spatial,
324
334
  width // self.vae_scale_factor_spatial,
325
335
  )
326
- if isinstance(generator, list) and len(generator) != batch_size:
327
- raise ValueError(
328
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
329
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
330
- )
331
336
 
332
337
  if latents is None:
333
338
  latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
@@ -340,7 +345,7 @@ class CogVideoXPipeline(DiffusionPipeline):
340
345
 
341
346
  def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
342
347
  latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
343
- latents = 1 / self.vae.config.scaling_factor * latents
348
+ latents = 1 / self.vae_scaling_factor_image * latents
344
349
 
345
350
  frames = self.vae.decode(latents).sample
346
351
  return frames
@@ -437,22 +442,39 @@ class CogVideoXPipeline(DiffusionPipeline):
437
442
  ) -> Tuple[torch.Tensor, torch.Tensor]:
438
443
  grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
439
444
  grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
440
- base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
441
- base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
442
445
 
443
- grid_crops_coords = get_resize_crop_region_for_grid(
444
- (grid_height, grid_width), base_size_width, base_size_height
445
- )
446
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
447
- embed_dim=self.transformer.config.attention_head_dim,
448
- crops_coords=grid_crops_coords,
449
- grid_size=(grid_height, grid_width),
450
- temporal_size=num_frames,
451
- use_real=True,
452
- )
446
+ p = self.transformer.config.patch_size
447
+ p_t = self.transformer.config.patch_size_t
448
+
449
+ base_size_width = self.transformer.config.sample_width // p
450
+ base_size_height = self.transformer.config.sample_height // p
451
+
452
+ if p_t is None:
453
+ # CogVideoX 1.0
454
+ grid_crops_coords = get_resize_crop_region_for_grid(
455
+ (grid_height, grid_width), base_size_width, base_size_height
456
+ )
457
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
458
+ embed_dim=self.transformer.config.attention_head_dim,
459
+ crops_coords=grid_crops_coords,
460
+ grid_size=(grid_height, grid_width),
461
+ temporal_size=num_frames,
462
+ device=device,
463
+ )
464
+ else:
465
+ # CogVideoX 1.5
466
+ base_num_frames = (num_frames + p_t - 1) // p_t
467
+
468
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
469
+ embed_dim=self.transformer.config.attention_head_dim,
470
+ crops_coords=None,
471
+ grid_size=(grid_height, grid_width),
472
+ temporal_size=base_num_frames,
473
+ grid_type="slice",
474
+ max_size=(base_size_height, base_size_width),
475
+ device=device,
476
+ )
453
477
 
454
- freqs_cos = freqs_cos.to(device=device)
455
- freqs_sin = freqs_sin.to(device=device)
456
478
  return freqs_cos, freqs_sin
457
479
 
458
480
  @property
@@ -463,6 +485,10 @@ class CogVideoXPipeline(DiffusionPipeline):
463
485
  def num_timesteps(self):
464
486
  return self._num_timesteps
465
487
 
488
+ @property
489
+ def attention_kwargs(self):
490
+ return self._attention_kwargs
491
+
466
492
  @property
467
493
  def interrupt(self):
468
494
  return self._interrupt
@@ -473,9 +499,9 @@ class CogVideoXPipeline(DiffusionPipeline):
473
499
  self,
474
500
  prompt: Optional[Union[str, List[str]]] = None,
475
501
  negative_prompt: Optional[Union[str, List[str]]] = None,
476
- height: int = 480,
477
- width: int = 720,
478
- num_frames: int = 49,
502
+ height: Optional[int] = None,
503
+ width: Optional[int] = None,
504
+ num_frames: Optional[int] = None,
479
505
  num_inference_steps: int = 50,
480
506
  timesteps: Optional[List[int]] = None,
481
507
  guidance_scale: float = 6,
@@ -488,6 +514,7 @@ class CogVideoXPipeline(DiffusionPipeline):
488
514
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
489
515
  output_type: str = "pil",
490
516
  return_dict: bool = True,
517
+ attention_kwargs: Optional[Dict[str, Any]] = None,
491
518
  callback_on_step_end: Optional[
492
519
  Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
493
520
  ] = None,
@@ -505,14 +532,14 @@ class CogVideoXPipeline(DiffusionPipeline):
505
532
  The prompt or prompts not to guide the image generation. If not defined, one has to pass
506
533
  `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
507
534
  less than `1`).
508
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
509
- The height in pixels of the generated image. This is set to 1024 by default for the best results.
510
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
511
- The width in pixels of the generated image. This is set to 1024 by default for the best results.
535
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
536
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
537
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
538
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
512
539
  num_frames (`int`, defaults to `48`):
513
540
  Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
514
541
  contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
515
- num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
542
+ num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
516
543
  needs to be satisfied is that of divisibility mentioned above.
517
544
  num_inference_steps (`int`, *optional*, defaults to 50):
518
545
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -549,6 +576,10 @@ class CogVideoXPipeline(DiffusionPipeline):
549
576
  return_dict (`bool`, *optional*, defaults to `True`):
550
577
  Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
551
578
  of a plain tuple.
579
+ attention_kwargs (`dict`, *optional*):
580
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
581
+ `self.processor` in
582
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
552
583
  callback_on_step_end (`Callable`, *optional*):
553
584
  A function that calls at the end of each denoising steps during the inference. The function is called
554
585
  with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -570,16 +601,13 @@ class CogVideoXPipeline(DiffusionPipeline):
570
601
  `tuple`. When returning a tuple, the first element is a list with the generated images.
571
602
  """
572
603
 
573
- if num_frames > 49:
574
- raise ValueError(
575
- "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
576
- )
577
-
578
604
  if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
579
605
  callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
580
606
 
581
- height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
582
- width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
607
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
608
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
609
+ num_frames = num_frames or self.transformer.config.sample_frames
610
+
583
611
  num_videos_per_prompt = 1
584
612
 
585
613
  # 1. Check inputs. Raise error if not correct
@@ -593,6 +621,7 @@ class CogVideoXPipeline(DiffusionPipeline):
593
621
  negative_prompt_embeds,
594
622
  )
595
623
  self._guidance_scale = guidance_scale
624
+ self._attention_kwargs = attention_kwargs
596
625
  self._interrupt = False
597
626
 
598
627
  # 2. Default call parameters
@@ -628,7 +657,16 @@ class CogVideoXPipeline(DiffusionPipeline):
628
657
  timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
629
658
  self._num_timesteps = len(timesteps)
630
659
 
631
- # 5. Prepare latents.
660
+ # 5. Prepare latents
661
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
662
+
663
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
664
+ patch_size_t = self.transformer.config.patch_size_t
665
+ additional_frames = 0
666
+ if patch_size_t is not None and latent_frames % patch_size_t != 0:
667
+ additional_frames = patch_size_t - latent_frames % patch_size_t
668
+ num_frames += additional_frames * self.vae_scale_factor_temporal
669
+
632
670
  latent_channels = self.transformer.config.in_channels
633
671
  latents = self.prepare_latents(
634
672
  batch_size * num_videos_per_prompt,
@@ -674,6 +712,7 @@ class CogVideoXPipeline(DiffusionPipeline):
674
712
  encoder_hidden_states=prompt_embeds,
675
713
  timestep=timestep,
676
714
  image_rotary_emb=image_rotary_emb,
715
+ attention_kwargs=attention_kwargs,
677
716
  return_dict=False,
678
717
  )[0]
679
718
  noise_pred = noise_pred.float()
@@ -717,6 +756,8 @@ class CogVideoXPipeline(DiffusionPipeline):
717
756
  progress_bar.update()
718
757
 
719
758
  if not output_type == "latent":
759
+ # Discard any padding frames that were added for CogVideoX 1.5
760
+ latents = latents[:, additional_frames:]
720
761
  video = self.decode_latents(latents)
721
762
  video = self.video_processor.postprocess_video(video=video, output_type=output_type)
722
763
  else: