diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +19 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  229. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  231. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  232. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
  267. diffusers-0.27.2.dist-info/RECORD +0 -399
  268. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  269. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -15,11 +15,10 @@
15
15
  import inspect
16
16
  from typing import Any, Callable, Dict, List, Optional, Union
17
17
 
18
- import numpy as np
19
18
  import torch
20
19
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
21
20
 
22
- from ...image_processor import PipelineImageInput, VaeImageProcessor
21
+ from ...image_processor import PipelineImageInput
23
22
  from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
24
23
  from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
25
24
  from ...models.lora import adjust_lora_scale_text_encoder
@@ -34,6 +33,7 @@ from ...schedulers import (
34
33
  )
35
34
  from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
36
35
  from ...utils.torch_utils import randn_tensor
36
+ from ...video_processor import VideoProcessor
37
37
  from ..free_init_utils import FreeInitMixin
38
38
  from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
39
39
  from .pipeline_output import AnimateDiffPipelineOutput
@@ -52,14 +52,21 @@ EXAMPLE_DOC_STRING = """
52
52
  >>> from io import BytesIO
53
53
  >>> from PIL import Image
54
54
 
55
- >>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
56
- >>> pipe = AnimateDiffVideoToVideoPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter).to("cuda")
57
- >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")
55
+ >>> adapter = MotionAdapter.from_pretrained(
56
+ ... "guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16
57
+ ... )
58
+ >>> pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(
59
+ ... "SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter
60
+ ... ).to("cuda")
61
+ >>> pipe.scheduler = DDIMScheduler(
62
+ ... beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace"
63
+ ... )
64
+
58
65
 
59
66
  >>> def load_video(file_path: str):
60
67
  ... images = []
61
- ...
62
- ... if file_path.startswith(('http://', 'https://')):
68
+
69
+ ... if file_path.startswith(("http://", "https://")):
63
70
  ... # If the file_path is a URL
64
71
  ... response = requests.get(file_path)
65
72
  ... response.raise_for_status()
@@ -68,43 +75,26 @@ EXAMPLE_DOC_STRING = """
68
75
  ... else:
69
76
  ... # Assuming it's a local file path
70
77
  ... vid = imageio.get_reader(file_path)
71
- ...
78
+
72
79
  ... for frame in vid:
73
80
  ... pil_image = Image.fromarray(frame)
74
81
  ... images.append(pil_image)
75
- ...
82
+
76
83
  ... return images
77
84
 
78
- >>> video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif")
79
- >>> output = pipe(video=video, prompt="panda playing a guitar, on a boat, in the ocean, high quality", strength=0.5)
85
+
86
+ >>> video = load_video(
87
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif"
88
+ ... )
89
+ >>> output = pipe(
90
+ ... video=video, prompt="panda playing a guitar, on a boat, in the ocean, high quality", strength=0.5
91
+ ... )
80
92
  >>> frames = output.frames[0]
81
93
  >>> export_to_gif(frames, "animation.gif")
82
94
  ```
83
95
  """
84
96
 
85
97
 
86
- # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
87
- def tensor2vid(video: torch.Tensor, processor, output_type="np"):
88
- batch_size, channels, num_frames, height, width = video.shape
89
- outputs = []
90
- for batch_idx in range(batch_size):
91
- batch_vid = video[batch_idx].permute(1, 0, 2, 3)
92
- batch_output = processor.postprocess(batch_vid, output_type)
93
-
94
- outputs.append(batch_output)
95
-
96
- if output_type == "np":
97
- outputs = np.stack(outputs)
98
-
99
- elif output_type == "pt":
100
- outputs = torch.stack(outputs)
101
-
102
- elif not output_type == "pil":
103
- raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
104
-
105
- return outputs
106
-
107
-
108
98
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
109
99
  def retrieve_latents(
110
100
  encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -125,6 +115,7 @@ def retrieve_timesteps(
125
115
  num_inference_steps: Optional[int] = None,
126
116
  device: Optional[Union[str, torch.device]] = None,
127
117
  timesteps: Optional[List[int]] = None,
118
+ sigmas: Optional[List[float]] = None,
128
119
  **kwargs,
129
120
  ):
130
121
  """
@@ -135,19 +126,23 @@ def retrieve_timesteps(
135
126
  scheduler (`SchedulerMixin`):
136
127
  The scheduler to get timesteps from.
137
128
  num_inference_steps (`int`):
138
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
139
- `timesteps` must be `None`.
129
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
130
+ must be `None`.
140
131
  device (`str` or `torch.device`, *optional*):
141
132
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
142
133
  timesteps (`List[int]`, *optional*):
143
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
144
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
145
- must be `None`.
134
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
135
+ `num_inference_steps` and `sigmas` must be `None`.
136
+ sigmas (`List[float]`, *optional*):
137
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
138
+ `num_inference_steps` and `timesteps` must be `None`.
146
139
 
147
140
  Returns:
148
141
  `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
