diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. diffusers/__init__.py +38 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +238 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +40 -7
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +6 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
  53. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  54. diffusers/pipelines/cogview3/__init__.py +47 -0
  55. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  56. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  57. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  58. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  60. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  62. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  63. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  64. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  66. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  67. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  68. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  70. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  71. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  72. diffusers/pipelines/flux/__init__.py +10 -0
  73. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  74. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  76. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  77. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  78. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  79. diffusers/pipelines/free_noise_utils.py +365 -5
  80. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  81. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  82. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  83. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  84. diffusers/pipelines/kolors/tokenizer.py +4 -0
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  86. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  87. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  89. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  90. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  91. diffusers/pipelines/pag/__init__.py +6 -0
  92. diffusers/pipelines/pag/pag_utils.py +8 -2
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  96. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  97. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  98. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  100. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  101. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  102. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  103. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  106. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  107. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  108. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  109. diffusers/pipelines/pipeline_utils.py +123 -180
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  111. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  113. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  117. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  120. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  121. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  122. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  123. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  129. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  131. diffusers/quantizers/__init__.py +16 -0
  132. diffusers/quantizers/auto.py +126 -0
  133. diffusers/quantizers/base.py +233 -0
  134. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  135. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  136. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  137. diffusers/quantizers/quantization_config.py +391 -0
  138. diffusers/schedulers/scheduling_ddim.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  140. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm.py +4 -1
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  143. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  148. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  149. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  150. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  151. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  152. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  155. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  156. diffusers/schedulers/scheduling_sasolver.py +78 -1
  157. diffusers/schedulers/scheduling_unclip.py +4 -1
  158. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  159. diffusers/training_utils.py +48 -18
  160. diffusers/utils/__init__.py +2 -1
  161. diffusers/utils/dummy_pt_objects.py +60 -0
  162. diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
  163. diffusers/utils/hub_utils.py +16 -4
  164. diffusers/utils/import_utils.py +31 -8
  165. diffusers/utils/loading_utils.py +28 -4
  166. diffusers/utils/peft_utils.py +3 -3
  167. diffusers/utils/testing_utils.py +59 -0
  168. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  169. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
  170. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  172. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  173. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -123,8 +123,16 @@ else:
123
123
  "AnimateDiffSDXLPipeline",
124
124
  "AnimateDiffSparseControlNetPipeline",
125
125
  "AnimateDiffVideoToVideoPipeline",
126
+ "AnimateDiffVideoToVideoControlNetPipeline",
127
+ ]
128
+ _import_structure["flux"] = [
129
+ "FluxControlNetPipeline",
130
+ "FluxControlNetImg2ImgPipeline",
131
+ "FluxControlNetInpaintPipeline",
132
+ "FluxImg2ImgPipeline",
133
+ "FluxInpaintPipeline",
134
+ "FluxPipeline",
126
135
  ]
127
- _import_structure["flux"] = ["FluxPipeline"]
128
136
  _import_structure["audioldm"] = ["AudioLDMPipeline"]
129
137
  _import_structure["audioldm2"] = [
130
138
  "AudioLDM2Pipeline",
@@ -132,7 +140,13 @@ else:
132
140
  "AudioLDM2UNet2DConditionModel",
133
141
  ]
134
142
  _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
