diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2222 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +1 -12
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +262 -2
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1795 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +319 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +1 -4
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +19 -16
  210. diffusers/utils/loading_utils.py +76 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -261,7 +261,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
261
261
  processors: Dict[str, AttentionProcessor],
262
262
  ):
263
263
  if hasattr(module, "get_processor"):
264
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
264
+ processors[f"{name}.processor"] = module.get_processor()
265
265
 
266
266
  for sub_name, child in module.named_children():
267
267
  fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@@ -478,9 +478,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
478
478
  create_custom_forward(block), x, r_embed, use_reentrant=False
479
479
  )
480
480
  else:
481
- x = x = torch.utils.checkpoint.checkpoint(
482
- create_custom_forward(block), use_reentrant=False
483
- )
481
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), use_reentrant=False)
484
482
  if i < len(repmap):
485
483
  x = repmap[i](x)
486
484
  level_outputs.insert(0, x)
@@ -225,7 +225,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
225
225
 
226
226
  def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
227
227
  if hasattr(module, "get_processor"):
228
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
228
+ processors[f"{name}.processor"] = module.get_processor()
229
229
 
230
230
  for sub_name, child in module.named_children():
231
231
  fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@@ -348,6 +348,70 @@ class KUpsample2D(nn.Module):
348
348
  return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
349
349
 
350
350
 