149
142
  second element is the number of inference steps.
150
143
  """
144
+ if timesteps is not None and sigmas is not None:
145
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
151
146
  if timesteps is not None:
152
147
  accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
153
148
  if not accepts_timesteps:
@@ -158,6 +153,16 @@ def retrieve_timesteps(
158
153
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
159
154
  timesteps = scheduler.timesteps
160
155
  num_inference_steps = len(timesteps)
156
+ elif sigmas is not None:
157
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
158
+ if not accept_sigmas:
159
+ raise ValueError(
160
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
161
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
162
+ )
163
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
164
+ timesteps = scheduler.timesteps
165
+ num_inference_steps = len(timesteps)
161
166
  else:
162
167
  scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
163
168
  timesteps = scheduler.timesteps
@@ -237,7 +242,7 @@ class AnimateDiffVideoToVideoPipeline(
237
242
  image_encoder=image_encoder,
238
243
  )
239
244
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
240
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
245
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
241
246
 
242
247
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
243
248
  def encode_prompt(
@@ -247,8 +252,8 @@ class AnimateDiffVideoToVideoPipeline(
247
252
  num_images_per_prompt,
248
253
  do_classifier_free_guidance,
249
254
  negative_prompt=None,
250
- prompt_embeds: Optional[torch.FloatTensor] = None,
251
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
255
+ prompt_embeds: Optional[torch.Tensor] = None,
256
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
252
257
  lora_scale: Optional[float] = None,
253
258
  clip_skip: Optional[int] = None,
254
259
  ):
@@ -268,10 +273,10 @@ class AnimateDiffVideoToVideoPipeline(
268
273
  The prompt or prompts not to guide the image generation. If not defined, one has to pass
269
274
  `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
270
275
  less than `1`).
271
- prompt_embeds (`torch.FloatTensor`, *optional*):
276
+ prompt_embeds (`torch.Tensor`, *optional*):
272
277
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
273
278
  provided, text embeddings will be generated from `prompt` input argument.
274
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
279
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
275
280
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
276
281
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
277
282
  argument.
@@ -623,16 +628,7 @@ class AnimateDiffVideoToVideoPipeline(
623
628
  generator,
624
629
  latents=None,
625
630
  ):
626
- # video must be a list of list of images
627
- # the outer list denotes having multiple videos as input, whereas inner list means the frames of the video
628
- # as a list of images
629
- if not isinstance(video[0], list):
630
- video = [video]
631
631
  if latents is None:
632
- video = torch.cat(
633
- [self.image_processor.preprocess(vid, height=height, width=width).unsqueeze(0) for vid in video], dim=0
634
- )
635
- video = video.to(device=device, dtype=dtype)
636
632
  num_frames = video.shape[1]
637
633
  else:
638
634
  num_frames = latents.shape[2]
@@ -738,17 +734,18 @@ class AnimateDiffVideoToVideoPipeline(
738
734
  width: Optional[int] = None,
739
735
  num_inference_steps: int = 50,
740
736
  timesteps: Optional[List[int]] = None,
737
+ sigmas: Optional[List[float]] = None,
741
738
  guidance_scale: float = 7.5,
742
739
  strength: float = 0.8,
743
740
  negative_prompt: Optional[Union[str, List[str]]] = None,
744
741
  num_videos_per_prompt: Optional[int] = 1,
745
742
  eta: float = 0.0,
746
743
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
747
- latents: Optional[torch.FloatTensor] = None,
748
- prompt_embeds: Optional[torch.FloatTensor] = None,
749
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
744
+ latents: Optional[torch.Tensor] = None,
745
+ prompt_embeds: Optional[torch.Tensor] = None,
746
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
750
747
  ip_adapter_image: Optional[PipelineImageInput] = None,
751
- ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
748
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
752
749
  output_type: Optional[str] = "pil",
753
750
  return_dict: bool = True,
754
751
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -771,6 +768,14 @@ class AnimateDiffVideoToVideoPipeline(
771
768
  num_inference_steps (`int`, *optional*, defaults to 50):
772
769
  The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
773
770
  expense of slower inference.
771
+ timesteps (`List[int]`, *optional*):
772
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
773
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
774
+ passed will be used. Must be in descending order.
775
+ sigmas (`List[float]`, *optional*):
776
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
777
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
778
+ will be used.
774
779
  strength (`float`, *optional*, defaults to 0.8):
775
780
  Higher strength leads to more differences between original video and generated video.
776
781
  guidance_scale (`float`, *optional*, defaults to 7.5):
@@ -785,30 +790,28 @@ class AnimateDiffVideoToVideoPipeline(
785
790
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
786
791
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
787
792
  generation deterministic.
788
- latents (`torch.FloatTensor`, *optional*):
793
+ latents (`torch.Tensor`, *optional*):
789
794
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
790
795
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
791
796
  tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
792
797
  `(batch_size, num_channel, num_frames, height, width)`.
793
- prompt_embeds (`torch.FloatTensor`, *optional*):
798
+ prompt_embeds (`torch.Tensor`, *optional*):
794
799
  Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
795
800
  provided, text embeddings are generated from the `prompt` input argument.
796
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
801
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
797
802
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
798
803
  not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
799
804
  ip_adapter_image: (`PipelineImageInput`, *optional*):
800
805
  Optional image input to work with IP Adapters.
801
- ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
802
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
803
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
804
- if `do_classifier_free_guidance` is set to `True`.
805
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
806
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
807
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
808
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
809
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
810
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
806
811
  output_type (`str`, *optional*, defaults to `"pil"`):
807
- The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
808
- `np.array`.
812
+ The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
809
813
  return_dict (`bool`, *optional*, defaults to `True`):
810
- Whether or not to return a [`AnimateDiffPipelineOutput`] instead
811
- of a plain tuple.
814
+ Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple.
812
815
  cross_attention_kwargs (`dict`, *optional*):
813
816
  A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
814
817
  [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -823,7 +826,7 @@ class AnimateDiffVideoToVideoPipeline(
823
826
  callback_on_step_end_tensor_inputs (`List`, *optional*):
824
827
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
825
828
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
826
- `._callback_tensor_inputs` attribute of your pipeine class.
829
+ `._callback_tensor_inputs` attribute of your pipeline class.
827
830
 
828
831
  Examples:
829
832
 
@@ -901,11 +904,18 @@ class AnimateDiffVideoToVideoPipeline(
901
904
  )
902
905
 
903
906
  # 4. Prepare timesteps
904
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
907
+ timesteps, num_inference_steps = retrieve_timesteps(
908
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
909
+ )
905
910
  timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
906
911
  latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
907
912
 
908
913
  # 5. Prepare latent variables
914
+ if latents is None:
915
+ video = self.video_processor.preprocess_video(video, height=height, width=width)
916
+ # Move the number of frames before the number of channels.
917
+ video = video.permute(0, 2, 1, 3, 4)
918
+ video = video.to(device=device, dtype=prompt_embeds.dtype)
909
919
  num_channels_latents = self.unet.config.in_channels
910
920
  latents = self.prepare_latents(
911
921
  video=video,
@@ -944,7 +954,7 @@ class AnimateDiffVideoToVideoPipeline(
944
954
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
945
955
 
946
956
  # 8. Denoising loop
947
- with self.progress_bar(total=num_inference_steps) as progress_bar:
957
+ with self.progress_bar(total=self._num_timesteps) as progress_bar:
948
958
  for i, t in enumerate(timesteps):
949
959
  # expand the latents if we are doing classifier free guidance
950
960
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
@@ -986,7 +996,7 @@ class AnimateDiffVideoToVideoPipeline(
986
996
  video = latents
987
997
  else:
988
998
  video_tensor = self.decode_latents(latents)
989
- video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
999
+ video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
990
1000
 
991
1001
  # 10. Offload all models
992
1002
  self.maybe_free_model_hooks()
@@ -13,9 +13,10 @@ class AnimateDiffPipelineOutput(BaseOutput):
13
13
  r"""