135
- _import_structure["cogvideo"] = ["CogVideoXPipeline"]
143
+ _import_structure["cogvideo"] = [
144
+ "CogVideoXPipeline",
145
+ "CogVideoXImageToVideoPipeline",
146
+ "CogVideoXVideoToVideoPipeline",
147
+ "CogVideoXFunControlPipeline",
148
+ ]
149
+ _import_structure["cogview3"] = ["CogView3PlusPipeline"]
136
150
  _import_structure["controlnet"].extend(
137
151
  [
138
152
  "BlipDiffusionControlNetPipeline",
@@ -146,14 +160,17 @@ else:
146
160
  )
147
161
  _import_structure["pag"].extend(
148
162
  [
163
+ "StableDiffusionControlNetPAGInpaintPipeline",
149
164
  "AnimateDiffPAGPipeline",
150
165
  "KolorsPAGPipeline",
151
166
  "HunyuanDiTPAGPipeline",
152
167
  "StableDiffusion3PAGPipeline",
153
168
  "StableDiffusionPAGPipeline",
169
+ "StableDiffusionPAGImg2ImgPipeline",
154
170
  "StableDiffusionControlNetPAGPipeline",
155
171
  "StableDiffusionXLPAGPipeline",
156
172
  "StableDiffusionXLPAGInpaintPipeline",
173
+ "StableDiffusionXLControlNetPAGImg2ImgPipeline",
157
174
  "StableDiffusionXLControlNetPAGPipeline",
158
175
  "StableDiffusionXLPAGImg2ImgPipeline",
159
176
  "PixArtSigmaPAGPipeline",
@@ -173,6 +190,7 @@ else:
173
190
  _import_structure["controlnet_sd3"].extend(
174
191
  [
175
192
  "StableDiffusion3ControlNetPipeline",
193
+ "StableDiffusion3ControlNetInpaintingPipeline",
176
194
  ]
177
195
  )
178
196
  _import_structure["deepfloyd_if"] = [
@@ -442,6 +460,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
442
460
  AnimateDiffPipeline,
443
461
  AnimateDiffSDXLPipeline,
444
462
  AnimateDiffSparseControlNetPipeline,
463
+ AnimateDiffVideoToVideoControlNetPipeline,
445
464
  AnimateDiffVideoToVideoPipeline,
446
465
  )
447
466
  from .audioldm import AudioLDMPipeline
@@ -452,7 +471,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
452
471
  )
453
472
  from .aura_flow import AuraFlowPipeline
454
473
  from .blip_diffusion import BlipDiffusionPipeline
455
- from .cogvideo import CogVideoXPipeline
474
+ from .cogvideo import (
475
+ CogVideoXFunControlPipeline,
476
+ CogVideoXImageToVideoPipeline,
477
+ CogVideoXPipeline,
478
+ CogVideoXVideoToVideoPipeline,
479
+ )
480
+ from .cogview3 import CogView3PlusPipeline
456
481
  from .controlnet import (
457
482
  BlipDiffusionControlNetPipeline,
458
483
  StableDiffusionControlNetImg2ImgPipeline,
@@ -465,9 +490,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
465
490
  from .controlnet_hunyuandit import (
466
491
  HunyuanDiTControlNetPipeline,
467
492
  )
468
- from .controlnet_sd3 import (
469
- StableDiffusion3ControlNetPipeline,
470
- )
493
+ from .controlnet_sd3 import StableDiffusion3ControlNetInpaintingPipeline, StableDiffusion3ControlNetPipeline
471
494
  from .controlnet_xs import (
472
495
  StableDiffusionControlNetXSPipeline,
473
496
  StableDiffusionXLControlNetXSPipeline,
@@ -494,7 +517,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
494
517
  VersatileDiffusionTextToImagePipeline,
495
518
  VQDiffusionPipeline,
496
519
  )
497
- from .flux import FluxPipeline
520
+ from .flux import (
521
+ FluxControlNetImg2ImgPipeline,
522
+ FluxControlNetInpaintPipeline,
523
+ FluxControlNetPipeline,
524
+ FluxImg2ImgPipeline,
525
+ FluxInpaintPipeline,
526
+ FluxPipeline,
527
+ )
498
528
  from .hunyuandit import HunyuanDiTPipeline
499
529
  from .i2vgen_xl import I2VGenXLPipeline
500
530
  from .kandinsky import (
@@ -546,8 +576,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
546
576
  KolorsPAGPipeline,
547
577
  PixArtSigmaPAGPipeline,
548
578
  StableDiffusion3PAGPipeline,
579
+ StableDiffusionControlNetPAGInpaintPipeline,
549
580
  StableDiffusionControlNetPAGPipeline,
581
+ StableDiffusionPAGImg2ImgPipeline,
550
582
  StableDiffusionPAGPipeline,
583
+ StableDiffusionXLControlNetPAGImg2ImgPipeline,
551
584
  StableDiffusionXLControlNetPAGPipeline,
552
585
  StableDiffusionXLPAGImg2ImgPipeline,
553
586
  StableDiffusionXLPAGInpaintPipeline,
@@ -26,6 +26,7 @@ else:
26
26
  _import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"]
27
27
  _import_structure["pipeline_animatediff_sparsectrl"] = ["AnimateDiffSparseControlNetPipeline"]
28
28
  _import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
29
+ _import_structure["pipeline_animatediff_video2video_controlnet"] = ["AnimateDiffVideoToVideoControlNetPipeline"]
29
30
 
30
31
  if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
31
32
  try:
@@ -40,6 +41,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
40
41
  from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline
41
42
  from .pipeline_animatediff_sparsectrl import AnimateDiffSparseControlNetPipeline
42
43
  from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
44
+ from .pipeline_animatediff_video2video_controlnet import AnimateDiffVideoToVideoControlNetPipeline
43
45
  from .pipeline_output import AnimateDiffPipelineOutput
44
46
 
45
47
  else:
@@ -432,7 +432,6 @@ class AnimateDiffPipeline(
432
432
  extra_step_kwargs["generator"] = generator
433
433
  return extra_step_kwargs
434
434
 
435
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
436
435
  def check_inputs(
437
436
  self,
438
437
  prompt,
@@ -470,8 +469,8 @@ class AnimateDiffPipeline(
470
469
  raise ValueError(
471
470
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
472
471
  )
473
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
474
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
472
+ elif prompt is not None and not isinstance(prompt, (str, list, dict)):
473
+ raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)=}")
475
474
 
476
475
  if negative_prompt is not None and negative_prompt_embeds is not None:
477
476
  raise ValueError(
@@ -557,11 +556,15 @@ class AnimateDiffPipeline(
557
556
  def num_timesteps(self):
558
557
  return self._num_timesteps
559
558
 
559
+ @property
560
+ def interrupt(self):
561
+ return self._interrupt
562
+
560
563
  @torch.no_grad()
561
564
  @replace_example_docstring(EXAMPLE_DOC_STRING)
562
565
  def __call__(
563
566
  self,
564
- prompt: Union[str, List[str]] = None,
567
+ prompt: Optional[Union[str, List[str]]] = None,
565
568
  num_frames: Optional[int] = 16,
566
569
  height: Optional[int] = None,
567
570
  width: Optional[int] = None,
@@ -701,9 +704,10 @@ class AnimateDiffPipeline(
701
704
  self._guidance_scale = guidance_scale
702
705
  self._clip_skip = clip_skip
703
706
  self._cross_attention_kwargs = cross_attention_kwargs
707
+ self._interrupt = False
704
708
 
705
709
  # 2. Define call parameters
706
- if prompt is not None and isinstance(prompt, str):
710
+ if prompt is not None and isinstance(prompt, (str, dict)):
707
711
  batch_size = 1
708
712
  elif prompt is not None and isinstance(prompt, list):
709
713
  batch_size = len(prompt)
@@ -716,22 +720,39 @@ class AnimateDiffPipeline(
716
720
  text_encoder_lora_scale = (
717
721
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
718
722
  )
719
- prompt_embeds, negative_prompt_embeds = self.encode_prompt(
720
- prompt,
721
- device,
722
- num_videos_per_prompt,
723
- self.do_classifier_free_guidance,
724
- negative_prompt,
725
- prompt_embeds=prompt_embeds,
726
- negative_prompt_embeds=negative_prompt_embeds,
727
- lora_scale=text_encoder_lora_scale,
728
- clip_skip=self.clip_skip,
729
- )
730
- # For classifier free guidance, we need to do two forward passes.
731
- # Here we concatenate the unconditional and text embeddings into a single batch
732
- # to avoid doing two forward passes
733
- if self.do_classifier_free_guidance:
734
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
723
+ if self.free_noise_enabled:
724
+ prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
725
+ prompt=prompt,
726
+ num_frames=num_frames,
727
+ device=device,
728
+ num_videos_per_prompt=num_videos_per_prompt,
729
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
730
+ negative_prompt=negative_prompt,
731
+ prompt_embeds=prompt_embeds,
732
+ negative_prompt_embeds=negative_prompt_embeds,
733
+ lora_scale=text_encoder_lora_scale,
734
+ clip_skip=self.clip_skip,
735
+ )
736
+ else:
737
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
738
+ prompt,
739
+ device,
740
+ num_videos_per_prompt,
741
+ self.do_classifier_free_guidance,
742
+ negative_prompt,
743
+ prompt_embeds=prompt_embeds,
744
+ negative_prompt_embeds=negative_prompt_embeds,
745
+ lora_scale=text_encoder_lora_scale,
746
+ clip_skip=self.clip_skip,
747
+ )
748
+
749
+ # For classifier free guidance, we need to do two forward passes.
750
+ # Here we concatenate the unconditional and text embeddings into a single batch
751
+ # to avoid doing two forward passes
752
+ if self.do_classifier_free_guidance:
753
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
754
+
755
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
735
756
 
736
757
  if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
737
758
  image_embeds = self.prepare_ip_adapter_image_embeds(
@@ -783,6 +804,9 @@ class AnimateDiffPipeline(
783
804
  # 8. Denoising loop
784
805
  with self.progress_bar(total=self._num_timesteps) as progress_bar:
785
806
  for i, t in enumerate(timesteps):
807
+ if self.interrupt:
808
+ continue
809
+
786
810
  # expand the latents if we are doing classifier free guidance
787
811
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
788
812
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -505,8 +505,8 @@ class AnimateDiffControlNetPipeline(
505
505
  raise ValueError(
506
506
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
507
507
  )
508
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
509
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
508
+ elif prompt is not None and not isinstance(prompt, (str, list, dict)):
509
+ raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
510
510
 
511
511
  if negative_prompt is not None and negative_prompt_embeds is not None:
512
512
  raise ValueError(
@@ -699,6 +699,10 @@ class AnimateDiffControlNetPipeline(
699
699
  def num_timesteps(self):
700
700
  return self._num_timesteps
701
701
 
702
+ @property
703
+ def interrupt(self):
704
+ return self._interrupt
705
+
702
706
  @torch.no_grad()
703
707
  def __call__(
704
708
  self,
@@ -858,9 +862,10 @@ class AnimateDiffControlNetPipeline(
858
862
  self._guidance_scale = guidance_scale
859
863
  self._clip_skip = clip_skip
860
864
  self._cross_attention_kwargs = cross_attention_kwargs
865
+ self._interrupt = False
861
866
 
862
867
  # 2. Define call parameters
863
- if prompt is not None and isinstance(prompt, str):
868
+ if prompt is not None and isinstance(prompt, (str, dict)):
864
869
  batch_size = 1
865
870
  elif prompt is not None and isinstance(prompt, list):
866
871
  batch_size = len(prompt)
@@ -883,22 +888,39 @@ class AnimateDiffControlNetPipeline(
883
888
  text_encoder_lora_scale = (
884
889
  cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
885
890
  )
886
- prompt_embeds, negative_prompt_embeds = self.encode_prompt(
887
- prompt,
888
- device,
889
- num_videos_per_prompt,
890
- self.do_classifier_free_guidance,
891
- negative_prompt,
892
- prompt_embeds=prompt_embeds,
893
- negative_prompt_embeds=negative_prompt_embeds,
894
- lora_scale=text_encoder_lora_scale,
895
- clip_skip=self.clip_skip,
896
- )
897
- # For classifier free guidance, we need to do two forward passes.
898
- # Here we concatenate the unconditional and text embeddings into a single batch
899
- # to avoid doing two forward passes
900
- if self.do_classifier_free_guidance:
901
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
891
+ if self.free_noise_enabled:
892
+ prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
893
+ prompt=prompt,
894
+ num_frames=num_frames,
895
+ device=device,
896
+ num_videos_per_prompt=num_videos_per_prompt,
897
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
898
+ negative_prompt=negative_prompt,
899
+ prompt_embeds=prompt_embeds,
900
+ negative_prompt_embeds=negative_prompt_embeds,
901
+ lora_scale=text_encoder_lora_scale,
902
+ clip_skip=self.clip_skip,
903
+ )
904
+ else:
905
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
906
+ prompt,
907
+ device,
908
+ num_videos_per_prompt,
909
+ self.do_classifier_free_guidance,
910
+ negative_prompt,
911
+ prompt_embeds=prompt_embeds,
912
+ negative_prompt_embeds=negative_prompt_embeds,
913
+ lora_scale=text_encoder_lora_scale,
914
+ clip_skip=self.clip_skip,
915
+ )
916
+
917
+ # For classifier free guidance, we need to do two forward passes.
918
+ # Here we concatenate the unconditional and text embeddings into a single batch
919
+ # to avoid doing two forward passes
920
+ if self.do_classifier_free_guidance:
921
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
922
+
923
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
902
924
 
903
925
  if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
904
926
  image_embeds = self.prepare_ip_adapter_image_embeds(
@@ -990,6 +1012,9 @@ class AnimateDiffControlNetPipeline(
990
1012
  # 8. Denoising loop
991
1013
  with self.progress_bar(total=self._num_timesteps) as progress_bar:
992
1014
  for i, t in enumerate(timesteps):
1015
+ if self.interrupt:
1016
+ continue
1017
+
993
1018
  # expand the latents if we are doing classifier free guidance
994
1019
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
995
1020
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1002,7 +1027,6 @@ class AnimateDiffControlNetPipeline(
1002
1027
  else:
1003
1028
  control_model_input = latent_model_input
1004
1029
  controlnet_prompt_embeds = prompt_embeds
1005
- controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)
1006
1030
 
1007
1031
  if isinstance(controlnet_keep[i], list):
1008
1032
  cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
@@ -113,9 +113,21 @@ EXAMPLE_DOC_STRING = """
113
113
 
114
114
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
115
115
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
116
- """
117
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
118
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
116
+ r"""
117
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
118
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
119
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
120
+
121
+ Args:
122
+ noise_cfg (`torch.Tensor`):
123
+ The predicted noise tensor for the guided diffusion process.
124
+ noise_pred_text (`torch.Tensor`):
125
+ The predicted noise tensor for the text-guided diffusion process.
126
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
127
+ A rescale factor applied to the noise predictions.
128
+
129
+ Returns:
130
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
119
131
  """
120
132
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
121
133
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -135,7 +147,7 @@ def retrieve_timesteps(
135
147
  sigmas: Optional[List[float]] = None,
136
148
  **kwargs,
137
149
  ):
138
- """
150
+ r"""
139
151
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
140
152
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
141
153
 
@@ -1143,6 +1155,8 @@ class AnimateDiffSDXLPipeline(
1143
1155
  add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1144
1156
  add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1145
1157
 
1158
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
1159
+
1146
1160
  prompt_embeds = prompt_embeds.to(device)
1147
1161
  add_text_embeds = add_text_embeds.to(device)
1148
1162
  add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1)
@@ -878,6 +878,8 @@ class AnimateDiffSparseControlNetPipeline(
878
878
  if self.do_classifier_free_guidance:
879
879
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
880
880
 
881
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
882
+
881
883
  # 4. Prepare IP-Adapter embeddings
882
884
  if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
883
885
  image_embeds = self.prepare_ip_adapter_image_embeds(