diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,23 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL
6
+ import torch
7
+
8
+ from diffusers.utils import BaseOutput
9
+
10
+
11
+ @dataclass
12
+ class AllegroPipelineOutput(BaseOutput):
13
+ r"""
14
+ Output class for Allegro pipelines.
15
+
16
+ Args:
17
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
18
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
19
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
20
+ `(batch_size, num_frames, channels, height, width)`.
21
+ """
22
+
23
+ frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
@@ -26,6 +26,7 @@ else:
26
26
  _import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"]
27
27
  _import_structure["pipeline_animatediff_sparsectrl"] = ["AnimateDiffSparseControlNetPipeline"]
28
28
  _import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
29
+ _import_structure["pipeline_animatediff_video2video_controlnet"] = ["AnimateDiffVideoToVideoControlNetPipeline"]
29
30
 
30
31
  if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
31
32
  try:
@@ -40,6 +41,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
40
41
  from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline
41
42
  from .pipeline_animatediff_sparsectrl import AnimateDiffSparseControlNetPipeline
42
43
  from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
44
+ from .pipeline_animatediff_video2video_controlnet import AnimateDiffVideoToVideoControlNetPipeline
43
45
  from .pipeline_output import AnimateDiffPipelineOutput
44
46
 
45
47
  else:
