diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (238) hide show
  1. diffusers/__init__.py +26 -2
  2. diffusers/commands/fp16_safetensors.py +10 -11
  3. diffusers/configuration_utils.py +13 -8
  4. diffusers/dependency_versions_check.py +0 -1
  5. diffusers/dependency_versions_table.py +5 -5
  6. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  7. diffusers/image_processor.py +463 -51
  8. diffusers/loaders/__init__.py +82 -0
  9. diffusers/loaders/ip_adapter.py +159 -0
  10. diffusers/loaders/lora.py +1553 -0
  11. diffusers/loaders/lora_conversion_utils.py +284 -0
  12. diffusers/loaders/single_file.py +637 -0
  13. diffusers/loaders/textual_inversion.py +455 -0
  14. diffusers/loaders/unet.py +828 -0
  15. diffusers/loaders/utils.py +59 -0
  16. diffusers/models/__init__.py +26 -9
  17. diffusers/models/activations.py +9 -6
  18. diffusers/models/attention.py +301 -29
  19. diffusers/models/attention_flax.py +9 -1
  20. diffusers/models/attention_processor.py +378 -6
  21. diffusers/models/autoencoders/__init__.py +5 -0
  22. diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
  23. diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
  24. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
  25. diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
  26. diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
  27. diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
  28. diffusers/models/controlnet.py +59 -39
  29. diffusers/models/controlnet_flax.py +19 -18
  30. diffusers/models/downsampling.py +338 -0
  31. diffusers/models/embeddings.py +112 -29
  32. diffusers/models/embeddings_flax.py +2 -0
  33. diffusers/models/lora.py +131 -1
  34. diffusers/models/modeling_flax_utils.py +14 -8
  35. diffusers/models/modeling_outputs.py +17 -0
  36. diffusers/models/modeling_utils.py +37 -29
  37. diffusers/models/normalization.py +110 -4
  38. diffusers/models/resnet.py +299 -652
  39. diffusers/models/transformer_2d.py +22 -5
  40. diffusers/models/transformer_temporal.py +183 -1
  41. diffusers/models/unet_2d_blocks_flax.py +5 -0
  42. diffusers/models/unet_2d_condition.py +46 -0
  43. diffusers/models/unet_2d_condition_flax.py +13 -13
  44. diffusers/models/unet_3d_blocks.py +957 -173
  45. diffusers/models/unet_3d_condition.py +16 -8
  46. diffusers/models/unet_kandinsky3.py +535 -0
  47. diffusers/models/unet_motion_model.py +48 -33
  48. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  49. diffusers/models/upsampling.py +454 -0
  50. diffusers/models/uvit_2d.py +471 -0
  51. diffusers/models/vae_flax.py +7 -0
  52. diffusers/models/vq_model.py +12 -3
  53. diffusers/optimization.py +16 -9
  54. diffusers/pipelines/__init__.py +137 -76
  55. diffusers/pipelines/amused/__init__.py +62 -0
  56. diffusers/pipelines/amused/pipeline_amused.py +328 -0
  57. diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
  58. diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
  59. diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
  60. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  61. diffusers/pipelines/auto_pipeline.py +23 -13
  62. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  63. diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
  64. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
  65. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
  66. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
  67. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
  68. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
  69. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  70. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  71. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  72. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  73. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  74. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  75. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  76. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  77. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  78. diffusers/pipelines/deprecated/__init__.py +153 -0
  79. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
  80. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
  81. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
  82. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
  83. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
  84. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
  85. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
  86. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
  87. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
  88. diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
  89. diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
  90. diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
  91. diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
  92. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
  93. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
  94. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
  95. diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
  96. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
  97. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
  98. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
  100. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
  101. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
  102. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
  103. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
  104. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
  105. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
  106. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
  107. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
  108. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
  109. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
  110. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
  111. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
  112. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
  113. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
  114. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
  115. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
  116. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  117. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  118. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  119. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  120. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  122. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  123. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
  124. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
  125. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
  126. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
  127. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
  128. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  129. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  130. diffusers/pipelines/onnx_utils.py +8 -5
  131. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  132. diffusers/pipelines/pipeline_flax_utils.py +11 -8
  133. diffusers/pipelines/pipeline_utils.py +63 -42
  134. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
  135. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
  136. diffusers/pipelines/stable_diffusion/__init__.py +37 -65
  137. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
  138. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  139. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  140. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  141. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
  142. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  143. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  144. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
  145. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
  146. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
  147. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
  151. diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
  152. diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
  153. diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
  154. diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
  155. diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
  156. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
  157. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  158. diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
  159. diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
  160. diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
  161. diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
  162. diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
  163. diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
  164. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
  165. diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
  166. diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
  171. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  172. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
  175. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
  179. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
  180. diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
  181. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
  182. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  183. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  184. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  185. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  186. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  187. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
  188. diffusers/schedulers/__init__.py +4 -4
  189. diffusers/schedulers/deprecated/__init__.py +50 -0
  190. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  191. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  192. diffusers/schedulers/scheduling_amused.py +162 -0
  193. diffusers/schedulers/scheduling_consistency_models.py +2 -0
  194. diffusers/schedulers/scheduling_ddim.py +1 -3
  195. diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
  196. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  197. diffusers/schedulers/scheduling_ddpm.py +47 -3
  198. diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
  199. diffusers/schedulers/scheduling_deis_multistep.py +28 -6
  200. diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
  201. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
  202. diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
  203. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
  204. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
  205. diffusers/schedulers/scheduling_euler_discrete.py +102 -16
  206. diffusers/schedulers/scheduling_heun_discrete.py +17 -5
  207. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
  208. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
  209. diffusers/schedulers/scheduling_lcm.py +123 -29
  210. diffusers/schedulers/scheduling_lms_discrete.py +3 -3
  211. diffusers/schedulers/scheduling_pndm.py +1 -3
  212. diffusers/schedulers/scheduling_repaint.py +1 -3
  213. diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
  214. diffusers/schedulers/scheduling_utils.py +3 -1
  215. diffusers/schedulers/scheduling_utils_flax.py +3 -1
  216. diffusers/training_utils.py +1 -1
  217. diffusers/utils/__init__.py +1 -2
  218. diffusers/utils/constants.py +10 -12
  219. diffusers/utils/dummy_pt_objects.py +75 -0
  220. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  221. diffusers/utils/dynamic_modules_utils.py +18 -22
  222. diffusers/utils/export_utils.py +8 -3
  223. diffusers/utils/hub_utils.py +24 -36
  224. diffusers/utils/logging.py +11 -11
  225. diffusers/utils/outputs.py +5 -5
  226. diffusers/utils/peft_utils.py +88 -44
  227. diffusers/utils/state_dict_utils.py +8 -0
  228. diffusers/utils/testing_utils.py +199 -1
  229. diffusers/utils/torch_utils.py +4 -4
  230. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
  231. diffusers-0.25.0.dist-info/RECORD +360 -0
  232. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
  233. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
  234. diffusers/loaders.py +0 -3336
  235. diffusers-0.23.1.dist-info/RECORD +0 -323
  236. /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
  237. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
  238. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,489 @@
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from ..configuration_utils import ConfigMixin, register_to_config
8
+ from ..loaders import UNet2DConditionLoadersMixin
9
+ from ..utils import BaseOutput, logging
10
+ from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
11
+ from .embeddings import TimestepEmbedding, Timesteps
12
+ from .modeling_utils import ModelMixin
13
+ from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
14
+
15
+
16
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17
+
18
+
19
+ @dataclass
20
+ class UNetSpatioTemporalConditionOutput(BaseOutput):
21
+ """
22
+ The output of [`UNetSpatioTemporalConditionModel`].
23
+
24
+ Args:
25
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
26
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
27
+ """
28
+
29
+ sample: torch.FloatTensor = None
30
+
31
+
32
+ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
33
+ r"""
34
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
35
+ shaped output.
36
+
37
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
38
+ for all models (such as downloading or saving).
39
+
40
+ Parameters:
41
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
42
+ Height and width of input/output sample.
43
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
44
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
45
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
46
+ The tuple of downsample blocks to use.
47
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
48
+ The tuple of upsample blocks to use.
49
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
50
+ The tuple of output channels for each block.
51
+ addition_time_embed_dim: (`int`, defaults to 256):
52
+ Dimension to to encode the additional time ids.
53
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
54
+ The dimension of the projection of encoded `added_time_ids`.
55
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
56
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
57
+ The dimension of the cross attention features.
58
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
59
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
60
+ [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
61
+ [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
62
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
63
+ The number of attention heads.
64
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
65
+ """
66
+
67
+ _supports_gradient_checkpointing = True
68
+
69
+ @register_to_config
70
+ def __init__(
71
+ self,
72
+ sample_size: Optional[int] = None,
73
+ in_channels: int = 8,
74
+ out_channels: int = 4,
75
+ down_block_types: Tuple[str] = (
76
+ "CrossAttnDownBlockSpatioTemporal",
77
+ "CrossAttnDownBlockSpatioTemporal",
78
+ "CrossAttnDownBlockSpatioTemporal",
79
+ "DownBlockSpatioTemporal",
80
+ ),
81
+ up_block_types: Tuple[str] = (
82
+ "UpBlockSpatioTemporal",
83
+ "CrossAttnUpBlockSpatioTemporal",
84
+ "CrossAttnUpBlockSpatioTemporal",
85
+ "CrossAttnUpBlockSpatioTemporal",
86
+ ),
87
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
88
+ addition_time_embed_dim: int = 256,
89
+ projection_class_embeddings_input_dim: int = 768,
90
+ layers_per_block: Union[int, Tuple[int]] = 2,
91
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
92
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
93
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
94
+ num_frames: int = 25,
95
+ ):
96
+ super().__init__()
97
+
98
+ self.sample_size = sample_size
99
+
100
+ # Check inputs
101
+ if len(down_block_types) != len(up_block_types):
102
+ raise ValueError(
103
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
104
+ )
105
+
106
+ if len(block_out_channels) != len(down_block_types):
107
+ raise ValueError(
108
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
109
+ )
110
+
111
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
112
+ raise ValueError(
113
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
114
+ )
115
+
116
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
117
+ raise ValueError(
118
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
119
+ )
120
+
121
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
122
+ raise ValueError(
123
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
124
+ )
125
+
126
+ # input
127
+ self.conv_in = nn.Conv2d(
128
+ in_channels,
129
+ block_out_channels[0],
130
+ kernel_size=3,
131
+ padding=1,
132
+ )
133
+
134
+ # time
135
+ time_embed_dim = block_out_channels[0] * 4
136
+
137
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
138
+ timestep_input_dim = block_out_channels[0]
139
+
140
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
141
+
142
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
143
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
144
+
145
+ self.down_blocks = nn.ModuleList([])
146
+ self.up_blocks = nn.ModuleList([])
147
+
148
+ if isinstance(num_attention_heads, int):
149
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
150
+
151
+ if isinstance(cross_attention_dim, int):
152
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
153
+
154
+ if isinstance(layers_per_block, int):
155
+ layers_per_block = [layers_per_block] * len(down_block_types)
156
+
157
+ if isinstance(transformer_layers_per_block, int):
158
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
159
+
160
+ blocks_time_embed_dim = time_embed_dim
161
+
162
+ # down
163
+ output_channel = block_out_channels[0]
164
+ for i, down_block_type in enumerate(down_block_types):
165
+ input_channel = output_channel
166
+ output_channel = block_out_channels[i]
167
+ is_final_block = i == len(block_out_channels) - 1
168
+
169
+ down_block = get_down_block(
170
+ down_block_type,
171
+ num_layers=layers_per_block[i],
172
+ transformer_layers_per_block=transformer_layers_per_block[i],
173
+ in_channels=input_channel,
174
+ out_channels=output_channel,
175
+ temb_channels=blocks_time_embed_dim,
176
+ add_downsample=not is_final_block,
177
+ resnet_eps=1e-5,
178
+ cross_attention_dim=cross_attention_dim[i],
179
+ num_attention_heads=num_attention_heads[i],
180
+ resnet_act_fn="silu",
181
+ )
182
+ self.down_blocks.append(down_block)
183
+
184
+ # mid
185
+ self.mid_block = UNetMidBlockSpatioTemporal(
186
+ block_out_channels[-1],
187
+ temb_channels=blocks_time_embed_dim,
188
+ transformer_layers_per_block=transformer_layers_per_block[-1],
189
+ cross_attention_dim=cross_attention_dim[-1],
190
+ num_attention_heads=num_attention_heads[-1],
191
+ )
192
+
193
+ # count how many layers upsample the images
194
+ self.num_upsamplers = 0
195
+
196
+ # up
197
+ reversed_block_out_channels = list(reversed(block_out_channels))
198
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
199
+ reversed_layers_per_block = list(reversed(layers_per_block))
200
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
201
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
202
+
203
+ output_channel = reversed_block_out_channels[0]
204
+ for i, up_block_type in enumerate(up_block_types):
205
+ is_final_block = i == len(block_out_channels) - 1
206
+
207
+ prev_output_channel = output_channel
208
+ output_channel = reversed_block_out_channels[i]
209
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
210
+
211
+ # add upsample block for all BUT final layer
212
+ if not is_final_block:
213
+ add_upsample = True
214
+ self.num_upsamplers += 1
215
+ else:
216
+ add_upsample = False
217
+
218
+ up_block = get_up_block(
219
+ up_block_type,
220
+ num_layers=reversed_layers_per_block[i] + 1,
221
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
222
+ in_channels=input_channel,
223
+ out_channels=output_channel,
224
+ prev_output_channel=prev_output_channel,
225
+ temb_channels=blocks_time_embed_dim,
226
+ add_upsample=add_upsample,
227
+ resnet_eps=1e-5,
228
+ resolution_idx=i,
229
+ cross_attention_dim=reversed_cross_attention_dim[i],
230
+ num_attention_heads=reversed_num_attention_heads[i],
231
+ resnet_act_fn="silu",
232
+ )
233
+ self.up_blocks.append(up_block)
234
+ prev_output_channel = output_channel
235
+
236
+ # out
237
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
238
+ self.conv_act = nn.SiLU()
239
+
240
+ self.conv_out = nn.Conv2d(
241
+ block_out_channels[0],
242
+ out_channels,
243
+ kernel_size=3,
244
+ padding=1,
245
+ )
246
+
247
+ @property
248
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
249
+ r"""
250
+ Returns:
251
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
252
+ indexed by its weight name.
253
+ """
254
+ # set recursively
255
+ processors = {}
256
+
257
+ def fn_recursive_add_processors(
258
+ name: str,
259
+ module: torch.nn.Module,
260
+ processors: Dict[str, AttentionProcessor],
261
+ ):
262
+ if hasattr(module, "get_processor"):
263
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
264
+
265
+ for sub_name, child in module.named_children():
266
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
267
+
268
+ return processors
269
+
270
+ for name, module in self.named_children():
271
+ fn_recursive_add_processors(name, module, processors)
272
+
273
+ return processors
274
+
275
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
276
+ r"""
277
+ Sets the attention processor to use to compute attention.
278
+
279
+ Parameters:
280
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
281
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
282
+ for **all** `Attention` layers.
283
+
284
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
285
+ processor. This is strongly recommended when setting trainable attention processors.
286
+
287
+ """
288
+ count = len(self.attn_processors.keys())
289
+
290
+ if isinstance(processor, dict) and len(processor) != count:
291
+ raise ValueError(
292
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
293
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
294
+ )
295
+
296
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
297
+ if hasattr(module, "set_processor"):
298
+ if not isinstance(processor, dict):
299
+ module.set_processor(processor)
300
+ else:
301
+ module.set_processor(processor.pop(f"{name}.processor"))
302
+
303
+ for sub_name, child in module.named_children():
304
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
305
+
306
+ for name, module in self.named_children():
307
+ fn_recursive_attn_processor(name, module, processor)
308
+
309
+ def set_default_attn_processor(self):
310
+ """
311
+ Disables custom attention processors and sets the default attention implementation.
312
+ """
313
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
314
+ processor = AttnProcessor()
315
+ else:
316
+ raise ValueError(
317
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
318
+ )
319
+
320
+ self.set_attn_processor(processor)
321
+
322
+ def _set_gradient_checkpointing(self, module, value=False):
323
+ if hasattr(module, "gradient_checkpointing"):
324
+ module.gradient_checkpointing = value
325
+
326
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
327
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
328
+ """
329
+ Sets the attention processor to use [feed forward
330
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
331
+
332
+ Parameters:
333
+ chunk_size (`int`, *optional*):
334
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
335
+ over each tensor of dim=`dim`.
336
+ dim (`int`, *optional*, defaults to `0`):
337
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
338
+ or dim=1 (sequence length).
339
+ """
340
+ if dim not in [0, 1]:
341
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
342
+
343
+ # By default chunk size is 1
344
+ chunk_size = chunk_size or 1
345
+
346
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
347
+ if hasattr(module, "set_chunk_feed_forward"):
348
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
349
+
350
+ for child in module.children():
351
+ fn_recursive_feed_forward(child, chunk_size, dim)
352
+
353
+ for module in self.children():
354
+ fn_recursive_feed_forward(module, chunk_size, dim)
355
+
356
+ def forward(
357
+ self,
358
+ sample: torch.FloatTensor,
359
+ timestep: Union[torch.Tensor, float, int],
360
+ encoder_hidden_states: torch.Tensor,
361
+ added_time_ids: torch.Tensor,
362
+ return_dict: bool = True,
363
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
364
+ r"""
365
+ The [`UNetSpatioTemporalConditionModel`] forward method.
366
+
367
+ Args:
368
+ sample (`torch.FloatTensor`):
369
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
370
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
371
+ encoder_hidden_states (`torch.FloatTensor`):
372
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
373
+ added_time_ids: (`torch.FloatTensor`):
374
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
375
+ embeddings and added to the time embeddings.
376
+ return_dict (`bool`, *optional*, defaults to `True`):
377
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
378
+ tuple.
379
+ Returns:
380
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
381
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
382
+ a `tuple` is returned where the first element is the sample tensor.
383
+ """
384
+ # 1. time
385
+ timesteps = timestep
386
+ if not torch.is_tensor(timesteps):
387
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
388
+ # This would be a good case for the `match` statement (Python 3.10+)
389
+ is_mps = sample.device.type == "mps"
390
+ if isinstance(timestep, float):
391
+ dtype = torch.float32 if is_mps else torch.float64
392
+ else:
393
+ dtype = torch.int32 if is_mps else torch.int64
394
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
395
+ elif len(timesteps.shape) == 0:
396
+ timesteps = timesteps[None].to(sample.device)
397
+
398
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
399
+ batch_size, num_frames = sample.shape[:2]
400
+ timesteps = timesteps.expand(batch_size)
401
+
402
+ t_emb = self.time_proj(timesteps)
403
+
404
+ # `Timesteps` does not contain any weights and will always return f32 tensors
405
+ # but time_embedding might actually be running in fp16. so we need to cast here.
406
+ # there might be better ways to encapsulate this.
407
+ t_emb = t_emb.to(dtype=sample.dtype)
408
+
409
+ emb = self.time_embedding(t_emb)
410
+
411
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
412
+ time_embeds = time_embeds.reshape((batch_size, -1))
413
+ time_embeds = time_embeds.to(emb.dtype)
414
+ aug_emb = self.add_embedding(time_embeds)
415
+ emb = emb + aug_emb
416
+
417
+ # Flatten the batch and frames dimensions
418
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
419
+ sample = sample.flatten(0, 1)
420
+ # Repeat the embeddings num_video_frames times
421
+ # emb: [batch, channels] -> [batch * frames, channels]
422
+ emb = emb.repeat_interleave(num_frames, dim=0)
423
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
424
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
425
+
426
+ # 2. pre-process
427
+ sample = self.conv_in(sample)
428
+
429
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
430
+
431
+ down_block_res_samples = (sample,)
432
+ for downsample_block in self.down_blocks:
433
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
434
+ sample, res_samples = downsample_block(
435
+ hidden_states=sample,
436
+ temb=emb,
437
+ encoder_hidden_states=encoder_hidden_states,
438
+ image_only_indicator=image_only_indicator,
439
+ )
440
+ else:
441
+ sample, res_samples = downsample_block(
442
+ hidden_states=sample,
443
+ temb=emb,
444
+ image_only_indicator=image_only_indicator,
445
+ )
446
+
447
+ down_block_res_samples += res_samples
448
+
449
+ # 4. mid
450
+ sample = self.mid_block(
451
+ hidden_states=sample,
452
+ temb=emb,
453
+ encoder_hidden_states=encoder_hidden_states,
454
+ image_only_indicator=image_only_indicator,
455
+ )
456
+
457
+ # 5. up
458
+ for i, upsample_block in enumerate(self.up_blocks):
459
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
460
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
461
+
462
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
463
+ sample = upsample_block(
464
+ hidden_states=sample,
465
+ temb=emb,
466
+ res_hidden_states_tuple=res_samples,
467
+ encoder_hidden_states=encoder_hidden_states,
468
+ image_only_indicator=image_only_indicator,
469
+ )
470
+ else:
471
+ sample = upsample_block(
472
+ hidden_states=sample,
473
+ temb=emb,
474
+ res_hidden_states_tuple=res_samples,
475
+ image_only_indicator=image_only_indicator,
476
+ )
477
+
478
+ # 6. post-process
479
+ sample = self.conv_norm_out(sample)
480
+ sample = self.conv_act(sample)
481
+ sample = self.conv_out(sample)
482
+
483
+ # 7. Reshape back to original shape
484
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
485
+
486
+ if not return_dict:
487
+ return (sample,)
488
+
489
+ return UNetSpatioTemporalConditionOutput(sample=sample)