351
+ class CogVideoXUpsample3D(nn.Module):
352
+ r"""
353
+ A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
354
+
355
+ Args:
356
+ in_channels (`int`):
357
+ Number of channels in the input image.
358
+ out_channels (`int`):
359
+ Number of channels produced by the convolution.
360
+ kernel_size (`int`, defaults to `3`):
361
+ Size of the convolving kernel.
362
+ stride (`int`, defaults to `1`):
363
+ Stride of the convolution.
364
+ padding (`int`, defaults to `1`):
365
+ Padding added to all four sides of the input.
366
+ compress_time (`bool`, defaults to `False`):
367
+ Whether or not to compress the time dimension.
368
+ """
369
+
370
+ def __init__(
371
+ self,
372
+ in_channels: int,
373
+ out_channels: int,
374
+ kernel_size: int = 3,
375
+ stride: int = 1,
376
+ padding: int = 1,
377
+ compress_time: bool = False,
378
+ ) -> None:
379
+ super().__init__()
380
+
381
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
382
+ self.compress_time = compress_time
383
+
384
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
385
+ if self.compress_time:
386
+ if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
387
+ # split first frame
388
+ x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
389
+
390
+ x_first = F.interpolate(x_first, scale_factor=2.0)
391
+ x_rest = F.interpolate(x_rest, scale_factor=2.0)
392
+ x_first = x_first[:, :, None, :, :]
393
+ inputs = torch.cat([x_first, x_rest], dim=2)
394
+ elif inputs.shape[2] > 1:
395
+ inputs = F.interpolate(inputs, scale_factor=2.0)
396
+ else:
397
+ inputs = inputs.squeeze(2)
398
+ inputs = F.interpolate(inputs, scale_factor=2.0)
399
+ inputs = inputs[:, :, None, :, :]
400
+ else:
401
+ # only interpolate 2D
402
+ b, c, t, h, w = inputs.shape
403
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
404
+ inputs = F.interpolate(inputs, scale_factor=2.0)
405
+ inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
406
+
407
+ b, c, t, h, w = inputs.shape
408
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
409
+ inputs = self.conv(inputs)
410
+ inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
411
+
412
+ return inputs
413
+
414
+
351
415
  def upfirdn2d_native(
352
416
  tensor: torch.Tensor,
353
417
  kernel: torch.Tensor,
@@ -16,10 +16,14 @@ from .autoencoders.vq_model import VQEncoderOutput, VQModel
16
16
 
17
17
 
18
18
  class VQEncoderOutput(VQEncoderOutput):
19
- deprecation_message = "Importing `VQEncoderOutput` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQEncoderOutput`, instead."
20
- deprecate("VQEncoderOutput", "0.31", deprecation_message)
19
+ def __init__(self, *args, **kwargs):
20
+ deprecation_message = "Importing `VQEncoderOutput` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQEncoderOutput`, instead."
21
+ deprecate("VQEncoderOutput", "0.31", deprecation_message)
22
+ super().__init__(*args, **kwargs)
21
23
 
22
24
 
23
25
  class VQModel(VQModel):
24
- deprecation_message = "Importing `VQModel` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQModel`, instead."
25
- deprecate("VQModel", "0.31", deprecation_message)
26
+ def __init__(self, *args, **kwargs):
27
+ deprecation_message = "Importing `VQModel` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQModel`, instead."
28
+ deprecate("VQModel", "0.31", deprecation_message)
29
+ super().__init__(*args, **kwargs)
diffusers/optimization.py CHANGED
@@ -87,7 +87,7 @@ def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_
87
87
  The optimizer for which to schedule the learning rate.
88
88
  step_rules (`string`):
89
89
  The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
90
- if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
90
+ if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
91
91
  steps and multiple 0.005 for the other steps.
92
92
  last_epoch (`int`, *optional*, defaults to -1):
93
93
  The index of the last epoch when resuming training.
@@ -10,6 +10,7 @@ from ..utils import (
10
10
  is_librosa_available,
11
11
  is_note_seq_available,
12
12
  is_onnx_available,
13
+ is_sentencepiece_available,
13
14
  is_torch_available,
14
15
  is_torch_npu_available,
15
16
  is_transformers_available,
@@ -20,12 +21,14 @@ from ..utils import (
20
21
  _dummy_objects = {}
21
22
  _import_structure = {
22
23
  "controlnet": [],
24
+ "controlnet_hunyuandit": [],
23
25
  "controlnet_sd3": [],
24
26
  "controlnet_xs": [],
25
27
  "deprecated": [],
26
28
  "latent_diffusion": [],
27
29
  "ledits_pp": [],
28
30
  "marigold": [],
31
+ "pag": [],
29
32
  "stable_diffusion": [],
30
33
  "stable_diffusion_xl": [],
31
34
  }
@@ -116,9 +119,12 @@ else:
116
119
  _import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
117
120
  _import_structure["animatediff"] = [
118
121
  "AnimateDiffPipeline",
122
+ "AnimateDiffControlNetPipeline",
119
123
  "AnimateDiffSDXLPipeline",
124
+ "AnimateDiffSparseControlNetPipeline",
120
125
  "AnimateDiffVideoToVideoPipeline",
121
126
  ]
127
+ _import_structure["flux"] = ["FluxPipeline"]
122
128
  _import_structure["audioldm"] = ["AudioLDMPipeline"]
123
129
  _import_structure["audioldm2"] = [
124
130
  "AudioLDM2Pipeline",
@@ -126,6 +132,7 @@ else:
126
132
  "AudioLDM2UNet2DConditionModel",
127
133
  ]
128
134
  _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
135
+ _import_structure["cogvideo"] = ["CogVideoXPipeline"]
129
136
  _import_structure["controlnet"].extend(
130
137
  [
131
138
  "BlipDiffusionControlNetPipeline",
@@ -137,12 +144,32 @@ else:
137
144
  "StableDiffusionXLControlNetPipeline",
138
145
  ]
139
146
  )
147
+ _import_structure["pag"].extend(
148
+ [
149
+ "AnimateDiffPAGPipeline",
150
+ "KolorsPAGPipeline",
151
+ "HunyuanDiTPAGPipeline",
152
+ "StableDiffusion3PAGPipeline",
153
+ "StableDiffusionPAGPipeline",
154
+ "StableDiffusionControlNetPAGPipeline",
155
+ "StableDiffusionXLPAGPipeline",
156
+ "StableDiffusionXLPAGInpaintPipeline",
157
+ "StableDiffusionXLControlNetPAGPipeline",
158
+ "StableDiffusionXLPAGImg2ImgPipeline",
159
+ "PixArtSigmaPAGPipeline",
160
+ ]
161
+ )
140
162
  _import_structure["controlnet_xs"].extend(
141
163
  [
142
164
  "StableDiffusionControlNetXSPipeline",
143
165
  "StableDiffusionXLControlNetXSPipeline",
144
166
  ]
145
167
  )
168
+ _import_structure["controlnet_hunyuandit"].extend(
169
+ [
170
+ "HunyuanDiTControlNetPipeline",
171
+ ]
172
+ )
146
173
  _import_structure["controlnet_sd3"].extend(
147
174
  [
148
175
  "StableDiffusion3ControlNetPipeline",
@@ -193,6 +220,8 @@ else:
193
220
  "LEditsPPPipelineStableDiffusionXL",
194
221
  ]
195
222
  )
223
+ _import_structure["latte"] = ["LattePipeline"]
224
+ _import_structure["lumina"] = ["LuminaText2ImgPipeline"]
196
225
  _import_structure["marigold"].extend(
197
226
  [
198
227
  "MarigoldDepthPipeline",
@@ -205,6 +234,10 @@ else:
205
234
  _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
206
235
  _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
207
236
  _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
237
+ _import_structure["stable_audio"] = [
238
+ "StableAudioProjectionModel",
239
+ "StableAudioPipeline",
240
+ ]
208
241
  _import_structure["stable_cascade"] = [
209
242
  "StableCascadeCombinedPipeline",
210
243
  "StableCascadeDecoderPipeline",
@@ -226,7 +259,12 @@ else:
226
259
  "StableDiffusionLDM3DPipeline",
227
260
  ]
228
261
  )
229
- _import_structure["stable_diffusion_3"] = ["StableDiffusion3Pipeline", "StableDiffusion3Img2ImgPipeline"]
262
+ _import_structure["aura_flow"] = ["AuraFlowPipeline"]
263
+ _import_structure["stable_diffusion_3"] = [
264
+ "StableDiffusion3Pipeline",
265
+ "StableDiffusion3Img2ImgPipeline",
266
+ "StableDiffusion3InpaintPipeline",
267
+ ]
230
268
  _import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
231
269
  _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
232
270
  _import_structure["stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
@@ -310,6 +348,22 @@ else:
310
348
  "StableDiffusionKDiffusionPipeline",
311
349
  "StableDiffusionXLKDiffusionPipeline",
312
350
  ]
351
+
352
+ try:
353
+ if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
354
+ raise OptionalDependencyNotAvailable()
355
+ except OptionalDependencyNotAvailable:
356
+ from ..utils import (
357
+ dummy_torch_and_transformers_and_sentencepiece_objects,
358
+ )
359
+
360
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_sentencepiece_objects))
361
+ else:
362
+ _import_structure["kolors"] = [
363
+ "KolorsPipeline",
364
+ "KolorsImg2ImgPipeline",
365
+ ]
366
+
313
367
  try:
314
368
  if not is_flax_available():
315
369
  raise OptionalDependencyNotAvailable()
@@ -383,14 +437,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
383
437
  from ..utils.dummy_torch_and_transformers_objects import *
384
438
  else:
385
439
  from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
386
- from .animatediff import AnimateDiffPipeline, AnimateDiffSDXLPipeline, AnimateDiffVideoToVideoPipeline
440
+ from .animatediff import (
441
+ AnimateDiffControlNetPipeline,
442
+ AnimateDiffPipeline,
443
+ AnimateDiffSDXLPipeline,
444
+ AnimateDiffSparseControlNetPipeline,
445
+ AnimateDiffVideoToVideoPipeline,
446
+ )
387
447
  from .audioldm import AudioLDMPipeline
388
448
  from .audioldm2 import (
389
449
  AudioLDM2Pipeline,
390
450
  AudioLDM2ProjectionModel,
391
451
  AudioLDM2UNet2DConditionModel,
392
452
  )
453
+ from .aura_flow import AuraFlowPipeline
393
454
  from .blip_diffusion import BlipDiffusionPipeline
455
+ from .cogvideo import CogVideoXPipeline
394
456
  from .controlnet import (
395
457
  BlipDiffusionControlNetPipeline,
396
458
  StableDiffusionControlNetImg2ImgPipeline,
@@ -400,6 +462,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
400
462
  StableDiffusionXLControlNetInpaintPipeline,
401
463
  StableDiffusionXLControlNetPipeline,
402
464
  )
465
+ from .controlnet_hunyuandit import (
466
+ HunyuanDiTControlNetPipeline,
467
+ )
403
468
  from .controlnet_sd3 import (
404
469
  StableDiffusion3ControlNetPipeline,
405
470
  )
@@ -429,6 +494,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
429
494
  VersatileDiffusionTextToImagePipeline,
430
495
  VQDiffusionPipeline,
431
496
  )
497
+ from .flux import FluxPipeline
432
498
  from .hunyuandit import HunyuanDiTPipeline
433
499
  from .i2vgen_xl import I2VGenXLPipeline
434
500
  from .kandinsky import (
@@ -461,22 +527,38 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
461
527
  LatentConsistencyModelPipeline,
462
528
  )
463
529
  from .latent_diffusion import LDMTextToImagePipeline
530
+ from .latte import LattePipeline
464
531
  from .ledits_pp import (
465
532
  LEditsPPDiffusionPipelineOutput,
466
533
  LEditsPPInversionPipelineOutput,
467
534
  LEditsPPPipelineStableDiffusion,
468
535
  LEditsPPPipelineStableDiffusionXL,
469
536
  )
537
+ from .lumina import LuminaText2ImgPipeline
470
538
  from .marigold import (
471
539
  MarigoldDepthPipeline,
472
540
  MarigoldNormalsPipeline,
473
541
  )
474
542
  from .musicldm import MusicLDMPipeline
543
+ from .pag import (
544
+ AnimateDiffPAGPipeline,
545
+ HunyuanDiTPAGPipeline,
546
+ KolorsPAGPipeline,
547
+ PixArtSigmaPAGPipeline,
548
+ StableDiffusion3PAGPipeline,
549
+ StableDiffusionControlNetPAGPipeline,
550
+ StableDiffusionPAGPipeline,
551
+ StableDiffusionXLControlNetPAGPipeline,
552
+ StableDiffusionXLPAGImg2ImgPipeline,
553
+ StableDiffusionXLPAGInpaintPipeline,
554
+ StableDiffusionXLPAGPipeline,
555
+ )
475
556
  from .paint_by_example import PaintByExamplePipeline
476
557
  from .pia import PIAPipeline
477
558
  from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
478
559
  from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
479
560
  from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
561
+ from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
480
562
  from .stable_cascade import (
481
563
  StableCascadeCombinedPipeline,
482
564
  StableCascadeDecoderPipeline,
@@ -495,7 +577,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
495
577
  StableUnCLIPImg2ImgPipeline,
496
578
  StableUnCLIPPipeline,
497
579
  )
498
- from .stable_diffusion_3 import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline
580
+ from .stable_diffusion_3 import (
581
+ StableDiffusion3Img2ImgPipeline,
582
+ StableDiffusion3InpaintPipeline,
583
+ StableDiffusion3Pipeline,
584
+ )
499
585
  from .stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
500
586
  from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
501
587
  from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline
@@ -567,6 +653,17 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
567
653
  StableDiffusionXLKDiffusionPipeline,
568
654
  )