@@ -432,7 +432,6 @@ class AnimateDiffPipeline(
432
432
  extra_step_kwargs["generator"] = generator
433
433
  return extra_step_kwargs
434
434
 
435
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
436
435
  def check_inputs(
437
436
  self,
438
437
  prompt,
@@ -470,8 +469,8 @@ class AnimateDiffPipeline(
470
469
  raise ValueError(
471
470
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
472
471
  )
473
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
474
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
472
+ elif prompt is not None and not isinstance(prompt, (str, list, dict)):
473
+ raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)=}")
475
474
 
476
475
  if negative_prompt is not None and negative_prompt_embeds is not None:
477
476
  raise ValueError(
@@ -557,11 +556,15 @@ class AnimateDiffPipeline(
557
556
  def num_timesteps(self):
558
557
  return self._num_timesteps
559
558
 
559
+ @property
560
+ def interrupt(self):
561
+ return self._interrupt
562
+
560
563
  @torch.no_grad()
561
564
  @replace_example_docstring(EXAMPLE_DOC_STRING)
562
565
  def __call__(
563
566
  self,
564
- prompt: Union[str, List[str]] = None,
567
+ prompt: Optional[Union[str, List[str]]] = None,
565
568
  num_frames: Optional[int] = 16,
566
569
  height: Optional[int] = None,
567
570
  width: Optional[int] = None,
@@ -701,9 +704,10 @@ class AnimateDiffPipeline(
701
704
  self._guidance_scale = guidance_scale
702
705
  self._clip_skip = clip_skip
703
706
  self._cross_attention_kwargs = cross_attention_kwargs
707
+ self._interrupt = False
704
708
 
705
709
  # 2. Define call parameters
706
- if prompt is not None and isinstance(prompt, str):
710
+ if prompt is not None and isinstance(prompt, (str, dict)):
707
711
  batch_size = 1
708
712
  elif prompt is not None and isinstance(prompt, list):
709
713
  batch_size = len(prompt)
@@ -716,22 +720,39 @@ class AnimateDiffPipeline(
716
720
  text_encoder_lora_scale = (
717
721
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
718
722
  )
719
- prompt_embeds, negative_prompt_embeds = self.encode_prompt(
720
- prompt,
721
- device,
722
- num_videos_per_prompt,
723
- self.do_classifier_free_guidance,
724
- negative_prompt,
725
- prompt_embeds=prompt_embeds,
726
- negative_prompt_embeds=negative_prompt_embeds,
727
- lora_scale=text_encoder_lora_scale,
728
- clip_skip=self.clip_skip,
729
- )
730
- # For classifier free guidance, we need to do two forward passes.
731
- # Here we concatenate the unconditional and text embeddings into a single batch
732
- # to avoid doing two forward passes
733
- if self.do_classifier_free_guidance:
734
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
723
+ if self.free_noise_enabled:
724
+ prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
725
+ prompt=prompt,
726
+ num_frames=num_frames,
727
+ device=device,
728
+ num_videos_per_prompt=num_videos_per_prompt,
729
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
730
+ negative_prompt=negative_prompt,
731
+ prompt_embeds=prompt_embeds,
732
+ negative_prompt_embeds=negative_prompt_embeds,
733
+ lora_scale=text_encoder_lora_scale,
734
+ clip_skip=self.clip_skip,
735
+ )
736
+ else:
737
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
738
+ prompt,
739
+ device,
740
+ num_videos_per_prompt,
741
+ self.do_classifier_free_guidance,
742
+ negative_prompt,
743
+ prompt_embeds=prompt_embeds,
744
+ negative_prompt_embeds=negative_prompt_embeds,
745
+ lora_scale=text_encoder_lora_scale,
746
+ clip_skip=self.clip_skip,
747
+ )
748
+
749
+ # For classifier free guidance, we need to do two forward passes.
750
+ # Here we concatenate the unconditional and text embeddings into a single batch
751
+ # to avoid doing two forward passes
752
+ if self.do_classifier_free_guidance:
753
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
754
+
755
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
735
756
 
736
757
  if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
737
758
  image_embeds = self.prepare_ip_adapter_image_embeds(
@@ -783,6 +804,9 @@ class AnimateDiffPipeline(
783
804
  # 8. Denoising loop
784
805
  with self.progress_bar(total=self._num_timesteps) as progress_bar:
785
806
  for i, t in enumerate(timesteps):
807
+ if self.interrupt:
808
+ continue
809
+
786
810
  # expand the latents if we are doing classifier free guidance
787
811
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
788
812
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -21,14 +21,20 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
21
21
 
22
22
  from ...image_processor import PipelineImageInput
23
23
  from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
24
- from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
24
+ from ...models import (
25
+ AutoencoderKL,
26
+ ControlNetModel,
27
+ ImageProjection,
28
+ MultiControlNetModel,
29
+ UNet2DConditionModel,
30
+ UNetMotionModel,
31
+ )
25
32
  from ...models.lora import adjust_lora_scale_text_encoder
26
33
  from ...models.unets.unet_motion_model import MotionAdapter
27
34
  from ...schedulers import KarrasDiffusionSchedulers
28
35
  from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
29
36
  from ...utils.torch_utils import is_compiled_module, randn_tensor
30
37
  from ...video_processor import VideoProcessor
31
- from ..controlnet.multicontrolnet import MultiControlNetModel
32
38
  from ..free_init_utils import FreeInitMixin
33
39
  from ..free_noise_utils import AnimateDiffFreeNoiseMixin
34
40
  from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
@@ -505,8 +511,8 @@ class AnimateDiffControlNetPipeline(
505
511
  raise ValueError(
506
512
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
507
513
  )
508
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
509
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
514
+ elif prompt is not None and not isinstance(prompt, (str, list, dict)):
515
+ raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
510
516
 
511
517
  if negative_prompt is not None and negative_prompt_embeds is not None:
512
518
  raise ValueError(
@@ -699,6 +705,10 @@ class AnimateDiffControlNetPipeline(
699
705
  def num_timesteps(self):
700
706
  return self._num_timesteps
701
707
 
708
+ @property
709
+ def interrupt(self):
710
+ return self._interrupt
711
+
702
712
  @torch.no_grad()
703
713
  def __call__(
704
714
  self,
@@ -858,9 +868,10 @@ class AnimateDiffControlNetPipeline(
858
868
  self._guidance_scale = guidance_scale
859
869
  self._clip_skip = clip_skip
860
870
  self._cross_attention_kwargs = cross_attention_kwargs
871
+ self._interrupt = False
861
872
 
862
873
  # 2. Define call parameters
863
- if prompt is not None and isinstance(prompt, str):
874
+ if prompt is not None and isinstance(prompt, (str, dict)):
864
875
  batch_size = 1
865
876
  elif prompt is not None and isinstance(prompt, list):
866
877
  batch_size = len(prompt)
@@ -883,22 +894,39 @@ class AnimateDiffControlNetPipeline(
883
894
  text_encoder_lora_scale = (
884
895
  cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
885
896
  )
886
- prompt_embeds, negative_prompt_embeds = self.encode_prompt(
887
- prompt,
888
- device,
889
- num_videos_per_prompt,
890
- self.do_classifier_free_guidance,
891
- negative_prompt,
892
- prompt_embeds=prompt_embeds,
893
- negative_prompt_embeds=negative_prompt_embeds,
894
- lora_scale=text_encoder_lora_scale,
895
- clip_skip=self.clip_skip,
896
- )
897
- # For classifier free guidance, we need to do two forward passes.
898
- # Here we concatenate the unconditional and text embeddings into a single batch
899
- # to avoid doing two forward passes
900
- if self.do_classifier_free_guidance:
901
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
897
+ if self.free_noise_enabled:
898
+ prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
899
+ prompt=prompt,
900
+ num_frames=num_frames,
901
+ device=device,
902
+ num_videos_per_prompt=num_videos_per_prompt,
903
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
904
+ negative_prompt=negative_prompt,
905
+ prompt_embeds=prompt_embeds,
906
+ negative_prompt_embeds=negative_prompt_embeds,
907
+ lora_scale=text_encoder_lora_scale,
908
+ clip_skip=self.clip_skip,
909
+ )
910
+ else:
911
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
912
+ prompt,
913
+ device,
914
+ num_videos_per_prompt,
915
+ self.do_classifier_free_guidance,
916
+ negative_prompt,
917
+ prompt_embeds=prompt_embeds,
918
+ negative_prompt_embeds=negative_prompt_embeds,
919
+ lora_scale=text_encoder_lora_scale,
920
+ clip_skip=self.clip_skip,
921
+ )
922
+
923
+ # For classifier free guidance, we need to do two forward passes.
924
+ # Here we concatenate the unconditional and text embeddings into a single batch
925
+ # to avoid doing two forward passes
926
+ if self.do_classifier_free_guidance:
927
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
928
+
929
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
902
930
 
903
931
  if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
904
932
  image_embeds = self.prepare_ip_adapter_image_embeds(
@@ -990,6 +1018,9 @@ class AnimateDiffControlNetPipeline(
990
1018
  # 8. Denoising loop
991
1019
  with self.progress_bar(total=self._num_timesteps) as progress_bar:
992
1020
  for i, t in enumerate(timesteps):
1021
+ if self.interrupt:
1022
+ continue
1023
+
993
1024
  # expand the latents if we are doing classifier free guidance
994
1025
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
995
1026
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1002,7 +1033,6 @@ class AnimateDiffControlNetPipeline(
1002
1033
  else:
1003
1034
  control_model_input = latent_model_input
1004
1035
  controlnet_prompt_embeds = prompt_embeds
1005
- controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)
1006
1036
 
1007
1037
  if isinstance(controlnet_keep[i], list):
1008
1038
  cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
@@ -113,9 +113,21 @@ EXAMPLE_DOC_STRING = """
113
113
 
114
114
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
115
115
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
116
- """
117
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
118
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
116
+ r"""
117
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
118
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
119
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
120
+
121
+ Args:
122
+ noise_cfg (`torch.Tensor`):
123
+ The predicted noise tensor for the guided diffusion process.
124
+ noise_pred_text (`torch.Tensor`):
125
+ The predicted noise tensor for the text-guided diffusion process.
126
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
127
+ A rescale factor applied to the noise predictions.
128
+
129
+ Returns:
130
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
119
131
  """
120
132
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
121
133
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -135,7 +147,7 @@ def retrieve_timesteps(
135
147
  sigmas: Optional[List[float]] = None,
136
148
  **kwargs,
137
149
  ):
138
- """
150
+ r"""
139
151
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
140
152
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
141
153
 
@@ -1143,6 +1155,8 @@ class AnimateDiffSDXLPipeline(
1143
1155
  add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1144
1156
  add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1145
1157
 
1158
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
1159
+
1146
1160
  prompt_embeds = prompt_embeds.to(device)
1147
1161
  add_text_embeds = add_text_embeds.to(device)
1148
1162
  add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1)
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
24
24
  from ...image_processor import PipelineImageInput, VaeImageProcessor
25
25
  from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
26
26
  from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
27
- from ...models.controlnet_sparsectrl import SparseControlNetModel
27
+ from ...models.controlnets.controlnet_sparsectrl import SparseControlNetModel
28
28
  from ...models.lora import adjust_lora_scale_text_encoder
29
29
  from ...models.unets.unet_motion_model import MotionAdapter
30
30
  from ...schedulers import KarrasDiffusionSchedulers
@@ -878,6 +878,8 @@ class AnimateDiffSparseControlNetPipeline(
878
878
  if self.do_classifier_free_guidance:
879
879
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
880
880
 
881
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
882
+
881
883
  # 4. Prepare IP-Adapter embeddings
882
884
  if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
883
885
  image_embeds = self.prepare_ip_adapter_image_embeds(
@@ -119,7 +119,7 @@ def retrieve_timesteps(
119
119
  sigmas: Optional[List[float]] = None,
120
120
  **kwargs,
121
121
  ):
122
- """
122
+ r"""
123
123
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
124
124
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
125
125
 
@@ -246,7 +246,6 @@ class AnimateDiffVideoToVideoPipeline(
246
246
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
247
247
  self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
248
248
 
249
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
250
249
  def encode_prompt(
251
250
  self,
252
251
  prompt,
@@ -299,7 +298,7 @@ class AnimateDiffVideoToVideoPipeline(
299
298
  else:
300
299
  scale_lora_layers(self.text_encoder, lora_scale)
301
300
 
302
- if prompt is not None and isinstance(prompt, str):
301
+ if prompt is not None and isinstance(prompt, (str, dict)):
303
302
  batch_size = 1
304
303
  elif prompt is not None and isinstance(prompt, list):
305
304
  batch_size = len(prompt)
@@ -582,8 +581,8 @@ class AnimateDiffVideoToVideoPipeline(
582
581
  raise ValueError(
583
582
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
584
583
  )
585
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
586
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
584
+ elif prompt is not None and not isinstance(prompt, (str, list, dict)):
585
+ raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
587
586
 
588
587
  if negative_prompt is not None and negative_prompt_embeds is not None:
589
588
  raise ValueError(
@@ -628,23 +627,20 @@ class AnimateDiffVideoToVideoPipeline(
628
627
 
629
628
  def prepare_latents(
630
629
  self,
631
- video,
632
- height,
633
- width,
634
- num_channels_latents,
635
- batch_size,
636
- timestep,
637
- dtype,
638
- device,
639
- generator,
640
- latents=None,
630
+ video: Optional[torch.Tensor] = None,
631
+ height: int = 64,
632
+ width: int = 64,
633
+ num_channels_latents: int = 4,
634
+ batch_size: int = 1,
635
+ timestep: Optional[int] = None,
636
+ dtype: Optional[torch.dtype] = None,
637
+ device: Optional[torch.device] = None,
638
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
639
+ latents: Optional[torch.Tensor] = None,
641
640
  decode_chunk_size: int = 16,
642
- ):
643
- if latents is None:
644
- num_frames = video.shape[1]
645
- else:
646
- num_frames = latents.shape[2]
647
-
641
+ add_noise: bool = False,
642
+ ) -> torch.Tensor:
643
+ num_frames = video.shape[1] if latents is None else latents.shape[2]
648
644
  shape = (
649
645
  batch_size,
650
646
  num_channels_latents,
@@ -666,12 +662,6 @@ class AnimateDiffVideoToVideoPipeline(
666
662
  self.vae.to(dtype=torch.float32)
667
663
 
668
664
  if isinstance(generator, list):
669
- if len(generator) != batch_size:
670
- raise ValueError(
671
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
672
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
673
- )
674
-
675
665
  init_latents = [
676
666
  self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
677
667
  for i in range(batch_size)
@@ -708,8 +698,13 @@ class AnimateDiffVideoToVideoPipeline(
708
698
  if shape != latents.shape:
709
699
  # [B, C, F, H, W]
710
700
  raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
701
+
711
702
  latents = latents.to(device, dtype=dtype)
712
703
 
704
+ if add_noise:
705
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
706
+ latents = self.scheduler.add_noise(latents, noise, timestep)
707
+
713
708
  return latents
714
709
 
715
710
  @property
@@ -735,6 +730,10 @@ class AnimateDiffVideoToVideoPipeline(
735
730
  def num_timesteps(self):
736
731
  return self._num_timesteps
737
732
 
733
+ @property
734
+ def interrupt(self):
735
+ return self._interrupt
736
+
738
737
  @torch.no_grad()
739
738
  def __call__(
740
739
  self,
@@ -743,6 +742,7 @@ class AnimateDiffVideoToVideoPipeline(
743
742
  height: Optional[int] = None,
744
743
  width: Optional[int] = None,
745
744
  num_inference_steps: int = 50,
745
+ enforce_inference_steps: bool = False,
746
746
  timesteps: Optional[List[int]] = None,
747
747
  sigmas: Optional[List[float]] = None,
748
748
  guidance_scale: float = 7.5,
@@ -874,9 +874,10 @@ class AnimateDiffVideoToVideoPipeline(
874
874
  self._guidance_scale = guidance_scale
875
875
  self._clip_skip = clip_skip
876
876
  self._cross_attention_kwargs = cross_attention_kwargs
877
+ self._interrupt = False
877
878
 
878
879
  # 2. Define call parameters
879
- if prompt is not None and isinstance(prompt, str):
880
+ if prompt is not None and isinstance(prompt, (str, dict)):
880
881
  batch_size = 1
881
882
  elif prompt is not None and isinstance(prompt, list):
882
883
  batch_size = len(prompt)
@@ -884,51 +885,29 @@ class AnimateDiffVideoToVideoPipeline(
884
885
  batch_size = prompt_embeds.shape[0]
885
886
 
886
887
  device = self._execution_device
888
+ dtype = self.dtype
887
889
 
888
- # 3. Encode input prompt
889
- text_encoder_lora_scale = (
890
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
891
- )
892
- prompt_embeds, negative_prompt_embeds = self.encode_prompt(
893
- prompt,
894
- device,
895
- num_videos_per_prompt,
896
- self.do_classifier_free_guidance,
897
- negative_prompt,
898
- prompt_embeds=prompt_embeds,
899
- negative_prompt_embeds=negative_prompt_embeds,
900
- lora_scale=text_encoder_lora_scale,
901
- clip_skip=self.clip_skip,
902
- )
903
-
904
- # For classifier free guidance, we need to do two forward passes.
905
- # Here we concatenate the unconditional and text embeddings into a single batch
906
- # to avoid doing two forward passes
907
- if self.do_classifier_free_guidance:
908
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
909
-
910
- if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
911
- image_embeds = self.prepare_ip_adapter_image_embeds(
912
- ip_adapter_image,
913
- ip_adapter_image_embeds,
914
- device,
915
- batch_size * num_videos_per_prompt,
916
- self.do_classifier_free_guidance,
890
+ # 3. Prepare timesteps
891
+ if not enforce_inference_steps:
892
+ timesteps, num_inference_steps = retrieve_timesteps(
893
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
917
894
  )
895
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
896
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
897
+ else:
898
+ denoising_inference_steps = int(num_inference_steps / strength)
899
+ timesteps, denoising_inference_steps = retrieve_timesteps(
900
+ self.scheduler, denoising_inference_steps, device, timesteps, sigmas
901
+ )
902
+ timesteps = timesteps[-num_inference_steps:]
903
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
918
904
 
919
- # 4. Prepare timesteps
920
- timesteps, num_inference_steps = retrieve_timesteps(
921
- self.scheduler, num_inference_steps, device, timesteps, sigmas
922
- )
923
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
924
- latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
925
-
926
- # 5. Prepare latent variables
905
+ # 4. Prepare latent variables
927
906
  if latents is None:
928
907
  video = self.video_processor.preprocess_video(video, height=height, width=width)
929
908
  # Move the number of frames before the number of channels.
930
909
  video = video.permute(0, 2, 1, 3, 4)
931
- video = video.to(device=device, dtype=prompt_embeds.dtype)
910
+ video = video.to(device=device, dtype=dtype)
932
911
  num_channels_latents = self.unet.config.in_channels
933
912
  latents = self.prepare_latents(
934
913
  video=video,
@@ -937,17 +916,67 @@ class AnimateDiffVideoToVideoPipeline(
937
916
  num_channels_latents=num_channels_latents,
938
917
  batch_size=batch_size * num_videos_per_prompt,
939
918
  timestep=latent_timestep,
940
- dtype=prompt_embeds.dtype,
919
+ dtype=dtype,
941
920
  device=device,
942
921
  generator=generator,
943
922
  latents=latents,
944
923
  decode_chunk_size=decode_chunk_size,
924
+ add_noise=enforce_inference_steps,
945
925
  )
946
926
 
947
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
927
+ # 5. Encode input prompt
928
+ text_encoder_lora_scale = (
929
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
930
+ )
931
+ num_frames = latents.shape[2]
932
+ if self.free_noise_enabled:
933
+ prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
934
+ prompt=prompt,
935
+ num_frames=num_frames,
936
+ device=device,
937
+ num_videos_per_prompt=num_videos_per_prompt,
938
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
939
+ negative_prompt=negative_prompt,
940
+ prompt_embeds=prompt_embeds,
941
+ negative_prompt_embeds=negative_prompt_embeds,
942
+ lora_scale=text_encoder_lora_scale,
943
+ clip_skip=self.clip_skip,
944
+ )
945
+ else:
946
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
947
+ prompt,
948
+ device,
949
+ num_videos_per_prompt,
950
+ self.do_classifier_free_guidance,
951
+ negative_prompt,
952
+ prompt_embeds=prompt_embeds,
953
+ negative_prompt_embeds=negative_prompt_embeds,
954
+ lora_scale=text_encoder_lora_scale,
955
+ clip_skip=self.clip_skip,
956
+ )
957
+
958
+ # For classifier free guidance, we need to do two forward passes.
959
+ # Here we concatenate the unconditional and text embeddings into a single batch
960
+ # to avoid doing two forward passes
961
+ if self.do_classifier_free_guidance:
962
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
963
+
964
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
965
+
966
+ # 6. Prepare IP-Adapter embeddings
967
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
968
+ image_embeds = self.prepare_ip_adapter_image_embeds(
969
+ ip_adapter_image,
970
+ ip_adapter_image_embeds,
971
+ device,
972
+ batch_size * num_videos_per_prompt,
973
+ self.do_classifier_free_guidance,
974
+ )
975
+
976
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
948
977
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
949
978
 
950
- # 7. Add image embeds for IP-Adapter
979
+ # 8. Add image embeds for IP-Adapter
951
980
  added_cond_kwargs = (
952
981
  {"image_embeds": image_embeds}
953
982
  if ip_adapter_image is not None or ip_adapter_image_embeds is not None
@@ -967,9 +996,12 @@ class AnimateDiffVideoToVideoPipeline(
967
996
  self._num_timesteps = len(timesteps)
968
997
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
969
998
 
970
- # 8. Denoising loop
999
+ # 9. Denoising loop
971
1000
  with self.progress_bar(total=self._num_timesteps) as progress_bar:
972
1001
  for i, t in enumerate(timesteps):
1002
+ if self.interrupt:
1003
+ continue
1004
+
973
1005
  # expand the latents if we are doing classifier free guidance
974
1006
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
975
1007
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1005,14 +1037,14 @@ class AnimateDiffVideoToVideoPipeline(
1005
1037
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1006
1038
  progress_bar.update()
1007
1039
 
1008
- # 9. Post-processing
1040
+ # 10. Post-processing
1009
1041
  if output_type == "latent":
1010
1042
  video = latents
1011
1043
  else:
1012
1044
  video_tensor = self.decode_latents(latents, decode_chunk_size)
1013
1045
  video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
1014
1046
 
1015
- # 10. Offload all models
1047
+ # 11. Offload all models
1016
1048
  self.maybe_free_model_hooks()
1017
1049
 
1018
1050
  if not return_dict: