diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (238) hide show
  1. diffusers/__init__.py +26 -2
  2. diffusers/commands/fp16_safetensors.py +10 -11
  3. diffusers/configuration_utils.py +13 -8
  4. diffusers/dependency_versions_check.py +0 -1
  5. diffusers/dependency_versions_table.py +5 -5
  6. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  7. diffusers/image_processor.py +463 -51
  8. diffusers/loaders/__init__.py +82 -0
  9. diffusers/loaders/ip_adapter.py +159 -0
  10. diffusers/loaders/lora.py +1553 -0
  11. diffusers/loaders/lora_conversion_utils.py +284 -0
  12. diffusers/loaders/single_file.py +637 -0
  13. diffusers/loaders/textual_inversion.py +455 -0
  14. diffusers/loaders/unet.py +828 -0
  15. diffusers/loaders/utils.py +59 -0
  16. diffusers/models/__init__.py +26 -9
  17. diffusers/models/activations.py +9 -6
  18. diffusers/models/attention.py +301 -29
  19. diffusers/models/attention_flax.py +9 -1
  20. diffusers/models/attention_processor.py +378 -6
  21. diffusers/models/autoencoders/__init__.py +5 -0
  22. diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
  23. diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
  24. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
  25. diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
  26. diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
  27. diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
  28. diffusers/models/controlnet.py +59 -39
  29. diffusers/models/controlnet_flax.py +19 -18
  30. diffusers/models/downsampling.py +338 -0
  31. diffusers/models/embeddings.py +112 -29
  32. diffusers/models/embeddings_flax.py +2 -0
  33. diffusers/models/lora.py +131 -1
  34. diffusers/models/modeling_flax_utils.py +14 -8
  35. diffusers/models/modeling_outputs.py +17 -0
  36. diffusers/models/modeling_utils.py +37 -29
  37. diffusers/models/normalization.py +110 -4
  38. diffusers/models/resnet.py +299 -652
  39. diffusers/models/transformer_2d.py +22 -5
  40. diffusers/models/transformer_temporal.py +183 -1
  41. diffusers/models/unet_2d_blocks_flax.py +5 -0
  42. diffusers/models/unet_2d_condition.py +46 -0
  43. diffusers/models/unet_2d_condition_flax.py +13 -13
  44. diffusers/models/unet_3d_blocks.py +957 -173
  45. diffusers/models/unet_3d_condition.py +16 -8
  46. diffusers/models/unet_kandinsky3.py +535 -0
  47. diffusers/models/unet_motion_model.py +48 -33
  48. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  49. diffusers/models/upsampling.py +454 -0
  50. diffusers/models/uvit_2d.py +471 -0
  51. diffusers/models/vae_flax.py +7 -0
  52. diffusers/models/vq_model.py +12 -3
  53. diffusers/optimization.py +16 -9
  54. diffusers/pipelines/__init__.py +137 -76
  55. diffusers/pipelines/amused/__init__.py +62 -0
  56. diffusers/pipelines/amused/pipeline_amused.py +328 -0
  57. diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
  58. diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
  59. diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
  60. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  61. diffusers/pipelines/auto_pipeline.py +23 -13
  62. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  63. diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
  64. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
  65. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
  66. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
  67. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
  68. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
  69. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  70. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  71. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  72. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  73. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  74. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  75. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  76. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  77. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  78. diffusers/pipelines/deprecated/__init__.py +153 -0
  79. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
  80. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
  81. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
  82. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
  83. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
  84. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
  85. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
  86. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
  87. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
  88. diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
  89. diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
  90. diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
  91. diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
  92. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
  93. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
  94. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
  95. diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
  96. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
  97. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
  98. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
  100. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
  101. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
  102. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
  103. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
  104. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
  105. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
  106. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
  107. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
  108. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
  109. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
  110. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
  111. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
  112. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
  113. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
  114. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
  115. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
  116. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  117. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  118. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  119. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  120. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  122. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  123. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
  124. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
  125. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
  126. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
  127. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
  128. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  129. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  130. diffusers/pipelines/onnx_utils.py +8 -5
  131. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  132. diffusers/pipelines/pipeline_flax_utils.py +11 -8
  133. diffusers/pipelines/pipeline_utils.py +63 -42
  134. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
  135. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
  136. diffusers/pipelines/stable_diffusion/__init__.py +37 -65
  137. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
  138. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  139. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  140. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  141. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
  142. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  143. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  144. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
  145. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
  146. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
  147. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
  151. diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
  152. diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
  153. diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
  154. diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
  155. diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
  156. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
  157. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  158. diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
  159. diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
  160. diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
  161. diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
  162. diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
  163. diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
  164. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
  165. diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
  166. diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
  171. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  172. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
  175. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
  179. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
  180. diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
  181. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
  182. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  183. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  184. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  185. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  186. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  187. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
  188. diffusers/schedulers/__init__.py +4 -4
  189. diffusers/schedulers/deprecated/__init__.py +50 -0
  190. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  191. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  192. diffusers/schedulers/scheduling_amused.py +162 -0
  193. diffusers/schedulers/scheduling_consistency_models.py +2 -0
  194. diffusers/schedulers/scheduling_ddim.py +1 -3
  195. diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
  196. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  197. diffusers/schedulers/scheduling_ddpm.py +47 -3
  198. diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
  199. diffusers/schedulers/scheduling_deis_multistep.py +28 -6
  200. diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
  201. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
  202. diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
  203. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
  204. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
  205. diffusers/schedulers/scheduling_euler_discrete.py +102 -16
  206. diffusers/schedulers/scheduling_heun_discrete.py +17 -5
  207. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
  208. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
  209. diffusers/schedulers/scheduling_lcm.py +123 -29
  210. diffusers/schedulers/scheduling_lms_discrete.py +3 -3
  211. diffusers/schedulers/scheduling_pndm.py +1 -3
  212. diffusers/schedulers/scheduling_repaint.py +1 -3
  213. diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
  214. diffusers/schedulers/scheduling_utils.py +3 -1
  215. diffusers/schedulers/scheduling_utils_flax.py +3 -1
  216. diffusers/training_utils.py +1 -1
  217. diffusers/utils/__init__.py +1 -2
  218. diffusers/utils/constants.py +10 -12
  219. diffusers/utils/dummy_pt_objects.py +75 -0
  220. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  221. diffusers/utils/dynamic_modules_utils.py +18 -22
  222. diffusers/utils/export_utils.py +8 -3
  223. diffusers/utils/hub_utils.py +24 -36
  224. diffusers/utils/logging.py +11 -11
  225. diffusers/utils/outputs.py +5 -5
  226. diffusers/utils/peft_utils.py +88 -44
  227. diffusers/utils/state_dict_utils.py +8 -0
  228. diffusers/utils/testing_utils.py +199 -1
  229. diffusers/utils/torch_utils.py +4 -4
  230. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
  231. diffusers-0.25.0.dist-info/RECORD +360 -0
  232. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
  233. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
  234. diffusers/loaders.py +0 -3336
  235. diffusers-0.23.1.dist-info/RECORD +0 -323
  236. /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
  237. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
  238. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1331 @@