569
655
 
656
+ try:
657
+ if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
658
+ raise OptionalDependencyNotAvailable()
659
+ except OptionalDependencyNotAvailable:
660
+ from ..utils.dummy_torch_and_transformers_and_sentencepiece_objects import *
661
+ else:
662
+ from .kolors import (
663
+ KolorsImg2ImgPipeline,
664
+ KolorsPipeline,
665
+ )
666
+
570
667
  try:
571
668
  if not is_flax_available():
572
669
  raise OptionalDependencyNotAvailable()
@@ -22,7 +22,9 @@ except OptionalDependencyNotAvailable:
22
22
  _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
23
23
  else:
24
24
  _import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"]
25
+ _import_structure["pipeline_animatediff_controlnet"] = ["AnimateDiffControlNetPipeline"]
25
26
  _import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"]
27
+ _import_structure["pipeline_animatediff_sparsectrl"] = ["AnimateDiffSparseControlNetPipeline"]
26
28
  _import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
27
29
 
28
30
  if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -34,7 +36,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
34
36
 
35
37
  else:
36
38
  from .pipeline_animatediff import AnimateDiffPipeline
39
+ from .pipeline_animatediff_controlnet import AnimateDiffControlNetPipeline
37
40
  from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline
41
+ from .pipeline_animatediff_sparsectrl import AnimateDiffSparseControlNetPipeline
38
42
  from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
39
43
  from .pipeline_output import AnimateDiffPipelineOutput
40
44
 
@@ -19,7 +19,7 @@ import torch
19
19
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
20
20
 
21
21
  from ...image_processor import PipelineImageInput
22
- from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
22
+ from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
23
23
  from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
24
24
  from ...models.lora import adjust_lora_scale_text_encoder
25
25
  from ...models.unets.unet_motion_model import MotionAdapter
@@ -42,6 +42,7 @@ from ...utils import (
42
42
  from ...utils.torch_utils import randn_tensor
43
43
  from ...video_processor import VideoProcessor
44
44
  from ..free_init_utils import FreeInitMixin
45
+ from ..free_noise_utils import AnimateDiffFreeNoiseMixin
45
46
  from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
46
47
  from .pipeline_output import AnimateDiffPipelineOutput
47
48
 
@@ -70,8 +71,9 @@ class AnimateDiffPipeline(
70
71
  StableDiffusionMixin,
71
72
  TextualInversionLoaderMixin,
72
73
  IPAdapterMixin,
73
- LoraLoaderMixin,
74
+ StableDiffusionLoraLoaderMixin,
74
75
  FreeInitMixin,
76
+ AnimateDiffFreeNoiseMixin,
75
77
  ):
76
78
  r"""
77
79
  Pipeline for text-to-video generation.
@@ -81,8 +83,8 @@ class AnimateDiffPipeline(
81
83
 
82
84
  The pipeline also inherits the following loading methods:
83
85
  - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
84
- - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
85
- - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
86
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
87
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
86
88
  - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
87
89
 
88
90
  Args:
@@ -184,7 +186,7 @@ class AnimateDiffPipeline(
184
186
  """
185
187
  # set lora scale so that monkey patched LoRA
186
188
  # function of text encoder can correctly access it
187
- if lora_scale is not None and isinstance(self, LoraLoaderMixin):
189
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
188
190
  self._lora_scale = lora_scale
189
191
 
190
192
  # dynamically adjust the LoRA scale
@@ -317,7 +319,7 @@ class AnimateDiffPipeline(
317
319
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
318
320
 
319
321
  if self.text_encoder is not None:
320
- if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
322
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
321
323
  # Retrieve the original scale by scaling back the LoRA layers
322
324
  unscale_lora_layers(self.text_encoder, lora_scale)
323
325
 
@@ -352,6 +354,9 @@ class AnimateDiffPipeline(
352
354
  def prepare_ip_adapter_image_embeds(
353
355
  self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
354
356
  ):
357
+ image_embeds = []
358
+ if do_classifier_free_guidance:
359
+ negative_image_embeds = []
355
360
  if ip_adapter_image_embeds is None:
356
361
  if not isinstance(ip_adapter_image, list):
357
362
  ip_adapter_image = [ip_adapter_image]
@@ -361,7 +366,6 @@ class AnimateDiffPipeline(
361
366
  f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
362
367
  )
363
368
 
364
- image_embeds = []
365
369
  for single_ip_adapter_image, image_proj_layer in zip(
366
370
  ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
367
371
  ):
@@ -369,46 +373,43 @@ class AnimateDiffPipeline(
369
373
  single_image_embeds, single_negative_image_embeds = self.encode_image(
370
374
  single_ip_adapter_image, device, 1, output_hidden_state
371
375
  )
372
- single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
373
- single_negative_image_embeds = torch.stack(
374
- [single_negative_image_embeds] * num_images_per_prompt, dim=0
375
- )
376
376
 
377
+ image_embeds.append(single_image_embeds[None, :])
377
378
  if do_classifier_free_guidance:
378
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
379
- single_image_embeds = single_image_embeds.to(device)
380
-
381
- image_embeds.append(single_image_embeds)
379
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
382
380
  else:
383
- repeat_dims = [1]
384
- image_embeds = []
385
381
  for single_image_embeds in ip_adapter_image_embeds:
386
382
  if do_classifier_free_guidance:
387
383
  single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
388
- single_image_embeds = single_image_embeds.repeat(
389
- num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
390
- )
391
- single_negative_image_embeds = single_negative_image_embeds.repeat(
392
- num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
393
- )
394
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
395
- else:
396
- single_image_embeds = single_image_embeds.repeat(
397
- num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
398
- )
384
+ negative_image_embeds.append(single_negative_image_embeds)
399
385
  image_embeds.append(single_image_embeds)
400
386
 
401
- return image_embeds
387
+ ip_adapter_image_embeds = []
388
+ for i, single_image_embeds in enumerate(image_embeds):
389
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
390
+ if do_classifier_free_guidance:
391
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
392
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
393
+
394
+ single_image_embeds = single_image_embeds.to(device=device)
395
+ ip_adapter_image_embeds.append(single_image_embeds)
402
396
 
403
- # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
404
- def decode_latents(self, latents):
397
+ return ip_adapter_image_embeds
398
+
399
+ def decode_latents(self, latents, decode_chunk_size: int = 16):
405
400
  latents = 1 / self.vae.config.scaling_factor * latents
406
401
 
407
402
  batch_size, channels, num_frames, height, width = latents.shape
408
403
  latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
409
404
 
410
- image = self.vae.decode(latents).sample
411
- video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
405
+ video = []
406
+ for i in range(0, latents.shape[0], decode_chunk_size):
407
+ batch_latents = latents[i : i + decode_chunk_size]
408
+ batch_latents = self.vae.decode(batch_latents).sample
409
+ video.append(batch_latents)
410
+
411
+ video = torch.cat(video)
412
+ video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
412
413
  # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
413
414
  video = video.float()
414
415
  return video
@@ -501,10 +502,21 @@ class AnimateDiffPipeline(
501
502
  f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
502
503
  )
503
504
 
504
- # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
505
505
  def prepare_latents(
506
506
  self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
507
507
  ):
508
+ # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
509
+ if self.free_noise_enabled:
510
+ latents = self._prepare_latents_free_noise(
511
+ batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
512
+ )
513
+
514
+ if isinstance(generator, list) and len(generator) != batch_size:
515
+ raise ValueError(
516
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
517
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
518
+ )
519
+
508
520
  shape = (
509
521
  batch_size,
510
522
  num_channels_latents,
@@ -512,11 +524,6 @@ class AnimateDiffPipeline(
512
524
  height // self.vae_scale_factor,
513
525
  width // self.vae_scale_factor,
514
526
  )
515
- if isinstance(generator, list) and len(generator) != batch_size:
516
- raise ValueError(
517
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
518
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
519
- )
520
527
 
521
528
  if latents is None:
522
529
  latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
@@ -575,6 +582,7 @@ class AnimateDiffPipeline(
575
582
  clip_skip: Optional[int] = None,
576
583
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
577
584
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
585
+ decode_chunk_size: int = 16,
578
586
  **kwargs,
579
587
  ):
580
588
  r"""
@@ -643,6 +651,8 @@ class AnimateDiffPipeline(
643
651
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
644
652
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
645
653
  `._callback_tensor_inputs` attribute of your pipeline class.
654
+ decode_chunk_size (`int`, defaults to `16`):
655
+ The number of frames to decode at a time when calling `decode_latents` method.
646
656
 
647
657
  Examples:
648
658
 
@@ -814,7 +824,7 @@ class AnimateDiffPipeline(
814
824
  if output_type == "latent":
815
825
  video = latents
816
826
  else:
817
- video_tensor = self.decode_latents(latents)
827
+ video_tensor = self.decode_latents(latents, decode_chunk_size)
818
828
  video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
819
829
 
820
830
  # 10. Offload all models