14
14
  Output class for AnimateDiff pipelines.
15
15
 
16
- Args:
16
+ Args:
17
17
  frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
18
- List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised
18
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
19
+ denoised
19
20
  PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
20
21
  `(batch_size, num_frames, channels, height, width)`
21
22
  """
@@ -103,8 +103,8 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
103
103
  num_waveforms_per_prompt,
104
104
  do_classifier_free_guidance,
105
105
  negative_prompt=None,
106
- prompt_embeds: Optional[torch.FloatTensor] = None,
107
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
106
+ prompt_embeds: Optional[torch.Tensor] = None,
107
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
108
108
  ):
109
109
  r"""
110
110
  Encodes the prompt into text encoder hidden states.
@@ -122,10 +122,10 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
122
122
  The prompt or prompts not to guide the audio generation. If not defined, one has to pass
123
123
  `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
124
124
  less than `1`).
125
- prompt_embeds (`torch.FloatTensor`, *optional*):
125
+ prompt_embeds (`torch.Tensor`, *optional*):
126
126
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
127
127
  provided, text embeddings will be generated from `prompt` input argument.
128
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
128
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
129
129
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
130
130
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
131
131
  argument.
@@ -330,8 +330,8 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
330
330
  shape = (
331
331
  batch_size,
332
332
  num_channels_latents,
333
- height // self.vae_scale_factor,
334
- self.vocoder.config.model_in_dim // self.vae_scale_factor,
333
+ int(height) // self.vae_scale_factor,
334
+ int(self.vocoder.config.model_in_dim) // self.vae_scale_factor,
335
335
  )
336
336
  if isinstance(generator, list) and len(generator) != batch_size:
337
337
  raise ValueError(
@@ -360,11 +360,11 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
360
360
  num_waveforms_per_prompt: Optional[int] = 1,
361
361
  eta: float = 0.0,
362
362
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
363
- latents: Optional[torch.FloatTensor] = None,
364
- prompt_embeds: Optional[torch.FloatTensor] = None,
365
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
363
+ latents: Optional[torch.Tensor] = None,
364
+ prompt_embeds: Optional[torch.Tensor] = None,
365
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
366
366
  return_dict: bool = True,
367
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
367
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
368
368
  callback_steps: Optional[int] = 1,
369
369
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
370
370
  output_type: Optional[str] = "np",
@@ -394,21 +394,21 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
394
394
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
395
395
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
396
396
  generation deterministic.
397
- latents (`torch.FloatTensor`, *optional*):
397
+ latents (`torch.Tensor`, *optional*):
398
398
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
399
399
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
400
400
  tensor is generated by sampling using the supplied random `generator`.
401
- prompt_embeds (`torch.FloatTensor`, *optional*):
401
+ prompt_embeds (`torch.Tensor`, *optional*):
402
402
  Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
403
403
  provided, text embeddings are generated from the `prompt` input argument.
404
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
404
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
405
405
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
406
406
  not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
407
407
  return_dict (`bool`, *optional*, defaults to `True`):
408
408
  Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple.
409
409
  callback (`Callable`, *optional*):
410
410
  A function that calls every `callback_steps` steps during inference. The function is called with the
411
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
411
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
412
412
  callback_steps (`int`, *optional*, defaults to 1):
413
413
  The frequency at which the `callback` function is called. If not specified, the callback is called at
414
414
  every step.
@@ -64,7 +64,7 @@ class AudioLDM2ProjectionModelOutput(BaseOutput):
64
64
  """
65
65
  Args:
66
66
  Class for AudioLDM2 projection layer's outputs.
67
- hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
67
+ hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
68
68
  Sequence of hidden-states obtained by linearly projecting the hidden-states for each of the text
69
69
  encoders and subsequently concatenating them together.
70
70
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -75,7 +75,7 @@ class AudioLDM2ProjectionModelOutput(BaseOutput):
75
75
  - 0 for tokens that are **masked**.
76
76
  """
77
77
 
78
- hidden_states: torch.FloatTensor
78
+ hidden_states: torch.Tensor
79
79
  attention_mask: Optional[torch.LongTensor] = None
80
80
 
81
81
 
@@ -95,7 +95,14 @@ class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin):
95
95
  """
96
96
 
97
97
  @register_to_config
98
- def __init__(self, text_encoder_dim, text_encoder_1_dim, langauge_model_dim):
98
+ def __init__(
99
+ self,
100
+ text_encoder_dim,
101
+ text_encoder_1_dim,
102
+ langauge_model_dim,
103
+ use_learned_position_embedding=None,
104
+ max_seq_length=None,
105
+ ):
99
106
  super().__init__()
100
107
  # additional projection layers for each text encoder
101
108
  self.projection = nn.Linear(text_encoder_dim, langauge_model_dim)
@@ -108,10 +115,18 @@ class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin):
108
115
  self.sos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
109
116
  self.eos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
110
117
 
118
+ self.use_learned_position_embedding = use_learned_position_embedding
119
+
120
+ # learable positional embedding for vits encoder
121
+ if self.use_learned_position_embedding is not None:
122
+ self.learnable_positional_embedding = torch.nn.Parameter(
123
+ torch.zeros((1, text_encoder_1_dim, max_seq_length))
124
+ )
125
+
111
126
  def forward(
112
127
  self,
113
- hidden_states: Optional[torch.FloatTensor] = None,
114
- hidden_states_1: Optional[torch.FloatTensor] = None,
128
+ hidden_states: Optional[torch.Tensor] = None,
129
+ hidden_states_1: Optional[torch.Tensor] = None,
115
130
  attention_mask: Optional[torch.LongTensor] = None,
116
131
  attention_mask_1: Optional[torch.LongTensor] = None,
117
132
  ):
@@ -120,6 +135,10 @@ class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin):
120
135
  hidden_states, attention_mask, sos_token=self.sos_embed, eos_token=self.eos_embed
121
136
  )
122
137
 
138
+ # Add positional embedding for Vits hidden state
139
+ if self.use_learned_position_embedding is not None:
140
+ hidden_states_1 = (hidden_states_1.permute(0, 2, 1) + self.learnable_positional_embedding).permute(0, 2, 1)
141
+
123
142
  hidden_states_1 = self.projection_1(hidden_states_1)
124
143
  hidden_states_1, attention_mask_1 = add_special_tokens(
125
144
  hidden_states_1, attention_mask_1, sos_token=self.sos_embed_1, eos_token=self.eos_embed_1
@@ -661,7 +680,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
661
680
 
662
681
  def forward(
663
682
  self,
664
- sample: torch.FloatTensor,
683
+ sample: torch.Tensor,
665
684
  timestep: Union[torch.Tensor, float, int],
666
685
  encoder_hidden_states: torch.Tensor,
667
686
  class_labels: Optional[torch.Tensor] = None,
@@ -677,10 +696,10 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
677
696
  The [`AudioLDM2UNet2DConditionModel`] forward method.
678
697
 
679
698
  Args:
680
- sample (`torch.FloatTensor`):
699
+ sample (`torch.Tensor`):
681
700
  The noisy input tensor with the following shape `(batch, channel, height, width)`.
682
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
683
- encoder_hidden_states (`torch.FloatTensor`):
701
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
702
+ encoder_hidden_states (`torch.Tensor`):
684
703
  The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
685
704
  encoder_attention_mask (`torch.Tensor`):
686
705
  A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
@@ -691,7 +710,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
691
710
  tuple.
692
711
  cross_attention_kwargs (`dict`, *optional*):
693
712
  A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
694
- encoder_hidden_states_1 (`torch.FloatTensor`, *optional*):
713
+ encoder_hidden_states_1 (`torch.Tensor`, *optional*):
695
714
  A second set of encoder hidden states with shape `(batch, sequence_length_2, feature_dim_2)`. Can be
696
715
  used to condition the model on a different set of embeddings to `encoder_hidden_states`.
697
716
  encoder_attention_mask_1 (`torch.Tensor`, *optional*):
@@ -701,8 +720,8 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
701
720
 
702
721
  Returns:
703
722
  [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
704
- If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
705
- a `tuple` is returned where the first element is the sample tensor.
723
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
724
+ otherwise a `tuple` is returned where the first element is the sample tensor.
706
725
  """
707
726
  # By default samples have to be AT least a multiple of the overall upsampling factor.
708
727
  # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
@@ -1072,14 +1091,14 @@ class CrossAttnDownBlock2D(nn.Module):
1072
1091
 
1073
1092
  def forward(
1074
1093
  self,
1075
- hidden_states: torch.FloatTensor,
1076
- temb: Optional[torch.FloatTensor] = None,
1077
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1078
- attention_mask: Optional[torch.FloatTensor] = None,
1094
+ hidden_states: torch.Tensor,
1095
+ temb: Optional[torch.Tensor] = None,
1096
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1097
+ attention_mask: Optional[torch.Tensor] = None,
1079
1098
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1080
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1081
- encoder_hidden_states_1: Optional[torch.FloatTensor] = None,
1082
- encoder_attention_mask_1: Optional[torch.FloatTensor] = None,
1099
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1100
+ encoder_hidden_states_1: Optional[torch.Tensor] = None,
1101
+ encoder_attention_mask_1: Optional[torch.Tensor] = None,
1083
1102
  ):
1084
1103
  output_states = ()
1085
1104
  num_layers = len(self.resnets)
@@ -1251,15 +1270,15 @@ class UNetMidBlock2DCrossAttn(nn.Module):
1251
1270
 
1252
1271
  def forward(
1253
1272
  self,
1254
- hidden_states: torch.FloatTensor,
1255
- temb: Optional[torch.FloatTensor] = None,
1256
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1257
- attention_mask: Optional[torch.FloatTensor] = None,
1273
+ hidden_states: torch.Tensor,
1274
+ temb: Optional[torch.Tensor] = None,
1275
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1276
+ attention_mask: Optional[torch.Tensor] = None,
1258
1277
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1259
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1260
- encoder_hidden_states_1: Optional[torch.FloatTensor] = None,
1261
- encoder_attention_mask_1: Optional[torch.FloatTensor] = None,
1262
- ) -> torch.FloatTensor:
1278
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1279
+ encoder_hidden_states_1: Optional[torch.Tensor] = None,
1280
+ encoder_attention_mask_1: Optional[torch.Tensor] = None,
1281
+ ) -> torch.Tensor:
1263
1282
  hidden_states = self.resnets[0](hidden_states, temb)
1264
1283
  num_attention_per_layer = len(self.attentions) // (len(self.resnets) - 1)
1265
1284
 
@@ -1418,16 +1437,16 @@ class CrossAttnUpBlock2D(nn.Module):
1418
1437
 
1419
1438
  def forward(
1420
1439
  self,
1421
- hidden_states: torch.FloatTensor,
1422
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1423
- temb: Optional[torch.FloatTensor] = None,
1424
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1440
+ hidden_states: torch.Tensor,
1441
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
1442
+ temb: Optional[torch.Tensor] = None,
1443
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1425
1444
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1426
1445
  upsample_size: Optional[int] = None,
1427
- attention_mask: Optional[torch.FloatTensor] = None,
1428
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1429
- encoder_hidden_states_1: Optional[torch.FloatTensor] = None,
1430
- encoder_attention_mask_1: Optional[torch.FloatTensor] = None,
1446
+ attention_mask: Optional[torch.Tensor] = None,
1447
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1448
+ encoder_hidden_states_1: Optional[torch.Tensor] = None,
1449
+ encoder_attention_mask_1: Optional[torch.Tensor] = None,
1431
1450
  ):
1432
1451
  num_layers = len(self.resnets)
1433
1452
  num_attention_per_layer = len(self.attentions) // num_layers