1
+ import copy
2
+ import inspect
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import PIL
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch.nn.functional import grid_sample
11
+ from transformers import (
12
+ CLIPImageProcessor,
13
+ CLIPTextModel,
14
+ CLIPTextModelWithProjection,
15
+ CLIPTokenizer,
16
+ CLIPVisionModelWithProjection,
17
+ )
18
+
19
+ from ...image_processor import VaeImageProcessor
20
+ from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
21
+ from ...models import AutoencoderKL, UNet2DConditionModel
22
+ from ...models.attention_processor import (
23
+ AttnProcessor2_0,
24
+ FusedAttnProcessor2_0,
25
+ LoRAAttnProcessor2_0,
26
+ LoRAXFormersAttnProcessor,
27
+ XFormersAttnProcessor,
28
+ )
29
+ from ...models.lora import adjust_lora_scale_text_encoder
30
+ from ...schedulers import KarrasDiffusionSchedulers
31
+ from ...utils import (
32
+ USE_PEFT_BACKEND,
33
+ BaseOutput,
34
+ is_invisible_watermark_available,
35
+ logging,
36
+ scale_lora_layers,
37
+ unscale_lora_layers,
38
+ )
39
+ from ...utils.torch_utils import randn_tensor
40
+ from ..pipeline_utils import DiffusionPipeline
41
+
42
+
43
+ if is_invisible_watermark_available():
44
+ from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
45
+
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+
50
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_0
51
+ def rearrange_0(tensor, f):
52
+ F, C, H, W = tensor.size()
53
+ tensor = torch.permute(torch.reshape(tensor, (F // f, f, C, H, W)), (0, 2, 1, 3, 4))
54
+ return tensor
55
+
56
+
57
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_1
58
+ def rearrange_1(tensor):
59
+ B, C, F, H, W = tensor.size()
60
+ return torch.reshape(torch.permute(tensor, (0, 2, 1, 3, 4)), (B * F, C, H, W))
61
+
62
+
63
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_3
64
+ def rearrange_3(tensor, f):
65
+ F, D, C = tensor.size()
66
+ return torch.reshape(tensor, (F // f, f, D, C))
67
+
68
+
69
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_4
70
+ def rearrange_4(tensor):
71
+ B, F, D, C = tensor.size()
72
+ return torch.reshape(tensor, (B * F, D, C))
73
+
74
+
75
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor
76
+ class CrossFrameAttnProcessor:
77
+ """
78
+ Cross frame attention processor. Each frame attends the first frame.
79
+
80
+ Args:
81
+ batch_size: The number that represents actual batch size, other than the frames.
82
+ For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to
83
+ 2, due to classifier-free guidance.
84
+ """
85
+
86
+ def __init__(self, batch_size=2):
87
+ self.batch_size = batch_size
88
+
89
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
90
+ batch_size, sequence_length, _ = hidden_states.shape
91
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
92
+ query = attn.to_q(hidden_states)
93
+
94
+ is_cross_attention = encoder_hidden_states is not None
95
+ if encoder_hidden_states is None:
96
+ encoder_hidden_states = hidden_states
97
+ elif attn.norm_cross:
98
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
99
+
100
+ key = attn.to_k(encoder_hidden_states)
101
+ value = attn.to_v(encoder_hidden_states)
102
+
103
+ # Cross Frame Attention
104
+ if not is_cross_attention:
105
+ video_length = key.size()[0] // self.batch_size
106
+ first_frame_index = [0] * video_length
107
+
108
+ # rearrange keys to have batch and frames in the 1st and 2nd dims respectively
109
+ key = rearrange_3(key, video_length)
110
+ key = key[:, first_frame_index]
111
+ # rearrange values to have batch and frames in the 1st and 2nd dims respectively
112
+ value = rearrange_3(value, video_length)
113
+ value = value[:, first_frame_index]
114
+
115
+ # rearrange back to original shape
116
+ key = rearrange_4(key)
117
+ value = rearrange_4(value)
118
+
119
+ query = attn.head_to_batch_dim(query)
120
+ key = attn.head_to_batch_dim(key)
121
+ value = attn.head_to_batch_dim(value)
122
+
123
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
124
+ hidden_states = torch.bmm(attention_probs, value)
125
+ hidden_states = attn.batch_to_head_dim(hidden_states)
126
+
127
+ # linear proj
128
+ hidden_states = attn.to_out[0](hidden_states)
129
+ # dropout
130
+ hidden_states = attn.to_out[1](hidden_states)
131
+
132
+ return hidden_states
133
+
134
+
135
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor2_0
136
+ class CrossFrameAttnProcessor2_0:
137
+ """
138
+ Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0.
139
+
140
+ Args:
141
+ batch_size: The number that represents actual batch size, other than the frames.
142
+ For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to
143
+ 2, due to classifier-free guidance.
144
+ """
145
+
146
+ def __init__(self, batch_size=2):
147
+ if not hasattr(F, "scaled_dot_product_attention"):
148
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
149
+ self.batch_size = batch_size
150
+
151
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
152
+ batch_size, sequence_length, _ = (
153
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
154
+ )
155
+ inner_dim = hidden_states.shape[-1]
156
+
157
+ if attention_mask is not None:
158
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
159
+ # scaled_dot_product_attention expects attention_mask shape to be
160
+ # (batch, heads, source_length, target_length)
161
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
162
+
163
+ query = attn.to_q(hidden_states)
164
+
165
+ is_cross_attention = encoder_hidden_states is not None
166
+ if encoder_hidden_states is None:
167
+ encoder_hidden_states = hidden_states
168
+ elif attn.norm_cross:
169
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
170
+
171
+ key = attn.to_k(encoder_hidden_states)
172
+ value = attn.to_v(encoder_hidden_states)
173
+
174
+ # Cross Frame Attention
175
+ if not is_cross_attention:
176
+ video_length = max(1, key.size()[0] // self.batch_size)
177
+ first_frame_index = [0] * video_length
178
+
179
+ # rearrange keys to have batch and frames in the 1st and 2nd dims respectively
180
+ key = rearrange_3(key, video_length)
181
+ key = key[:, first_frame_index]
182
+ # rearrange values to have batch and frames in the 1st and 2nd dims respectively
183
+ value = rearrange_3(value, video_length)
184
+ value = value[:, first_frame_index]
185
+
186
+ # rearrange back to original shape
187
+ key = rearrange_4(key)
188
+ value = rearrange_4(value)
189
+
190
+ head_dim = inner_dim // attn.heads
191
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
192
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
193
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
194
+
195
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
196
+ # TODO: add support for attn.scale when we move to Torch 2.1
197
+ hidden_states = F.scaled_dot_product_attention(
198
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
199
+ )
200
+
201
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
202
+ hidden_states = hidden_states.to(query.dtype)
203
+
204
+ # linear proj
205
+ hidden_states = attn.to_out[0](hidden_states)
206
+ # dropout
207
+ hidden_states = attn.to_out[1](hidden_states)
208
+ return hidden_states
209
+
210
+
211
+ @dataclass
212
+ class TextToVideoSDXLPipelineOutput(BaseOutput):
213
+ """
214
+ Output class for zero-shot text-to-video pipeline.
215
+
216
+ Args:
217
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
218
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
219
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
220
+ """
221
+
222
+ images: Union[List[PIL.Image.Image], np.ndarray]
223
+
224
+
225
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.coords_grid
226
+ def coords_grid(batch, ht, wd, device):
227
+ # Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py
228
+ coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
229
+ coords = torch.stack(coords[::-1], dim=0).float()
230
+ return coords[None].repeat(batch, 1, 1, 1)
231
+
232
+
233
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.warp_single_latent
234
+ def warp_single_latent(latent, reference_flow):
235
+ """
236
+ Warp latent of a single frame with given flow
237
+
238
+ Args:
239
+ latent: latent code of a single frame
240
+ reference_flow: flow which to warp the latent with
241
+
242
+ Returns:
243
+ warped: warped latent
244
+ """
245
+ _, _, H, W = reference_flow.size()
246
+ _, _, h, w = latent.size()
247
+ coords0 = coords_grid(1, H, W, device=latent.device).to(latent.dtype)
248
+
249
+ coords_t0 = coords0 + reference_flow
250
+ coords_t0[:, 0] /= W
251
+ coords_t0[:, 1] /= H
252
+
253
+ coords_t0 = coords_t0 * 2.0 - 1.0
254
+ coords_t0 = F.interpolate(coords_t0, size=(h, w), mode="bilinear")
255
+ coords_t0 = torch.permute(coords_t0, (0, 2, 3, 1))
256
+
257
+ warped = grid_sample(latent, coords_t0, mode="nearest", padding_mode="reflection")
258
+ return warped
259
+
260
+
261
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.create_motion_field
262
+ def create_motion_field(motion_field_strength_x, motion_field_strength_y, frame_ids, device, dtype):
263
+ """
264
+ Create translation motion field
265
+
266
+ Args:
267
+ motion_field_strength_x: motion strength along x-axis
268
+ motion_field_strength_y: motion strength along y-axis
269
+ frame_ids: indexes of the frames the latents of which are being processed.
270
+ This is needed when we perform chunk-by-chunk inference
271
+ device: device
272
+ dtype: dtype
273
+
274
+ Returns:
275
+
276
+ """
277
+ seq_length = len(frame_ids)
278
+ reference_flow = torch.zeros((seq_length, 2, 512, 512), device=device, dtype=dtype)
279
+ for fr_idx in range(seq_length):
280
+ reference_flow[fr_idx, 0, :, :] = motion_field_strength_x * (frame_ids[fr_idx])
281
+ reference_flow[fr_idx, 1, :, :] = motion_field_strength_y * (frame_ids[fr_idx])
282
+ return reference_flow
283
+
284
+
285
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.create_motion_field_and_warp_latents
286
+ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_strength_y, frame_ids, latents):
287
+ """
288
+ Creates translation motion and warps the latents accordingly
289
+
290
+ Args:
291
+ motion_field_strength_x: motion strength along x-axis
292
+ motion_field_strength_y: motion strength along y-axis
293
+ frame_ids: indexes of the frames the latents of which are being processed.
294
+ This is needed when we perform chunk-by-chunk inference
295
+ latents: latent codes of frames
296
+
297
+ Returns:
298
+ warped_latents: warped latents
299
+ """
300
+ motion_field = create_motion_field(
301
+ motion_field_strength_x=motion_field_strength_x,
302
+ motion_field_strength_y=motion_field_strength_y,
303
+ frame_ids=frame_ids,
304
+ device=latents.device,
305
+ dtype=latents.dtype,
306
+ )
307
+ warped_latents = latents.clone().detach()
308
+ for i in range(len(warped_latents)):
309
+ warped_latents[i] = warp_single_latent(latents[i][None], motion_field[i][None])
310
+ return warped_latents
311
+
312
+
313
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
314
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
315
+ """
316
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
317
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
318
+ """
319
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
320
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
321
+ # rescale the results from guidance (fixes overexposure)
322
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
323
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
324
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
325
+ return noise_cfg
326
+
327
+
328
+ class TextToVideoZeroSDXLPipeline(
329
+ DiffusionPipeline,
330
+ StableDiffusionXLLoraLoaderMixin,
331
+ TextualInversionLoaderMixin,
332
+ ):
333
+ r"""
334
+ Pipeline for zero-shot text-to-video generation using Stable Diffusion XL.
335
+
336
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
337
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
338
+
339
+ Args:
340
+ vae ([`AutoencoderKL`]):
341
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
342
+ text_encoder ([`CLIPTextModel`]):
343
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
344
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
345
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
346
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
347
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
348
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
349
+ specifically the
350
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
351
+ variant.
352
+ tokenizer (`CLIPTokenizer`):
353
+ Tokenizer of class
354
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
355
+ tokenizer_2 (`CLIPTokenizer`):
356
+ Second Tokenizer of class
357
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
358
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
359
+ scheduler ([`SchedulerMixin`]):
360
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
361
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
362
+ """
363
+
364
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
365
+ _optional_components = [
366
+ "tokenizer",
367
+ "tokenizer_2",
368
+ "text_encoder",
369
+ "text_encoder_2",
370
+ "image_encoder",
371
+ "feature_extractor",
372
+ ]
373
+
374
+ def __init__(
375
+ self,
376
+ vae: AutoencoderKL,
377
+ text_encoder: CLIPTextModel,
378
+ text_encoder_2: CLIPTextModelWithProjection,
379
+ tokenizer: CLIPTokenizer,
380
+ tokenizer_2: CLIPTokenizer,
381
+ unet: UNet2DConditionModel,
382
+ scheduler: KarrasDiffusionSchedulers,
383
+ image_encoder: CLIPVisionModelWithProjection = None,
384
+ feature_extractor: CLIPImageProcessor = None,
385
+ force_zeros_for_empty_prompt: bool = True,
386
+ add_watermarker: Optional[bool] = None,
387
+ ):
388
+ super().__init__()
389
+ self.register_modules(
390
+ vae=vae,
391
+ text_encoder=text_encoder,
392
+ text_encoder_2=text_encoder_2,
393
+ tokenizer=tokenizer,
394
+ tokenizer_2=tokenizer_2,
395
+ unet=unet,
396
+ scheduler=scheduler,
397
+ image_encoder=image_encoder,
398
+ feature_extractor=feature_extractor,
399
+ )
400
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
401
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
402
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
403
+
404
+ self.default_sample_size = self.unet.config.sample_size
405
+
406
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
407
+
408
+ if add_watermarker:
409
+ self.watermark = StableDiffusionXLWatermarker()
410
+ else:
411
+ self.watermark = None
412
+
413
+ processor = (
414
+ CrossFrameAttnProcessor2_0(batch_size=2)
415
+ if hasattr(F, "scaled_dot_product_attention")
416
+ else CrossFrameAttnProcessor(batch_size=2)
417
+ )
418
+
419
+ self.unet.set_attn_processor(processor)
420
+
421
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
422
+ def prepare_extra_step_kwargs(self, generator, eta):
423
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
424
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
425
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
426
+ # and should be between [0, 1]
427
+
428
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
429
+ extra_step_kwargs = {}
430
+ if accepts_eta:
431
+ extra_step_kwargs["eta"] = eta
432
+
433
+ # check if the scheduler accepts generator
434
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
435
+ if accepts_generator:
436
+ extra_step_kwargs["generator"] = generator
437
+ return extra_step_kwargs
438
+
439
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
440
+ def enable_vae_slicing(self):
441
+ r"""
442
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
443
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
444
+ """
445
+ self.vae.enable_slicing()
446
+
447
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
448
+ def disable_vae_slicing(self):
449
+ r"""
450
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
451
+ computing decoding in one step.
452
+ """
453
+ self.vae.disable_slicing()
454
+
455
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
456
+ def upcast_vae(self):
457
+ dtype = self.vae.dtype
458
+ self.vae.to(dtype=torch.float32)
459
+ use_torch_2_0_or_xformers = isinstance(
460
+ self.vae.decoder.mid_block.attentions[0].processor,
461
+ (
462
+ AttnProcessor2_0,
463
+ XFormersAttnProcessor,
464
+ LoRAXFormersAttnProcessor,
465
+ LoRAAttnProcessor2_0,
466
+ FusedAttnProcessor2_0,
467
+ ),
468
+ )
469
+ # if xformers or torch_2_0 is used attention block does not need
470
+ # to be in float32 which can save lots of memory
471
+ if use_torch_2_0_or_xformers:
472
+ self.vae.post_quant_conv.to(dtype)
473
+ self.vae.decoder.conv_in.to(dtype)
474
+ self.vae.decoder.mid_block.to(dtype)
475
+
476
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
477
+ def _get_add_time_ids(
478
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
479
+ ):
480
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
481
+
482
+ passed_add_embed_dim = (
483
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
484
+ )
485
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
486
+
487
+ if expected_add_embed_dim != passed_add_embed_dim:
488
+ raise ValueError(
489
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
490
+ )
491
+
492
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
493
+ return add_time_ids
494
+
495
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
496
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
497
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
498
+ if isinstance(generator, list) and len(generator) != batch_size:
499
+ raise ValueError(
500
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
501
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
502
+ )
503
+
504
+ if latents is None:
505
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
506
+ else:
507
+ latents = latents.to(device)
508
+
509
+ # scale the initial noise by the standard deviation required by the scheduler
510
+ latents = latents * self.scheduler.init_noise_sigma
511
+ return latents
512
+
513
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs
514
+ def check_inputs(
515
+ self,
516
+ prompt,
517
+ prompt_2,
518
+ height,
519
+ width,
520
+ callback_steps,
521
+ negative_prompt=None,
522
+ negative_prompt_2=None,
523
+ prompt_embeds=None,
524
+ negative_prompt_embeds=None,
525
+ pooled_prompt_embeds=None,
526
+ negative_pooled_prompt_embeds=None,
527
+ callback_on_step_end_tensor_inputs=None,
528
+ ):
529
+ if height % 8 != 0 or width % 8 != 0:
530
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
531
+
532
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
533
+ raise ValueError(
534
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
535
+ f" {type(callback_steps)}."
536
+ )
537
+
538
+ if callback_on_step_end_tensor_inputs is not None and not all(
539
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
540
+ ):
541
+ raise ValueError(
542
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
543
+ )
544
+
545
+ if prompt is not None and prompt_embeds is not None:
546
+ raise ValueError(
547
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
548
+ " only forward one of the two."
549
+ )
550
+ elif prompt_2 is not None and prompt_embeds is not None:
551
+ raise ValueError(
552
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
553
+ " only forward one of the two."
554
+ )
555
+ elif prompt is None and prompt_embeds is None:
556
+ raise ValueError(
557
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
558
+ )
559
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
560
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
561
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
562
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
563
+
564
+ if negative_prompt is not None and negative_prompt_embeds is not None:
565
+ raise ValueError(
566
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
567
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
568
+ )
569
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
570
+ raise ValueError(
571
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
572
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
573
+ )
574
+
575
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
576
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
577
+ raise ValueError(
578
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
579
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
580
+ f" {negative_prompt_embeds.shape}."
581
+ )
582
+
583
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
584
+ raise ValueError(
585
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
586
+ )
587
+
588
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
589
+ raise ValueError(
590
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
591
+ )
592
+
593
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
594
+ def encode_prompt(
595
+ self,
596
+ prompt: str,
597
+ prompt_2: Optional[str] = None,
598
+ device: Optional[torch.device] = None,
599
+ num_images_per_prompt: int = 1,
600
+ do_classifier_free_guidance: bool = True,
601
+ negative_prompt: Optional[str] = None,
602
+ negative_prompt_2: Optional[str] = None,
603
+ prompt_embeds: Optional[torch.FloatTensor] = None,
604
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
605
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
606
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
607
+ lora_scale: Optional[float] = None,
608
+ clip_skip: Optional[int] = None,
609
+ ):
610
+ r"""
611
+ Encodes the prompt into text encoder hidden states.
612
+
613
+ Args:
614
+ prompt (`str` or `List[str]`, *optional*):
615
+ prompt to be encoded
616
+ prompt_2 (`str` or `List[str]`, *optional*):
617
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
618
+ used in both text-encoders
619
+ device: (`torch.device`):
620
+ torch device
621
+ num_images_per_prompt (`int`):
622
+ number of images that should be generated per prompt
623
+ do_classifier_free_guidance (`bool`):
624
+ whether to use classifier free guidance or not
625
+ negative_prompt (`str` or `List[str]`, *optional*):
626
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
627
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
628
+ less than `1`).
629
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
630
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
631
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
632
+ prompt_embeds (`torch.FloatTensor`, *optional*):
633
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
634
+ provided, text embeddings will be generated from `prompt` input argument.
635
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
636
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
637
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
638
+ argument.
639
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
640
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
641
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
642
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
643
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
644
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
645
+ input argument.
646
+ lora_scale (`float`, *optional*):
647
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
648
+ clip_skip (`int`, *optional*):
649
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
650
+ the output of the pre-final layer will be used for computing the prompt embeddings.
651
+ """
652
+ device = device or self._execution_device
653
+
654
+ # set lora scale so that monkey patched LoRA
655
+ # function of text encoder can correctly access it
656
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
657
+ self._lora_scale = lora_scale
658
+
659
+ # dynamically adjust the LoRA scale
660
+ if self.text_encoder is not None:
661
+ if not USE_PEFT_BACKEND:
662
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
663
+ else:
664
+ scale_lora_layers(self.text_encoder, lora_scale)
665
+
666
+ if self.text_encoder_2 is not None:
667
+ if not USE_PEFT_BACKEND:
668
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
669
+ else:
670
+ scale_lora_layers(self.text_encoder_2, lora_scale)
671
+
672
+ prompt = [prompt] if isinstance(prompt, str) else prompt
673
+
674
+ if prompt is not None:
675
+ batch_size = len(prompt)
676
+ else:
677
+ batch_size = prompt_embeds.shape[0]
678
+
679
+ # Define tokenizers and text encoders
680
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
681
+ text_encoders = (
682
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
683
+ )
684
+
685
+ if prompt_embeds is None:
686
+ prompt_2 = prompt_2 or prompt
687
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
688
+
689
+ # textual inversion: procecss multi-vector tokens if necessary
690
+ prompt_embeds_list = []
691
+ prompts = [prompt, prompt_2]
692
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
693
+ if isinstance(self, TextualInversionLoaderMixin):
694
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
695
+
696
+ text_inputs = tokenizer(
697
+ prompt,
698
+ padding="max_length",
699
+ max_length=tokenizer.model_max_length,
700
+ truncation=True,
701
+ return_tensors="pt",
702
+ )
703
+
704
+ text_input_ids = text_inputs.input_ids
705
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
706
+
707
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
708
+ text_input_ids, untruncated_ids
709
+ ):
710
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
711
+ logger.warning(
712
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
713
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
714
+ )
715
+
716
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
717
+
718
+ # We are only ALWAYS interested in the pooled output of the final text encoder
719
+ pooled_prompt_embeds = prompt_embeds[0]
720
+ if clip_skip is None:
721
+ prompt_embeds = prompt_embeds.hidden_states[-2]
722
+ else:
723
+ # "2" because SDXL always indexes from the penultimate layer.
724
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
725
+
726
+ prompt_embeds_list.append(prompt_embeds)
727
+
728
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
729
+
730
+ # get unconditional embeddings for classifier free guidance
731
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
732
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
733
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
734
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
735
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
736
+ negative_prompt = negative_prompt or ""
737
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
738
+
739
+ # normalize str to list
740
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
741
+ negative_prompt_2 = (
742
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
743
+ )
744
+
745
+ uncond_tokens: List[str]
746
+ if prompt is not None and type(prompt) is not type(negative_prompt):
747
+ raise TypeError(
748
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
749
+ f" {type(prompt)}."
750
+ )
751
+ elif batch_size != len(negative_prompt):
752
+ raise ValueError(
753
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
754
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
755
+ " the batch size of `prompt`."
756
+ )
757
+ else:
758
+ uncond_tokens = [negative_prompt, negative_prompt_2]
759
+
760
+ negative_prompt_embeds_list = []
761
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
762
+ if isinstance(self, TextualInversionLoaderMixin):
763
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
764
+
765
+ max_length = prompt_embeds.shape[1]
766
+ uncond_input = tokenizer(
767
+ negative_prompt,
768
+ padding="max_length",
769
+ max_length=max_length,
770
+ truncation=True,
771
+ return_tensors="pt",
772
+ )
773
+
774
+ negative_prompt_embeds = text_encoder(
775
+ uncond_input.input_ids.to(device),
776
+ output_hidden_states=True,
777
+ )
778
+ # We are only ALWAYS interested in the pooled output of the final text encoder
779
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
780
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
781
+
782
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
783
+
784
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
785
+
786
+ if self.text_encoder_2 is not None:
787
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
788
+ else:
789
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
790
+
791
+ bs_embed, seq_len, _ = prompt_embeds.shape
792
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
793
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
794
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
795
+
796
+ if do_classifier_free_guidance:
797
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
798
+ seq_len = negative_prompt_embeds.shape[1]
799
+
800
+ if self.text_encoder_2 is not None:
801
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
802
+ else:
803
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
804
+
805
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
806
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
807
+
808
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
809
+ bs_embed * num_images_per_prompt, -1
810
+ )
811
+ if do_classifier_free_guidance:
812
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
813
+ bs_embed * num_images_per_prompt, -1
814
+ )
815
+
816
+ if self.text_encoder is not None:
817
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
818
+ # Retrieve the original scale by scaling back the LoRA layers
819
+ unscale_lora_layers(self.text_encoder, lora_scale)
820
+
821
+ if self.text_encoder_2 is not None:
822
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
823
+ # Retrieve the original scale by scaling back the LoRA layers
824
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
825
+
826
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
827
+
828
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoZeroPipeline.forward_loop
829
+ def forward_loop(self, x_t0, t0, t1, generator):
830
+ """
831
+ Perform DDPM forward process from time t0 to t1. This is the same as adding noise with corresponding variance.
832
+
833
+ Args:
834
+ x_t0:
835
+ Latent code at time t0.
836
+ t0:
837
+ Timestep at t0.
838
+ t1:
839
+ Timestamp at t1.
840
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
841
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
842
+ generation deterministic.
843
+
844
+ Returns:
845
+ x_t1:
846
+ Forward process applied to x_t0 from time t0 to t1.
847
+ """
848
+ eps = randn_tensor(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device)
849
+ alpha_vec = torch.prod(self.scheduler.alphas[t0:t1])
850
+ x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps
851
+ return x_t1
852
+
853
+ def backward_loop(
854
+ self,
855
+ latents,
856
+ timesteps,
857
+ prompt_embeds,
858
+ guidance_scale,
859
+ callback,
860
+ callback_steps,
861
+ num_warmup_steps,
862
+ extra_step_kwargs,
863
+ add_text_embeds,
864
+ add_time_ids,
865
+ cross_attention_kwargs=None,
866
+ guidance_rescale: float = 0.0,
867
+ ):
868
+ """
869
+ Perform backward process given list of time steps
870
+
871
+ Args:
872
+ latents:
873
+ Latents at time timesteps[0].
874
+ timesteps:
875
+ Time steps along which to perform backward process.
876
+ prompt_embeds:
877
+ Pre-generated text embeddings.
878
+ guidance_scale:
879
+ A higher guidance scale value encourages the model to generate images closely linked to the text
880
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
881
+ callback (`Callable`, *optional*):
882
+ A function that calls every `callback_steps` steps during inference. The function is called with the
883
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
884
+ callback_steps (`int`, *optional*, defaults to 1):
885
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
886
+ every step.
887
+ extra_step_kwargs:
888
+ Extra_step_kwargs.
889
+ cross_attention_kwargs:
890
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
891
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
892
+ num_warmup_steps:
893
+ number of warmup steps.
894
+
895
+ Returns:
896
+ latents: latents of backward process output at time timesteps[-1]
897
+ """
898
+
899
+ do_classifier_free_guidance = guidance_scale > 1.0
900
+ num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order
901
+
902
+ with self.progress_bar(total=num_steps) as progress_bar:
903
+ for i, t in enumerate(timesteps):
904
+ # expand the latents if we are doing classifier free guidance
905
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
906
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
907
+
908
+ # predict the noise residual
909
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
910
+ noise_pred = self.unet(
911
+ latent_model_input,
912
+ t,
913
+ encoder_hidden_states=prompt_embeds,
914
+ cross_attention_kwargs=cross_attention_kwargs,
915
+ added_cond_kwargs=added_cond_kwargs,
916
+ return_dict=False,
917
+ )[0]
918
+
919
+ # perform guidance
920
+ if do_classifier_free_guidance:
921
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
922
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
923
+
924
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
925
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
926
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
927
+
928
+ # compute the previous noisy sample x_t -> x_t-1
929
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
930
+
931
+ # call the callback, if provided
932
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
933
+ progress_bar.update()
934
+ if callback is not None and i % callback_steps == 0:
935
+ callback(i, t, latents)
936
+ return latents.clone().detach()
937
+
938
+ @torch.no_grad()
939
+ def __call__(
940
+ self,
941
+ prompt: Union[str, List[str]],
942
+ prompt_2: Optional[Union[str, List[str]]] = None,
943
+ video_length: Optional[int] = 8,
944
+ height: Optional[int] = None,
945
+ width: Optional[int] = None,
946
+ num_inference_steps: int = 50,
947
+ denoising_end: Optional[float] = None,
948
+ guidance_scale: float = 7.5,
949
+ negative_prompt: Optional[Union[str, List[str]]] = None,
950
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
951
+ num_videos_per_prompt: Optional[int] = 1,
952
+ eta: float = 0.0,
953
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
954
+ frame_ids: Optional[List[int]] = None,
955
+ prompt_embeds: Optional[torch.FloatTensor] = None,
956
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
957
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
958
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
959
+ latents: Optional[torch.FloatTensor] = None,
960
+ motion_field_strength_x: float = 12,
961
+ motion_field_strength_y: float = 12,
962
+ output_type: Optional[str] = "tensor",
963
+ return_dict: bool = True,
964
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
965
+ callback_steps: int = 1,
966
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
967
+ guidance_rescale: float = 0.0,
968
+ original_size: Optional[Tuple[int, int]] = None,
969
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
970
+ target_size: Optional[Tuple[int, int]] = None,
971
+ t0: int = 44,
972
+ t1: int = 47,
973
+ ):
974
+ """
975
+ Function invoked when calling the pipeline for generation.
976
+
977
+ Args:
978
+ prompt (`str` or `List[str]`, *optional*):
979
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
980
+ instead.
981
+ prompt_2 (`str` or `List[str]`, *optional*):
982
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
983
+ used in both text-encoders
984
+ video_length (`int`, *optional*, defaults to 8):
985
+ The number of generated video frames.
986
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
987
+ The height in pixels of the generated image.
988
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
989
+ The width in pixels of the generated image.
990
+ num_inference_steps (`int`, *optional*, defaults to 50):
991
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
992
+ expense of slower inference.
993
+ denoising_end (`float`, *optional*):
994
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
995
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
996
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
997
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
998
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
999
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
1000
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1001
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1002
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1003
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1004
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1005
+ usually at the expense of lower image quality.
1006
+ negative_prompt (`str` or `List[str]`, *optional*):
1007
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1008
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1009
+ less than `1`).
1010
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1011
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1012
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1013
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
1014
+ The number of videos to generate per prompt.
1015
+ eta (`float`, *optional*, defaults to 0.0):
1016
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1017
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1018
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1019
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1020
+ to make generation deterministic.
1021
+ frame_ids (`List[int]`, *optional*):
1022
+ Indexes of the frames that are being generated. This is used when generating longer videos
1023
+ chunk-by-chunk.
1024
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1025
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1026
+ provided, text embeddings will be generated from `prompt` input argument.
1027
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1028
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1029
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1030
+ argument.
1031
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1032
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1033
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1034
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1035
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1036
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1037
+ input argument.
1038
+ latents (`torch.FloatTensor`, *optional*):
1039
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1040
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1041
+ tensor will ge generated by sampling using the supplied random `generator`.
1042
+ motion_field_strength_x (`float`, *optional*, defaults to 12):
1043
+ Strength of motion in generated video along x-axis. See the [paper](https://arxiv.org/abs/2303.13439),
1044
+ Sect. 3.3.1.
1045
+ motion_field_strength_y (`float`, *optional*, defaults to 12):
1046
+ Strength of motion in generated video along y-axis. See the [paper](https://arxiv.org/abs/2303.13439),
1047
+ Sect. 3.3.1.
1048
+ output_type (`str`, *optional*, defaults to `"pil"`):
1049
+ The output format of the generate image. Choose between
1050
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1051
+ return_dict (`bool`, *optional*, defaults to `True`):
1052
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
1053
+ of a plain tuple.
1054
+ callback (`Callable`, *optional*):
1055
+ A function that will be called every `callback_steps` steps during inference. The function will be
1056
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1057
+ callback_steps (`int`, *optional*, defaults to 1):
1058
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1059
+ called at every step.
1060
+ cross_attention_kwargs (`dict`, *optional*):
1061
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1062
+ `self.processor` in
1063
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1064
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
1065
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
1066
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
1067
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
1068
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
1069
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1070
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1071
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
1072
+ explained in section 2.2 of
1073
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1074
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1075
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1076
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1077
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1078
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1079
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1080
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
1081
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
1082
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1083
+ t0 (`int`, *optional*, defaults to 44):
1084
+ Timestep t0. Should be in the range [0, num_inference_steps - 1]. See the
1085
+ [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1.
1086
+ t1 (`int`, *optional*, defaults to 47):
1087
+ Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the
1088
+ [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1.
1089
+
1090
+ Returns:
1091
+ [`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoSDXLPipelineOutput`] or
1092
+ `tuple`: [`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoSDXLPipelineOutput`]
1093
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
1094
+ generated images.
1095
+ """
1096
+ assert video_length > 0
1097
+ if frame_ids is None:
1098
+ frame_ids = list(range(video_length))
1099
+ assert len(frame_ids) == video_length
1100
+
1101
+ assert num_videos_per_prompt == 1
1102
+
1103
+ if isinstance(prompt, str):
1104
+ prompt = [prompt]
1105
+ if isinstance(negative_prompt, str):
1106
+ negative_prompt = [negative_prompt]
1107
+
1108
+ # 0. Default height and width to unet
1109
+ height = height or self.default_sample_size * self.vae_scale_factor
1110
+ width = width or self.default_sample_size * self.vae_scale_factor
1111
+
1112
+ original_size = original_size or (height, width)
1113
+ target_size = target_size or (height, width)
1114
+
1115
+ # 1. Check inputs. Raise error if not correct
1116
+ self.check_inputs(
1117
+ prompt,
1118
+ prompt_2,
1119
+ height,
1120
+ width,
1121
+ callback_steps,
1122
+ negative_prompt,
1123
+ negative_prompt_2,
1124
+ prompt_embeds,
1125
+ negative_prompt_embeds,
1126
+ pooled_prompt_embeds,
1127
+ negative_pooled_prompt_embeds,
1128
+ )
1129
+
1130
+ # 2. Define call parameters
1131
+ batch_size = (
1132
+ 1 if isinstance(prompt, str) else len(prompt) if isinstance(prompt, list) else prompt_embeds.shape[0]
1133
+ )
1134
+ device = self._execution_device
1135
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1136
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1137
+ # corresponds to doing no classifier free guidance.
1138
+ do_classifier_free_guidance = guidance_scale > 1.0
1139
+
1140
+ # 3. Encode input prompt
1141
+ text_encoder_lora_scale = (
1142
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
1143
+ )
1144
+ (
1145
+ prompt_embeds,
1146
+ negative_prompt_embeds,
1147
+ pooled_prompt_embeds,
1148
+ negative_pooled_prompt_embeds,
1149
+ ) = self.encode_prompt(
1150
+ prompt=prompt,
1151
+ prompt_2=prompt_2,
1152
+ device=device,
1153
+ num_images_per_prompt=num_videos_per_prompt,
1154
+ do_classifier_free_guidance=do_classifier_free_guidance,
1155
+ negative_prompt=negative_prompt,
1156
+ negative_prompt_2=negative_prompt_2,
1157
+ prompt_embeds=prompt_embeds,
1158
+ negative_prompt_embeds=negative_prompt_embeds,
1159
+ pooled_prompt_embeds=pooled_prompt_embeds,
1160
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1161
+ lora_scale=text_encoder_lora_scale,
1162
+ )
1163
+
1164
+ # 4. Prepare timesteps
1165
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1166
+ timesteps = self.scheduler.timesteps
1167
+
1168
+ # 5. Prepare latent variables
1169
+ num_channels_latents = self.unet.config.in_channels
1170
+
1171
+ latents = self.prepare_latents(
1172
+ batch_size * num_videos_per_prompt,
1173
+ num_channels_latents,
1174
+ height,
1175
+ width,
1176
+ prompt_embeds.dtype,
1177
+ device,
1178
+ generator,
1179
+ latents,
1180
+ )
1181
+
1182
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1183
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1184
+
1185
+ # 7. Prepare added time ids & embeddings
1186
+ add_text_embeds = pooled_prompt_embeds
1187
+ if self.text_encoder_2 is None:
1188
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1189
+ else:
1190
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1191
+
1192
+ add_time_ids = self._get_add_time_ids(
1193
+ original_size,
1194
+ crops_coords_top_left,
1195
+ target_size,
1196
+ dtype=prompt_embeds.dtype,
1197
+ text_encoder_projection_dim=text_encoder_projection_dim,
1198
+ )
1199
+
1200
+ if do_classifier_free_guidance:
1201
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1202
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1203
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
1204
+
1205
+ prompt_embeds = prompt_embeds.to(device)
1206
+ add_text_embeds = add_text_embeds.to(device)
1207
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1)
1208
+
1209
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1210
+
1211
+ # Perform the first backward process up to time T_1
1212
+ x_1_t1 = self.backward_loop(
1213
+ timesteps=timesteps[: -t1 - 1],
1214
+ prompt_embeds=prompt_embeds,
1215
+ latents=latents,
1216
+ guidance_scale=guidance_scale,
1217
+ callback=callback,
1218
+ callback_steps=callback_steps,
1219
+ extra_step_kwargs=extra_step_kwargs,
1220
+ num_warmup_steps=num_warmup_steps,
1221
+ add_text_embeds=add_text_embeds,
1222
+ add_time_ids=add_time_ids,
1223
+ )
1224
+
1225
+ scheduler_copy = copy.deepcopy(self.scheduler)
1226
+
1227
+ # Perform the second backward process up to time T_0
1228
+ x_1_t0 = self.backward_loop(
1229
+ timesteps=timesteps[-t1 - 1 : -t0 - 1],
1230
+ prompt_embeds=prompt_embeds,
1231
+ latents=x_1_t1,
1232
+ guidance_scale=guidance_scale,
1233
+ callback=callback,
1234
+ callback_steps=callback_steps,
1235
+ extra_step_kwargs=extra_step_kwargs,
1236
+ num_warmup_steps=0,
1237
+ add_text_embeds=add_text_embeds,
1238
+ add_time_ids=add_time_ids,
1239
+ )
1240
+
1241
+ # Propagate first frame latents at time T_0 to remaining frames
1242
+ x_2k_t0 = x_1_t0.repeat(video_length - 1, 1, 1, 1)
1243
+
1244
+ # Add motion in latents at time T_0
1245
+ x_2k_t0 = create_motion_field_and_warp_latents(
1246
+ motion_field_strength_x=motion_field_strength_x,
1247
+ motion_field_strength_y=motion_field_strength_y,
1248
+ latents=x_2k_t0,
1249
+ frame_ids=frame_ids[1:],
1250
+ )
1251
+
1252
+ # Perform forward process up to time T_1
1253
+ x_2k_t1 = self.forward_loop(
1254
+ x_t0=x_2k_t0,
1255
+ t0=timesteps[-t0 - 1].to(torch.long),
1256
+ t1=timesteps[-t1 - 1].to(torch.long),
1257
+ generator=generator,
1258
+ )
1259
+
1260
+ # Perform backward process from time T_1 to 0
1261
+ latents = torch.cat([x_1_t1, x_2k_t1])
1262
+
1263
+ self.scheduler = scheduler_copy
1264
+ timesteps = timesteps[-t1 - 1 :]
1265
+
1266
+ b, l, d = prompt_embeds.size()
1267
+ prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d)
1268
+
1269
+ b, k = add_text_embeds.size()
1270
+ add_text_embeds = add_text_embeds[:, None].repeat(1, video_length, 1).reshape(b * video_length, k)
1271
+
1272
+ b, k = add_time_ids.size()
1273
+ add_time_ids = add_time_ids[:, None].repeat(1, video_length, 1).reshape(b * video_length, k)
1274
+
1275
+ # 7.1 Apply denoising_end
1276
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
1277
+ discrete_timestep_cutoff = int(
1278
+ round(
1279
+ self.scheduler.config.num_train_timesteps
1280
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
1281
+ )
1282
+ )
1283
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1284
+ timesteps = timesteps[:num_inference_steps]
1285
+
1286
+ x_1k_0 = self.backward_loop(
1287
+ timesteps=timesteps,
1288
+ prompt_embeds=prompt_embeds,
1289
+ latents=latents,
1290
+ guidance_scale=guidance_scale,
1291
+ callback=callback,
1292
+ callback_steps=callback_steps,
1293
+ extra_step_kwargs=extra_step_kwargs,
1294
+ num_warmup_steps=0,
1295
+ add_text_embeds=add_text_embeds,
1296
+ add_time_ids=add_time_ids,
1297
+ )
1298
+
1299
+ latents = x_1k_0
1300
+
1301
+ if not output_type == "latent":
1302
+ # make sure the VAE is in float32 mode, as it overflows in float16
1303
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1304
+
1305
+ if needs_upcasting:
1306
+ self.upcast_vae()
1307
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1308
+
1309
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1310
+
1311
+ # cast back to fp16 if needed
1312
+ if needs_upcasting:
1313
+ self.vae.to(dtype=torch.float16)
1314
+ else:
1315
+ image = latents
1316
+ return TextToVideoSDXLPipelineOutput(images=image)
1317
+
1318
+ # apply watermark if available
1319
+ if self.watermark is not None:
1320
+ image = self.watermark.apply_watermark(image)
1321
+
1322
+ image = self.image_processor.postprocess(image, output_type=output_type)
1323
+
1324
+ # Offload last model to CPU manually for max memory savings
1325
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1326
+ self.final_offload_hook.offload()
1327
+
1328
+ if not return_dict:
1329
+ return (image,)
1330
+
1331
+ return TextToVideoSDXLPipelineOutput(images=image)