diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (238) hide show
  1. diffusers/__init__.py +26 -2
  2. diffusers/commands/fp16_safetensors.py +10 -11
  3. diffusers/configuration_utils.py +13 -8
  4. diffusers/dependency_versions_check.py +0 -1
  5. diffusers/dependency_versions_table.py +5 -5
  6. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  7. diffusers/image_processor.py +463 -51
  8. diffusers/loaders/__init__.py +82 -0
  9. diffusers/loaders/ip_adapter.py +159 -0
  10. diffusers/loaders/lora.py +1553 -0
  11. diffusers/loaders/lora_conversion_utils.py +284 -0
  12. diffusers/loaders/single_file.py +637 -0
  13. diffusers/loaders/textual_inversion.py +455 -0
  14. diffusers/loaders/unet.py +828 -0
  15. diffusers/loaders/utils.py +59 -0
  16. diffusers/models/__init__.py +26 -9
  17. diffusers/models/activations.py +9 -6
  18. diffusers/models/attention.py +301 -29
  19. diffusers/models/attention_flax.py +9 -1
  20. diffusers/models/attention_processor.py +378 -6
  21. diffusers/models/autoencoders/__init__.py +5 -0
  22. diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
  23. diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
  24. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
  25. diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
  26. diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
  27. diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
  28. diffusers/models/controlnet.py +59 -39
  29. diffusers/models/controlnet_flax.py +19 -18
  30. diffusers/models/downsampling.py +338 -0
  31. diffusers/models/embeddings.py +112 -29
  32. diffusers/models/embeddings_flax.py +2 -0
  33. diffusers/models/lora.py +131 -1
  34. diffusers/models/modeling_flax_utils.py +14 -8
  35. diffusers/models/modeling_outputs.py +17 -0
  36. diffusers/models/modeling_utils.py +37 -29
  37. diffusers/models/normalization.py +110 -4
  38. diffusers/models/resnet.py +299 -652
  39. diffusers/models/transformer_2d.py +22 -5
  40. diffusers/models/transformer_temporal.py +183 -1
  41. diffusers/models/unet_2d_blocks_flax.py +5 -0
  42. diffusers/models/unet_2d_condition.py +46 -0
  43. diffusers/models/unet_2d_condition_flax.py +13 -13
  44. diffusers/models/unet_3d_blocks.py +957 -173
  45. diffusers/models/unet_3d_condition.py +16 -8
  46. diffusers/models/unet_kandinsky3.py +535 -0
  47. diffusers/models/unet_motion_model.py +48 -33
  48. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  49. diffusers/models/upsampling.py +454 -0
  50. diffusers/models/uvit_2d.py +471 -0
  51. diffusers/models/vae_flax.py +7 -0
  52. diffusers/models/vq_model.py +12 -3
  53. diffusers/optimization.py +16 -9
  54. diffusers/pipelines/__init__.py +137 -76
  55. diffusers/pipelines/amused/__init__.py +62 -0
  56. diffusers/pipelines/amused/pipeline_amused.py +328 -0
  57. diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
  58. diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
  59. diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
  60. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  61. diffusers/pipelines/auto_pipeline.py +23 -13
  62. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  63. diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
  64. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
  65. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
  66. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
  67. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
  68. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
  69. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  70. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  71. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  72. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  73. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  74. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  75. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  76. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  77. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  78. diffusers/pipelines/deprecated/__init__.py +153 -0
  79. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
  80. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
  81. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
  82. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
  83. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
  84. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
  85. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
  86. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
  87. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
  88. diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
  89. diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
  90. diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
  91. diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
  92. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
  93. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
  94. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
  95. diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
  96. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
  97. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
  98. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
  100. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
  101. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
  102. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
  103. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
  104. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
  105. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
  106. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
  107. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
  108. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
  109. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
  110. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
  111. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
  112. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
  113. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
  114. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
  115. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
  116. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  117. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  118. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  119. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  120. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  122. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  123. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
  124. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
  125. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
  126. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
  127. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
  128. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  129. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  130. diffusers/pipelines/onnx_utils.py +8 -5
  131. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  132. diffusers/pipelines/pipeline_flax_utils.py +11 -8
  133. diffusers/pipelines/pipeline_utils.py +63 -42
  134. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
  135. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
  136. diffusers/pipelines/stable_diffusion/__init__.py +37 -65
  137. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
  138. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  139. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  140. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  141. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
  142. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  143. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  144. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
  145. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
  146. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
  147. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
  151. diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
  152. diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
  153. diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
  154. diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
  155. diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
  156. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
  157. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  158. diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
  159. diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
  160. diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
  161. diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
  162. diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
  163. diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
  164. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
  165. diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
  166. diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
  171. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  172. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
  175. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
  179. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
  180. diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
  181. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
  182. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  183. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  184. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  185. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  186. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  187. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
  188. diffusers/schedulers/__init__.py +4 -4
  189. diffusers/schedulers/deprecated/__init__.py +50 -0
  190. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  191. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  192. diffusers/schedulers/scheduling_amused.py +162 -0
  193. diffusers/schedulers/scheduling_consistency_models.py +2 -0
  194. diffusers/schedulers/scheduling_ddim.py +1 -3
  195. diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
  196. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  197. diffusers/schedulers/scheduling_ddpm.py +47 -3
  198. diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
  199. diffusers/schedulers/scheduling_deis_multistep.py +28 -6
  200. diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
  201. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
  202. diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
  203. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
  204. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
  205. diffusers/schedulers/scheduling_euler_discrete.py +102 -16
  206. diffusers/schedulers/scheduling_heun_discrete.py +17 -5
  207. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
  208. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
  209. diffusers/schedulers/scheduling_lcm.py +123 -29
  210. diffusers/schedulers/scheduling_lms_discrete.py +3 -3
  211. diffusers/schedulers/scheduling_pndm.py +1 -3
  212. diffusers/schedulers/scheduling_repaint.py +1 -3
  213. diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
  214. diffusers/schedulers/scheduling_utils.py +3 -1
  215. diffusers/schedulers/scheduling_utils_flax.py +3 -1
  216. diffusers/training_utils.py +1 -1
  217. diffusers/utils/__init__.py +1 -2
  218. diffusers/utils/constants.py +10 -12
  219. diffusers/utils/dummy_pt_objects.py +75 -0
  220. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  221. diffusers/utils/dynamic_modules_utils.py +18 -22
  222. diffusers/utils/export_utils.py +8 -3
  223. diffusers/utils/hub_utils.py +24 -36
  224. diffusers/utils/logging.py +11 -11
  225. diffusers/utils/outputs.py +5 -5
  226. diffusers/utils/peft_utils.py +88 -44
  227. diffusers/utils/state_dict_utils.py +8 -0
  228. diffusers/utils/testing_utils.py +199 -1
  229. diffusers/utils/torch_utils.py +4 -4
  230. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
  231. diffusers-0.25.0.dist-info/RECORD +360 -0
  232. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
  233. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
  234. diffusers/loaders.py +0 -3336
  235. diffusers-0.23.1.dist-info/RECORD +0 -323
  236. /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
  237. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
  238. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -12,40 +12,58 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple
15
+ from typing import Any, Dict, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from torch import nn
19
19
 
20
20
  from ..utils import is_torch_version
21
21
  from ..utils.torch_utils import apply_freeu
22
+ from .attention import Attention
22
23
  from .dual_transformer_2d import DualTransformer2DModel
23
- from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
24
+ from .resnet import (
25
+ Downsample2D,
26
+ ResnetBlock2D,
27
+ SpatioTemporalResBlock,
28
+ TemporalConvLayer,
29
+ Upsample2D,
30
+ )
24
31
  from .transformer_2d import Transformer2DModel
25
- from .transformer_temporal import TransformerTemporalModel
32
+ from .transformer_temporal import (
33
+ TransformerSpatioTemporalModel,
34
+ TransformerTemporalModel,
35
+ )
26
36
 
27
37
 
28
38
  def get_down_block(
29
- down_block_type,
30
- num_layers,
31
- in_channels,
32
- out_channels,
33
- temb_channels,
34
- add_downsample,
35
- resnet_eps,
36
- resnet_act_fn,
37
- num_attention_heads,
38
- resnet_groups=None,
39
- cross_attention_dim=None,
40
- downsample_padding=None,
41
- dual_cross_attention=False,
42
- use_linear_projection=True,
43
- only_cross_attention=False,
44
- upcast_attention=False,
45
- resnet_time_scale_shift="default",
46
- temporal_num_attention_heads=8,
47
- temporal_max_seq_length=32,
48
- ):
39
+ down_block_type: str,
40
+ num_layers: int,
41
+ in_channels: int,
42
+ out_channels: int,
43
+ temb_channels: int,
44
+ add_downsample: bool,
45
+ resnet_eps: float,
46
+ resnet_act_fn: str,
47
+ num_attention_heads: int,
48
+ resnet_groups: Optional[int] = None,
49
+ cross_attention_dim: Optional[int] = None,
50
+ downsample_padding: Optional[int] = None,
51
+ dual_cross_attention: bool = False,
52
+ use_linear_projection: bool = True,
53
+ only_cross_attention: bool = False,
54
+ upcast_attention: bool = False,
55
+ resnet_time_scale_shift: str = "default",
56
+ temporal_num_attention_heads: int = 8,
57
+ temporal_max_seq_length: int = 32,
58
+ transformer_layers_per_block: int = 1,
59
+ ) -> Union[
60
+ "DownBlock3D",
61
+ "CrossAttnDownBlock3D",
62
+ "DownBlockMotion",
63
+ "CrossAttnDownBlockMotion",
64
+ "DownBlockSpatioTemporal",
65
+ "CrossAttnDownBlockSpatioTemporal",
66
+ ]:
49
67
  if down_block_type == "DownBlock3D":
50
68
  return DownBlock3D(
51
69
  num_layers=num_layers,
@@ -118,33 +136,65 @@ def get_down_block(
118
136
  temporal_num_attention_heads=temporal_num_attention_heads,
119
137
  temporal_max_seq_length=temporal_max_seq_length,
120
138
  )
139
+ elif down_block_type == "DownBlockSpatioTemporal":
140
+ # added for SDV
141
+ return DownBlockSpatioTemporal(
142
+ num_layers=num_layers,
143
+ in_channels=in_channels,
144
+ out_channels=out_channels,
145
+ temb_channels=temb_channels,
146
+ add_downsample=add_downsample,
147
+ )
148
+ elif down_block_type == "CrossAttnDownBlockSpatioTemporal":
149
+ # added for SDV
150
+ if cross_attention_dim is None:
151
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal")
152
+ return CrossAttnDownBlockSpatioTemporal(
153
+ in_channels=in_channels,
154
+ out_channels=out_channels,
155
+ temb_channels=temb_channels,
156
+ num_layers=num_layers,
157
+ transformer_layers_per_block=transformer_layers_per_block,
158
+ add_downsample=add_downsample,
159
+ cross_attention_dim=cross_attention_dim,
160
+ num_attention_heads=num_attention_heads,
161
+ )
121
162
 
122
163
  raise ValueError(f"{down_block_type} does not exist.")
123
164
 
124
165
 
125
166
  def get_up_block(
126
- up_block_type,
127
- num_layers,
128
- in_channels,
129
- out_channels,
130
- prev_output_channel,
131
- temb_channels,
132
- add_upsample,
133
- resnet_eps,
134
- resnet_act_fn,
135
- num_attention_heads,
136
- resolution_idx=None,
137
- resnet_groups=None,
138
- cross_attention_dim=None,
139
- dual_cross_attention=False,
140
- use_linear_projection=True,
141
- only_cross_attention=False,
142
- upcast_attention=False,
143
- resnet_time_scale_shift="default",
144
- temporal_num_attention_heads=8,
145
- temporal_cross_attention_dim=None,
146
- temporal_max_seq_length=32,
147
- ):
167
+ up_block_type: str,
168
+ num_layers: int,
169
+ in_channels: int,
170
+ out_channels: int,
171
+ prev_output_channel: int,
172
+ temb_channels: int,
173
+ add_upsample: bool,
174
+ resnet_eps: float,
175
+ resnet_act_fn: str,
176
+ num_attention_heads: int,
177
+ resolution_idx: Optional[int] = None,
178
+ resnet_groups: Optional[int] = None,
179
+ cross_attention_dim: Optional[int] = None,
180
+ dual_cross_attention: bool = False,
181
+ use_linear_projection: bool = True,
182
+ only_cross_attention: bool = False,
183
+ upcast_attention: bool = False,
184
+ resnet_time_scale_shift: str = "default",
185
+ temporal_num_attention_heads: int = 8,
186
+ temporal_cross_attention_dim: Optional[int] = None,
187
+ temporal_max_seq_length: int = 32,
188
+ transformer_layers_per_block: int = 1,
189
+ dropout: float = 0.0,
190
+ ) -> Union[
191
+ "UpBlock3D",
192
+ "CrossAttnUpBlock3D",
193
+ "UpBlockMotion",
194
+ "CrossAttnUpBlockMotion",
195
+ "UpBlockSpatioTemporal",
196
+ "CrossAttnUpBlockSpatioTemporal",
197
+ ]:
148
198
  if up_block_type == "UpBlock3D":
149
199
  return UpBlock3D(
150
200
  num_layers=num_layers,
@@ -221,6 +271,34 @@ def get_up_block(
221
271
  temporal_num_attention_heads=temporal_num_attention_heads,
222
272
  temporal_max_seq_length=temporal_max_seq_length,
223
273
  )
274
+ elif up_block_type == "UpBlockSpatioTemporal":
275
+ # added for SDV
276
+ return UpBlockSpatioTemporal(
277
+ num_layers=num_layers,
278
+ in_channels=in_channels,
279
+ out_channels=out_channels,
280
+ prev_output_channel=prev_output_channel,
281
+ temb_channels=temb_channels,
282
+ resolution_idx=resolution_idx,
283
+ add_upsample=add_upsample,
284
+ )
285
+ elif up_block_type == "CrossAttnUpBlockSpatioTemporal":
286
+ # added for SDV
287
+ if cross_attention_dim is None:
288
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal")
289
+ return CrossAttnUpBlockSpatioTemporal(
290
+ in_channels=in_channels,
291
+ out_channels=out_channels,
292
+ prev_output_channel=prev_output_channel,
293
+ temb_channels=temb_channels,
294
+ num_layers=num_layers,
295
+ transformer_layers_per_block=transformer_layers_per_block,
296
+ add_upsample=add_upsample,
297
+ cross_attention_dim=cross_attention_dim,
298
+ num_attention_heads=num_attention_heads,
299
+ resolution_idx=resolution_idx,
300
+ )
301
+
224
302
  raise ValueError(f"{up_block_type} does not exist.")
225
303
 
226
304
 
@@ -236,12 +314,12 @@ class UNetMidBlock3DCrossAttn(nn.Module):
236
314
  resnet_act_fn: str = "swish",
237
315
  resnet_groups: int = 32,
238
316
  resnet_pre_norm: bool = True,
239
- num_attention_heads=1,
240
- output_scale_factor=1.0,
241
- cross_attention_dim=1280,
242
- dual_cross_attention=False,
243
- use_linear_projection=True,
244
- upcast_attention=False,
317
+ num_attention_heads: int = 1,
318
+ output_scale_factor: float = 1.0,
319
+ cross_attention_dim: int = 1280,
320
+ dual_cross_attention: bool = False,
321
+ use_linear_projection: bool = True,
322
+ upcast_attention: bool = False,
245
323
  ):
246
324
  super().__init__()
247
325
 
@@ -269,6 +347,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
269
347
  in_channels,
270
348
  in_channels,
271
349
  dropout=0.1,
350
+ norm_num_groups=resnet_groups,
272
351
  )
273
352
  ]
274
353
  attentions = []
@@ -316,6 +395,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
316
395
  in_channels,
317
396
  in_channels,
318
397
  dropout=0.1,
398
+ norm_num_groups=resnet_groups,
319
399
  )
320
400
  )
321
401
 
@@ -326,13 +406,13 @@ class UNetMidBlock3DCrossAttn(nn.Module):
326
406
 
327
407
  def forward(
328
408
  self,
329
- hidden_states,
330
- temb=None,
331
- encoder_hidden_states=None,
332
- attention_mask=None,
333
- num_frames=1,
334
- cross_attention_kwargs=None,
335
- ):
409
+ hidden_states: torch.FloatTensor,
410
+ temb: Optional[torch.FloatTensor] = None,
411
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
412
+ attention_mask: Optional[torch.FloatTensor] = None,
413
+ num_frames: int = 1,
414
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
415
+ ) -> torch.FloatTensor:
336
416
  hidden_states = self.resnets[0](hidden_states, temb)
337
417
  hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
338
418
  for attn, temp_attn, resnet, temp_conv in zip(
@@ -345,7 +425,10 @@ class UNetMidBlock3DCrossAttn(nn.Module):
345
425
  return_dict=False,
346
426
  )[0]
347
427
  hidden_states = temp_attn(
348
- hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
428
+ hidden_states,
429
+ num_frames=num_frames,
430
+ cross_attention_kwargs=cross_attention_kwargs,
431
+ return_dict=False,
349
432
  )[0]
350
433
  hidden_states = resnet(hidden_states, temb)
351
434
  hidden_states = temp_conv(hidden_states, num_frames=num_frames)
@@ -366,15 +449,15 @@ class CrossAttnDownBlock3D(nn.Module):
366
449
  resnet_act_fn: str = "swish",
367
450
  resnet_groups: int = 32,
368
451
  resnet_pre_norm: bool = True,
369
- num_attention_heads=1,
370
- cross_attention_dim=1280,
371
- output_scale_factor=1.0,
372
- downsample_padding=1,
373
- add_downsample=True,
374
- dual_cross_attention=False,
375
- use_linear_projection=False,
376
- only_cross_attention=False,
377
- upcast_attention=False,
452
+ num_attention_heads: int = 1,
453
+ cross_attention_dim: int = 1280,
454
+ output_scale_factor: float = 1.0,
455
+ downsample_padding: int = 1,
456
+ add_downsample: bool = True,
457
+ dual_cross_attention: bool = False,
458
+ use_linear_projection: bool = False,
459
+ only_cross_attention: bool = False,
460
+ upcast_attention: bool = False,
378
461
  ):
379
462
  super().__init__()
380
463
  resnets = []
@@ -406,6 +489,7 @@ class CrossAttnDownBlock3D(nn.Module):
406
489
  out_channels,
407
490
  out_channels,
408
491
  dropout=0.1,
492
+ norm_num_groups=resnet_groups,
409
493
  )
410
494
  )
411
495
  attentions.append(
@@ -440,7 +524,11 @@ class CrossAttnDownBlock3D(nn.Module):
440
524
  self.downsamplers = nn.ModuleList(
441
525
  [
442
526
  Downsample2D(
443
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
527
+ out_channels,
528
+ use_conv=True,
529
+ out_channels=out_channels,
530
+ padding=downsample_padding,
531
+ name="op",
444
532
  )
445
533
  ]
446
534
  )
@@ -451,13 +539,13 @@ class CrossAttnDownBlock3D(nn.Module):
451
539
 
452
540
  def forward(
453
541
  self,
454
- hidden_states,
455
- temb=None,
456
- encoder_hidden_states=None,
457
- attention_mask=None,
458
- num_frames=1,
459
- cross_attention_kwargs=None,
460
- ):
542
+ hidden_states: torch.FloatTensor,
543
+ temb: Optional[torch.FloatTensor] = None,
544
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
545
+ attention_mask: Optional[torch.FloatTensor] = None,
546
+ num_frames: int = 1,
547
+ cross_attention_kwargs: Dict[str, Any] = None,
548
+ ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
461
549
  # TODO(Patrick, William) - attention mask is not used
462
550
  output_states = ()
463
551
 
@@ -473,7 +561,10 @@ class CrossAttnDownBlock3D(nn.Module):
473
561
  return_dict=False,
474
562
  )[0]
475
563
  hidden_states = temp_attn(
476
- hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
564
+ hidden_states,
565
+ num_frames=num_frames,
566
+ cross_attention_kwargs=cross_attention_kwargs,
567
+ return_dict=False,
477
568
  )[0]
478
569
 
479
570
  output_states += (hidden_states,)
@@ -500,9 +591,9 @@ class DownBlock3D(nn.Module):
500
591
  resnet_act_fn: str = "swish",
501
592
  resnet_groups: int = 32,
502
593
  resnet_pre_norm: bool = True,
503
- output_scale_factor=1.0,
504
- add_downsample=True,
505
- downsample_padding=1,
594
+ output_scale_factor: float = 1.0,
595
+ add_downsample: bool = True,
596
+ downsample_padding: int = 1,
506
597
  ):
507
598
  super().__init__()
508
599
  resnets = []
@@ -529,6 +620,7 @@ class DownBlock3D(nn.Module):
529
620
  out_channels,
530
621
  out_channels,
531
622
  dropout=0.1,
623
+ norm_num_groups=resnet_groups,
532
624
  )
533
625
  )
534
626
 
@@ -539,7 +631,11 @@ class DownBlock3D(nn.Module):
539
631
  self.downsamplers = nn.ModuleList(
540
632
  [
541
633
  Downsample2D(
542
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
634
+ out_channels,
635
+ use_conv=True,
636
+ out_channels=out_channels,
637
+ padding=downsample_padding,
638
+ name="op",
543
639
  )
544
640
  ]
545
641
  )
@@ -548,7 +644,12 @@ class DownBlock3D(nn.Module):
548
644
 
549
645
  self.gradient_checkpointing = False
550
646
 
551
- def forward(self, hidden_states, temb=None, num_frames=1):
647
+ def forward(
648
+ self,
649
+ hidden_states: torch.FloatTensor,
650
+ temb: Optional[torch.FloatTensor] = None,
651
+ num_frames: int = 1,
652
+ ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
552
653
  output_states = ()
553
654
 
554
655
  for resnet, temp_conv in zip(self.resnets, self.temp_convs):
@@ -580,15 +681,15 @@ class CrossAttnUpBlock3D(nn.Module):
580
681
  resnet_act_fn: str = "swish",
581
682
  resnet_groups: int = 32,
582
683
  resnet_pre_norm: bool = True,
583
- num_attention_heads=1,
584
- cross_attention_dim=1280,
585
- output_scale_factor=1.0,
586
- add_upsample=True,
587
- dual_cross_attention=False,
588
- use_linear_projection=False,
589
- only_cross_attention=False,
590
- upcast_attention=False,
591
- resolution_idx=None,
684
+ num_attention_heads: int = 1,
685
+ cross_attention_dim: int = 1280,
686
+ output_scale_factor: float = 1.0,
687
+ add_upsample: bool = True,
688
+ dual_cross_attention: bool = False,
689
+ use_linear_projection: bool = False,
690
+ only_cross_attention: bool = False,
691
+ upcast_attention: bool = False,
692
+ resolution_idx: Optional[int] = None,
592
693
  ):
593
694
  super().__init__()
594
695
  resnets = []
@@ -622,6 +723,7 @@ class CrossAttnUpBlock3D(nn.Module):
622
723
  out_channels,
623
724
  out_channels,
624
725
  dropout=0.1,
726
+ norm_num_groups=resnet_groups,
625
727
  )
626
728
  )
627
729
  attentions.append(
@@ -662,15 +764,15 @@ class CrossAttnUpBlock3D(nn.Module):
662
764
 
663
765
  def forward(
664
766
  self,
665
- hidden_states,
666
- res_hidden_states_tuple,
667
- temb=None,
668
- encoder_hidden_states=None,
669
- upsample_size=None,
670
- attention_mask=None,
671
- num_frames=1,
672
- cross_attention_kwargs=None,
673
- ):
767
+ hidden_states: torch.FloatTensor,
768
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
769
+ temb: Optional[torch.FloatTensor] = None,
770
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
771
+ upsample_size: Optional[int] = None,
772
+ attention_mask: Optional[torch.FloatTensor] = None,
773
+ num_frames: int = 1,
774
+ cross_attention_kwargs: Dict[str, Any] = None,
775
+ ) -> torch.FloatTensor:
674
776
  is_freeu_enabled = (
675
777
  getattr(self, "s1", None)
676
778
  and getattr(self, "s2", None)
@@ -709,7 +811,10 @@ class CrossAttnUpBlock3D(nn.Module):
709
811
  return_dict=False,
710
812
  )[0]
711
813
  hidden_states = temp_attn(
712
- hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
814
+ hidden_states,
815
+ num_frames=num_frames,
816
+ cross_attention_kwargs=cross_attention_kwargs,
817
+ return_dict=False,
713
818
  )[0]
714
819
 
715
820
  if self.upsamplers is not None:
@@ -733,9 +838,9 @@ class UpBlock3D(nn.Module):
733
838
  resnet_act_fn: str = "swish",
734
839
  resnet_groups: int = 32,
735
840
  resnet_pre_norm: bool = True,
736
- output_scale_factor=1.0,
737
- add_upsample=True,
738
- resolution_idx=None,
841
+ output_scale_factor: float = 1.0,
842
+ add_upsample: bool = True,
843
+ resolution_idx: Optional[int] = None,
739
844
  ):
740
845
  super().__init__()
741
846
  resnets = []
@@ -764,6 +869,7 @@ class UpBlock3D(nn.Module):
764
869
  out_channels,
765
870
  out_channels,
766
871
  dropout=0.1,
872
+ norm_num_groups=resnet_groups,
767
873
  )
768
874
  )
769
875
 
@@ -778,7 +884,14 @@ class UpBlock3D(nn.Module):
778
884
  self.gradient_checkpointing = False
779
885
  self.resolution_idx = resolution_idx
780
886
 
781
- def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
887
+ def forward(
888
+ self,
889
+ hidden_states: torch.FloatTensor,
890
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
891
+ temb: Optional[torch.FloatTensor] = None,
892
+ upsample_size: Optional[int] = None,
893
+ num_frames: int = 1,
894
+ ) -> torch.FloatTensor:
782
895
  is_freeu_enabled = (
783
896
  getattr(self, "s1", None)
784
897
  and getattr(self, "s2", None)
@@ -827,12 +940,12 @@ class DownBlockMotion(nn.Module):
827
940
  resnet_act_fn: str = "swish",
828
941
  resnet_groups: int = 32,
829
942
  resnet_pre_norm: bool = True,
830
- output_scale_factor=1.0,
831
- add_downsample=True,
832
- downsample_padding=1,
833
- temporal_num_attention_heads=1,
834
- temporal_cross_attention_dim=None,
835
- temporal_max_seq_length=32,
943
+ output_scale_factor: float = 1.0,
944
+ add_downsample: bool = True,
945
+ downsample_padding: int = 1,
946
+ temporal_num_attention_heads: int = 1,
947
+ temporal_cross_attention_dim: Optional[int] = None,
948
+ temporal_max_seq_length: int = 32,
836
949
  ):
837
950
  super().__init__()
838
951
  resnets = []
@@ -875,7 +988,11 @@ class DownBlockMotion(nn.Module):
875
988
  self.downsamplers = nn.ModuleList(
876
989
  [
877
990
  Downsample2D(
878
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
991
+ out_channels,
992
+ use_conv=True,
993
+ out_channels=out_channels,
994
+ padding=downsample_padding,
995
+ name="op",
879
996
  )
880
997
  ]
881
998
  )
@@ -884,7 +1001,13 @@ class DownBlockMotion(nn.Module):
884
1001
 
885
1002
  self.gradient_checkpointing = False
886
1003
 
887
- def forward(self, hidden_states, temb=None, scale: float = 1.0, num_frames=1):
1004
+ def forward(
1005
+ self,
1006
+ hidden_states: torch.FloatTensor,
1007
+ temb: Optional[torch.FloatTensor] = None,
1008
+ scale: float = 1.0,
1009
+ num_frames: int = 1,
1010
+ ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
888
1011
  output_states = ()
889
1012
 
890
1013
  blocks = zip(self.resnets, self.motion_modules)
@@ -899,14 +1022,20 @@ class DownBlockMotion(nn.Module):
899
1022
 
900
1023
  if is_torch_version(">=", "1.11.0"):
901
1024
  hidden_states = torch.utils.checkpoint.checkpoint(
902
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
1025
+ create_custom_forward(resnet),
1026
+ hidden_states,
1027
+ temb,
1028
+ use_reentrant=False,
903
1029
  )
904
1030
  else:
905
1031
  hidden_states = torch.utils.checkpoint.checkpoint(
906
1032
  create_custom_forward(resnet), hidden_states, temb, scale
907
1033
  )
908
1034
  hidden_states = torch.utils.checkpoint.checkpoint(
909
- create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, num_frames
1035
+ create_custom_forward(motion_module),
1036
+ hidden_states.requires_grad_(),
1037
+ temb,
1038
+ num_frames,
910
1039
  )
911
1040
 
912
1041
  else:
@@ -938,19 +1067,19 @@ class CrossAttnDownBlockMotion(nn.Module):
938
1067
  resnet_act_fn: str = "swish",
939
1068
  resnet_groups: int = 32,
940
1069
  resnet_pre_norm: bool = True,
941
- num_attention_heads=1,
942
- cross_attention_dim=1280,
943
- output_scale_factor=1.0,
944
- downsample_padding=1,
945
- add_downsample=True,
946
- dual_cross_attention=False,
947
- use_linear_projection=False,
948
- only_cross_attention=False,
949
- upcast_attention=False,
950
- attention_type="default",
951
- temporal_cross_attention_dim=None,
952
- temporal_num_attention_heads=8,
953
- temporal_max_seq_length=32,
1070
+ num_attention_heads: int = 1,
1071
+ cross_attention_dim: int = 1280,
1072
+ output_scale_factor: float = 1.0,
1073
+ downsample_padding: int = 1,
1074
+ add_downsample: bool = True,
1075
+ dual_cross_attention: bool = False,
1076
+ use_linear_projection: bool = False,
1077
+ only_cross_attention: bool = False,
1078
+ upcast_attention: bool = False,
1079
+ attention_type: str = "default",
1080
+ temporal_cross_attention_dim: Optional[int] = None,
1081
+ temporal_num_attention_heads: int = 8,
1082
+ temporal_max_seq_length: int = 32,
954
1083
  ):
955
1084
  super().__init__()
956
1085
  resnets = []
@@ -1026,7 +1155,11 @@ class CrossAttnDownBlockMotion(nn.Module):
1026
1155
  self.downsamplers = nn.ModuleList(
1027
1156
  [
1028
1157
  Downsample2D(
1029
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
1158
+ out_channels,
1159
+ use_conv=True,
1160
+ out_channels=out_channels,
1161
+ padding=downsample_padding,
1162
+ name="op",
1030
1163
  )
1031
1164
  ]
1032
1165
  )
@@ -1037,14 +1170,14 @@ class CrossAttnDownBlockMotion(nn.Module):
1037
1170
 
1038
1171
  def forward(
1039
1172
  self,
1040
- hidden_states,
1041
- temb=None,
1042
- encoder_hidden_states=None,
1043
- attention_mask=None,
1044
- num_frames=1,
1045
- encoder_attention_mask=None,
1046
- cross_attention_kwargs=None,
1047
- additional_residuals=None,
1173
+ hidden_states: torch.FloatTensor,
1174
+ temb: Optional[torch.FloatTensor] = None,
1175
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1176
+ attention_mask: Optional[torch.FloatTensor] = None,
1177
+ num_frames: int = 1,
1178
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1179
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1180
+ additional_residuals: Optional[torch.FloatTensor] = None,
1048
1181
  ):
1049
1182
  output_states = ()
1050
1183
 
@@ -1115,7 +1248,7 @@ class CrossAttnUpBlockMotion(nn.Module):
1115
1248
  out_channels: int,
1116
1249
  prev_output_channel: int,
1117
1250
  temb_channels: int,
1118
- resolution_idx: int = None,
1251
+ resolution_idx: Optional[int] = None,
1119
1252
  dropout: float = 0.0,
1120
1253
  num_layers: int = 1,
1121
1254
  transformer_layers_per_block: int = 1,
@@ -1124,18 +1257,18 @@ class CrossAttnUpBlockMotion(nn.Module):
1124
1257
  resnet_act_fn: str = "swish",
1125
1258
  resnet_groups: int = 32,
1126
1259
  resnet_pre_norm: bool = True,
1127
- num_attention_heads=1,
1128
- cross_attention_dim=1280,
1129
- output_scale_factor=1.0,
1130
- add_upsample=True,
1131
- dual_cross_attention=False,
1132
- use_linear_projection=False,
1133
- only_cross_attention=False,
1134
- upcast_attention=False,
1135
- attention_type="default",
1136
- temporal_cross_attention_dim=None,
1137
- temporal_num_attention_heads=8,
1138
- temporal_max_seq_length=32,
1260
+ num_attention_heads: int = 1,
1261
+ cross_attention_dim: int = 1280,
1262
+ output_scale_factor: float = 1.0,
1263
+ add_upsample: bool = True,
1264
+ dual_cross_attention: bool = False,
1265
+ use_linear_projection: bool = False,
1266
+ only_cross_attention: bool = False,
1267
+ upcast_attention: bool = False,
1268
+ attention_type: str = "default",
1269
+ temporal_cross_attention_dim: Optional[int] = None,
1270
+ temporal_num_attention_heads: int = 8,
1271
+ temporal_max_seq_length: int = 32,
1139
1272
  ):
1140
1273
  super().__init__()
1141
1274
  resnets = []
@@ -1226,8 +1359,8 @@ class CrossAttnUpBlockMotion(nn.Module):
1226
1359
  upsample_size: Optional[int] = None,
1227
1360
  attention_mask: Optional[torch.FloatTensor] = None,
1228
1361
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
1229
- num_frames=1,
1230
- ):
1362
+ num_frames: int = 1,
1363
+ ) -> torch.FloatTensor:
1231
1364
  lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1232
1365
  is_freeu_enabled = (
1233
1366
  getattr(self, "s1", None)
@@ -1311,7 +1444,7 @@ class UpBlockMotion(nn.Module):
1311
1444
  prev_output_channel: int,
1312
1445
  out_channels: int,
1313
1446
  temb_channels: int,
1314
- resolution_idx: int = None,
1447
+ resolution_idx: Optional[int] = None,
1315
1448
  dropout: float = 0.0,
1316
1449
  num_layers: int = 1,
1317
1450
  resnet_eps: float = 1e-6,
@@ -1319,12 +1452,12 @@ class UpBlockMotion(nn.Module):
1319
1452
  resnet_act_fn: str = "swish",
1320
1453
  resnet_groups: int = 32,
1321
1454
  resnet_pre_norm: bool = True,
1322
- output_scale_factor=1.0,
1323
- add_upsample=True,
1324
- temporal_norm_num_groups=32,
1325
- temporal_cross_attention_dim=None,
1326
- temporal_num_attention_heads=8,
1327
- temporal_max_seq_length=32,
1455
+ output_scale_factor: float = 1.0,
1456
+ add_upsample: bool = True,
1457
+ temporal_norm_num_groups: int = 32,
1458
+ temporal_cross_attention_dim: Optional[int] = None,
1459
+ temporal_num_attention_heads: int = 8,
1460
+ temporal_max_seq_length: int = 32,
1328
1461
  ):
1329
1462
  super().__init__()
1330
1463
  resnets = []
@@ -1375,8 +1508,14 @@ class UpBlockMotion(nn.Module):
1375
1508
  self.resolution_idx = resolution_idx
1376
1509
 
1377
1510
  def forward(
1378
- self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0, num_frames=1
1379
- ):
1511
+ self,
1512
+ hidden_states: torch.FloatTensor,
1513
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1514
+ temb: Optional[torch.FloatTensor] = None,
1515
+ upsample_size=None,
1516
+ scale: float = 1.0,
1517
+ num_frames: int = 1,
1518
+ ) -> torch.FloatTensor:
1380
1519
  is_freeu_enabled = (
1381
1520
  getattr(self, "s1", None)
1382
1521
  and getattr(self, "s2", None)
@@ -1415,7 +1554,10 @@ class UpBlockMotion(nn.Module):
1415
1554
 
1416
1555
  if is_torch_version(">=", "1.11.0"):
1417
1556
  hidden_states = torch.utils.checkpoint.checkpoint(
1418
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
1557
+ create_custom_forward(resnet),
1558
+ hidden_states,
1559
+ temb,
1560
+ use_reentrant=False,
1419
1561
  )
1420
1562
  else:
1421
1563
  hidden_states = torch.utils.checkpoint.checkpoint(
@@ -1451,16 +1593,16 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
1451
1593
  resnet_act_fn: str = "swish",
1452
1594
  resnet_groups: int = 32,
1453
1595
  resnet_pre_norm: bool = True,
1454
- num_attention_heads=1,
1455
- output_scale_factor=1.0,
1456
- cross_attention_dim=1280,
1457
- dual_cross_attention=False,
1458
- use_linear_projection=False,
1459
- upcast_attention=False,
1460
- attention_type="default",
1461
- temporal_num_attention_heads=1,
1462
- temporal_cross_attention_dim=None,
1463
- temporal_max_seq_length=32,
1596
+ num_attention_heads: int = 1,
1597
+ output_scale_factor: float = 1.0,
1598
+ cross_attention_dim: int = 1280,
1599
+ dual_cross_attention: float = False,
1600
+ use_linear_projection: float = False,
1601
+ upcast_attention: float = False,
1602
+ attention_type: str = "default",
1603
+ temporal_num_attention_heads: int = 1,
1604
+ temporal_cross_attention_dim: Optional[int] = None,
1605
+ temporal_max_seq_length: int = 32,
1464
1606
  ):
1465
1607
  super().__init__()
1466
1608
 
@@ -1554,7 +1696,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
1554
1696
  attention_mask: Optional[torch.FloatTensor] = None,
1555
1697
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1556
1698
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
1557
- num_frames=1,
1699
+ num_frames: int = 1,
1558
1700
  ) -> torch.FloatTensor:
1559
1701
  lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1560
1702
  hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
@@ -1609,3 +1751,645 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
1609
1751
  hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1610
1752
 
1611
1753
  return hidden_states
1754
+
1755
+
1756
+ class MidBlockTemporalDecoder(nn.Module):
1757
+ def __init__(
1758
+ self,
1759
+ in_channels: int,
1760
+ out_channels: int,
1761
+ attention_head_dim: int = 512,
1762
+ num_layers: int = 1,
1763
+ upcast_attention: bool = False,
1764
+ ):
1765
+ super().__init__()
1766
+
1767
+ resnets = []
1768
+ attentions = []
1769
+ for i in range(num_layers):
1770
+ input_channels = in_channels if i == 0 else out_channels
1771
+ resnets.append(
1772
+ SpatioTemporalResBlock(
1773
+ in_channels=input_channels,
1774
+ out_channels=out_channels,
1775
+ temb_channels=None,
1776
+ eps=1e-6,
1777
+ temporal_eps=1e-5,
1778
+ merge_factor=0.0,
1779
+ merge_strategy="learned",
1780
+ switch_spatial_to_temporal_mix=True,
1781
+ )
1782
+ )
1783
+
1784
+ attentions.append(
1785
+ Attention(
1786
+ query_dim=in_channels,
1787
+ heads=in_channels // attention_head_dim,
1788
+ dim_head=attention_head_dim,
1789
+ eps=1e-6,
1790
+ upcast_attention=upcast_attention,
1791
+ norm_num_groups=32,
1792
+ bias=True,
1793
+ residual_connection=True,
1794
+ )
1795
+ )
1796
+
1797
+ self.attentions = nn.ModuleList(attentions)
1798
+ self.resnets = nn.ModuleList(resnets)
1799
+
1800
+ def forward(
1801
+ self,
1802
+ hidden_states: torch.FloatTensor,
1803
+ image_only_indicator: torch.FloatTensor,
1804
+ ):
1805
+ hidden_states = self.resnets[0](
1806
+ hidden_states,
1807
+ image_only_indicator=image_only_indicator,
1808
+ )
1809
+ for resnet, attn in zip(self.resnets[1:], self.attentions):
1810
+ hidden_states = attn(hidden_states)
1811
+ hidden_states = resnet(
1812
+ hidden_states,
1813
+ image_only_indicator=image_only_indicator,
1814
+ )
1815
+
1816
+ return hidden_states
1817
+
1818
+
1819
+ class UpBlockTemporalDecoder(nn.Module):
1820
+ def __init__(
1821
+ self,
1822
+ in_channels: int,
1823
+ out_channels: int,
1824
+ num_layers: int = 1,
1825
+ add_upsample: bool = True,
1826
+ ):
1827
+ super().__init__()
1828
+ resnets = []
1829
+ for i in range(num_layers):
1830
+ input_channels = in_channels if i == 0 else out_channels
1831
+
1832
+ resnets.append(
1833
+ SpatioTemporalResBlock(
1834
+ in_channels=input_channels,
1835
+ out_channels=out_channels,
1836
+ temb_channels=None,
1837
+ eps=1e-6,
1838
+ temporal_eps=1e-5,
1839
+ merge_factor=0.0,
1840
+ merge_strategy="learned",
1841
+ switch_spatial_to_temporal_mix=True,
1842
+ )
1843
+ )
1844
+ self.resnets = nn.ModuleList(resnets)
1845
+
1846
+ if add_upsample:
1847
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1848
+ else:
1849
+ self.upsamplers = None
1850
+
1851
+ def forward(
1852
+ self,
1853
+ hidden_states: torch.FloatTensor,
1854
+ image_only_indicator: torch.FloatTensor,
1855
+ ) -> torch.FloatTensor:
1856
+ for resnet in self.resnets:
1857
+ hidden_states = resnet(
1858
+ hidden_states,
1859
+ image_only_indicator=image_only_indicator,
1860
+ )
1861
+
1862
+ if self.upsamplers is not None:
1863
+ for upsampler in self.upsamplers:
1864
+ hidden_states = upsampler(hidden_states)
1865
+
1866
+ return hidden_states
1867
+
1868
+
1869
+ class UNetMidBlockSpatioTemporal(nn.Module):
1870
+ def __init__(
1871
+ self,
1872
+ in_channels: int,
1873
+ temb_channels: int,
1874
+ num_layers: int = 1,
1875
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
1876
+ num_attention_heads: int = 1,
1877
+ cross_attention_dim: int = 1280,
1878
+ ):
1879
+ super().__init__()
1880
+
1881
+ self.has_cross_attention = True
1882
+ self.num_attention_heads = num_attention_heads
1883
+
1884
+ # support for variable transformer layers per block
1885
+ if isinstance(transformer_layers_per_block, int):
1886
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
1887
+
1888
+ # there is always at least one resnet
1889
+ resnets = [
1890
+ SpatioTemporalResBlock(
1891
+ in_channels=in_channels,
1892
+ out_channels=in_channels,
1893
+ temb_channels=temb_channels,
1894
+ eps=1e-5,
1895
+ )
1896
+ ]
1897
+ attentions = []
1898
+
1899
+ for i in range(num_layers):
1900
+ attentions.append(
1901
+ TransformerSpatioTemporalModel(
1902
+ num_attention_heads,
1903
+ in_channels // num_attention_heads,
1904
+ in_channels=in_channels,
1905
+ num_layers=transformer_layers_per_block[i],
1906
+ cross_attention_dim=cross_attention_dim,
1907
+ )
1908
+ )
1909
+
1910
+ resnets.append(
1911
+ SpatioTemporalResBlock(
1912
+ in_channels=in_channels,
1913
+ out_channels=in_channels,
1914
+ temb_channels=temb_channels,
1915
+ eps=1e-5,
1916
+ )
1917
+ )
1918
+
1919
+ self.attentions = nn.ModuleList(attentions)
1920
+ self.resnets = nn.ModuleList(resnets)
1921
+
1922
+ self.gradient_checkpointing = False
1923
+
1924
+ def forward(
1925
+ self,
1926
+ hidden_states: torch.FloatTensor,
1927
+ temb: Optional[torch.FloatTensor] = None,
1928
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1929
+ image_only_indicator: Optional[torch.Tensor] = None,
1930
+ ) -> torch.FloatTensor:
1931
+ hidden_states = self.resnets[0](
1932
+ hidden_states,
1933
+ temb,
1934
+ image_only_indicator=image_only_indicator,
1935
+ )
1936
+
1937
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
1938
+ if self.training and self.gradient_checkpointing: # TODO
1939
+
1940
+ def create_custom_forward(module, return_dict=None):
1941
+ def custom_forward(*inputs):
1942
+ if return_dict is not None:
1943
+ return module(*inputs, return_dict=return_dict)
1944
+ else:
1945
+ return module(*inputs)
1946
+
1947
+ return custom_forward
1948
+
1949
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1950
+ hidden_states = attn(
1951
+ hidden_states,
1952
+ encoder_hidden_states=encoder_hidden_states,
1953
+ image_only_indicator=image_only_indicator,
1954
+ return_dict=False,
1955
+ )[0]
1956
+ hidden_states = torch.utils.checkpoint.checkpoint(
1957
+ create_custom_forward(resnet),
1958
+ hidden_states,
1959
+ temb,
1960
+ image_only_indicator,
1961
+ **ckpt_kwargs,
1962
+ )
1963
+ else:
1964
+ hidden_states = attn(
1965
+ hidden_states,
1966
+ encoder_hidden_states=encoder_hidden_states,
1967
+ image_only_indicator=image_only_indicator,
1968
+ return_dict=False,
1969
+ )[0]
1970
+ hidden_states = resnet(
1971
+ hidden_states,
1972
+ temb,
1973
+ image_only_indicator=image_only_indicator,
1974
+ )
1975
+
1976
+ return hidden_states
1977
+
1978
+
1979
+ class DownBlockSpatioTemporal(nn.Module):
1980
+ def __init__(
1981
+ self,
1982
+ in_channels: int,
1983
+ out_channels: int,
1984
+ temb_channels: int,
1985
+ num_layers: int = 1,
1986
+ add_downsample: bool = True,
1987
+ ):
1988
+ super().__init__()
1989
+ resnets = []
1990
+
1991
+ for i in range(num_layers):
1992
+ in_channels = in_channels if i == 0 else out_channels
1993
+ resnets.append(
1994
+ SpatioTemporalResBlock(
1995
+ in_channels=in_channels,
1996
+ out_channels=out_channels,
1997
+ temb_channels=temb_channels,
1998
+ eps=1e-5,
1999
+ )
2000
+ )
2001
+
2002
+ self.resnets = nn.ModuleList(resnets)
2003
+
2004
+ if add_downsample:
2005
+ self.downsamplers = nn.ModuleList(
2006
+ [
2007
+ Downsample2D(
2008
+ out_channels,
2009
+ use_conv=True,
2010
+ out_channels=out_channels,
2011
+ name="op",
2012
+ )
2013
+ ]
2014
+ )
2015
+ else:
2016
+ self.downsamplers = None
2017
+
2018
+ self.gradient_checkpointing = False
2019
+
2020
+ def forward(
2021
+ self,
2022
+ hidden_states: torch.FloatTensor,
2023
+ temb: Optional[torch.FloatTensor] = None,
2024
+ image_only_indicator: Optional[torch.Tensor] = None,
2025
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
2026
+ output_states = ()
2027
+ for resnet in self.resnets:
2028
+ if self.training and self.gradient_checkpointing:
2029
+
2030
+ def create_custom_forward(module):
2031
+ def custom_forward(*inputs):
2032
+ return module(*inputs)
2033
+
2034
+ return custom_forward
2035
+
2036
+ if is_torch_version(">=", "1.11.0"):
2037
+ hidden_states = torch.utils.checkpoint.checkpoint(
2038
+ create_custom_forward(resnet),
2039
+ hidden_states,
2040
+ temb,
2041
+ image_only_indicator,
2042
+ use_reentrant=False,
2043
+ )
2044
+ else:
2045
+ hidden_states = torch.utils.checkpoint.checkpoint(
2046
+ create_custom_forward(resnet),
2047
+ hidden_states,
2048
+ temb,
2049
+ image_only_indicator,
2050
+ )
2051
+ else:
2052
+ hidden_states = resnet(
2053
+ hidden_states,
2054
+ temb,
2055
+ image_only_indicator=image_only_indicator,
2056
+ )
2057
+
2058
+ output_states = output_states + (hidden_states,)
2059
+
2060
+ if self.downsamplers is not None:
2061
+ for downsampler in self.downsamplers:
2062
+ hidden_states = downsampler(hidden_states)
2063
+
2064
+ output_states = output_states + (hidden_states,)
2065
+
2066
+ return hidden_states, output_states
2067
+
2068
+
2069
+ class CrossAttnDownBlockSpatioTemporal(nn.Module):
2070
+ def __init__(
2071
+ self,
2072
+ in_channels: int,
2073
+ out_channels: int,
2074
+ temb_channels: int,
2075
+ num_layers: int = 1,
2076
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
2077
+ num_attention_heads: int = 1,
2078
+ cross_attention_dim: int = 1280,
2079
+ add_downsample: bool = True,
2080
+ ):
2081
+ super().__init__()
2082
+ resnets = []
2083
+ attentions = []
2084
+
2085
+ self.has_cross_attention = True
2086
+ self.num_attention_heads = num_attention_heads
2087
+ if isinstance(transformer_layers_per_block, int):
2088
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
2089
+
2090
+ for i in range(num_layers):
2091
+ in_channels = in_channels if i == 0 else out_channels
2092
+ resnets.append(
2093
+ SpatioTemporalResBlock(
2094
+ in_channels=in_channels,
2095
+ out_channels=out_channels,
2096
+ temb_channels=temb_channels,
2097
+ eps=1e-6,
2098
+ )
2099
+ )
2100
+ attentions.append(
2101
+ TransformerSpatioTemporalModel(
2102
+ num_attention_heads,
2103
+ out_channels // num_attention_heads,
2104
+ in_channels=out_channels,
2105
+ num_layers=transformer_layers_per_block[i],
2106
+ cross_attention_dim=cross_attention_dim,
2107
+ )
2108
+ )
2109
+
2110
+ self.attentions = nn.ModuleList(attentions)
2111
+ self.resnets = nn.ModuleList(resnets)
2112
+
2113
+ if add_downsample:
2114
+ self.downsamplers = nn.ModuleList(
2115
+ [
2116
+ Downsample2D(
2117
+ out_channels,
2118
+ use_conv=True,
2119
+ out_channels=out_channels,
2120
+ padding=1,
2121
+ name="op",
2122
+ )
2123
+ ]
2124
+ )
2125
+ else:
2126
+ self.downsamplers = None
2127
+
2128
+ self.gradient_checkpointing = False
2129
+
2130
+ def forward(
2131
+ self,
2132
+ hidden_states: torch.FloatTensor,
2133
+ temb: Optional[torch.FloatTensor] = None,
2134
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
2135
+ image_only_indicator: Optional[torch.Tensor] = None,
2136
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
2137
+ output_states = ()
2138
+
2139
+ blocks = list(zip(self.resnets, self.attentions))
2140
+ for resnet, attn in blocks:
2141
+ if self.training and self.gradient_checkpointing: # TODO
2142
+
2143
+ def create_custom_forward(module, return_dict=None):
2144
+ def custom_forward(*inputs):
2145
+ if return_dict is not None:
2146
+ return module(*inputs, return_dict=return_dict)
2147
+ else:
2148
+ return module(*inputs)
2149
+
2150
+ return custom_forward
2151
+
2152
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
2153
+ hidden_states = torch.utils.checkpoint.checkpoint(
2154
+ create_custom_forward(resnet),
2155
+ hidden_states,
2156
+ temb,
2157
+ image_only_indicator,
2158
+ **ckpt_kwargs,
2159
+ )
2160
+
2161
+ hidden_states = attn(
2162
+ hidden_states,
2163
+ encoder_hidden_states=encoder_hidden_states,
2164
+ image_only_indicator=image_only_indicator,
2165
+ return_dict=False,
2166
+ )[0]
2167
+ else:
2168
+ hidden_states = resnet(
2169
+ hidden_states,
2170
+ temb,
2171
+ image_only_indicator=image_only_indicator,
2172
+ )
2173
+ hidden_states = attn(
2174
+ hidden_states,
2175
+ encoder_hidden_states=encoder_hidden_states,
2176
+ image_only_indicator=image_only_indicator,
2177
+ return_dict=False,
2178
+ )[0]
2179
+
2180
+ output_states = output_states + (hidden_states,)
2181
+
2182
+ if self.downsamplers is not None:
2183
+ for downsampler in self.downsamplers:
2184
+ hidden_states = downsampler(hidden_states)
2185
+
2186
+ output_states = output_states + (hidden_states,)
2187
+
2188
+ return hidden_states, output_states
2189
+
2190
+
2191
+ class UpBlockSpatioTemporal(nn.Module):
2192
+ def __init__(
2193
+ self,
2194
+ in_channels: int,
2195
+ prev_output_channel: int,
2196
+ out_channels: int,
2197
+ temb_channels: int,
2198
+ resolution_idx: Optional[int] = None,
2199
+ num_layers: int = 1,
2200
+ resnet_eps: float = 1e-6,
2201
+ add_upsample: bool = True,
2202
+ ):
2203
+ super().__init__()
2204
+ resnets = []
2205
+
2206
+ for i in range(num_layers):
2207
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
2208
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
2209
+
2210
+ resnets.append(
2211
+ SpatioTemporalResBlock(
2212
+ in_channels=resnet_in_channels + res_skip_channels,
2213
+ out_channels=out_channels,
2214
+ temb_channels=temb_channels,
2215
+ eps=resnet_eps,
2216
+ )
2217
+ )
2218
+
2219
+ self.resnets = nn.ModuleList(resnets)
2220
+
2221
+ if add_upsample:
2222
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
2223
+ else:
2224
+ self.upsamplers = None
2225
+
2226
+ self.gradient_checkpointing = False
2227
+ self.resolution_idx = resolution_idx
2228
+
2229
+ def forward(
2230
+ self,
2231
+ hidden_states: torch.FloatTensor,
2232
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2233
+ temb: Optional[torch.FloatTensor] = None,
2234
+ image_only_indicator: Optional[torch.Tensor] = None,
2235
+ ) -> torch.FloatTensor:
2236
+ for resnet in self.resnets:
2237
+ # pop res hidden states
2238
+ res_hidden_states = res_hidden_states_tuple[-1]
2239
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
2240
+
2241
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2242
+
2243
+ if self.training and self.gradient_checkpointing:
2244
+
2245
+ def create_custom_forward(module):
2246
+ def custom_forward(*inputs):
2247
+ return module(*inputs)
2248
+
2249
+ return custom_forward
2250
+
2251
+ if is_torch_version(">=", "1.11.0"):
2252
+ hidden_states = torch.utils.checkpoint.checkpoint(
2253
+ create_custom_forward(resnet),
2254
+ hidden_states,
2255
+ temb,
2256
+ image_only_indicator,
2257
+ use_reentrant=False,
2258
+ )
2259
+ else:
2260
+ hidden_states = torch.utils.checkpoint.checkpoint(
2261
+ create_custom_forward(resnet),
2262
+ hidden_states,
2263
+ temb,
2264
+ image_only_indicator,
2265
+ )
2266
+ else:
2267
+ hidden_states = resnet(
2268
+ hidden_states,
2269
+ temb,
2270
+ image_only_indicator=image_only_indicator,
2271
+ )
2272
+
2273
+ if self.upsamplers is not None:
2274
+ for upsampler in self.upsamplers:
2275
+ hidden_states = upsampler(hidden_states)
2276
+
2277
+ return hidden_states
2278
+
2279
+
2280
+ class CrossAttnUpBlockSpatioTemporal(nn.Module):
2281
+ def __init__(
2282
+ self,
2283
+ in_channels: int,
2284
+ out_channels: int,
2285
+ prev_output_channel: int,
2286
+ temb_channels: int,
2287
+ resolution_idx: Optional[int] = None,
2288
+ num_layers: int = 1,
2289
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
2290
+ resnet_eps: float = 1e-6,
2291
+ num_attention_heads: int = 1,
2292
+ cross_attention_dim: int = 1280,
2293
+ add_upsample: bool = True,
2294
+ ):
2295
+ super().__init__()
2296
+ resnets = []
2297
+ attentions = []
2298
+
2299
+ self.has_cross_attention = True
2300
+ self.num_attention_heads = num_attention_heads
2301
+
2302
+ if isinstance(transformer_layers_per_block, int):
2303
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
2304
+
2305
+ for i in range(num_layers):
2306
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
2307
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
2308
+
2309
+ resnets.append(
2310
+ SpatioTemporalResBlock(
2311
+ in_channels=resnet_in_channels + res_skip_channels,
2312
+ out_channels=out_channels,
2313
+ temb_channels=temb_channels,
2314
+ eps=resnet_eps,
2315
+ )
2316
+ )
2317
+ attentions.append(
2318
+ TransformerSpatioTemporalModel(
2319
+ num_attention_heads,
2320
+ out_channels // num_attention_heads,
2321
+ in_channels=out_channels,
2322
+ num_layers=transformer_layers_per_block[i],
2323
+ cross_attention_dim=cross_attention_dim,
2324
+ )
2325
+ )
2326
+
2327
+ self.attentions = nn.ModuleList(attentions)
2328
+ self.resnets = nn.ModuleList(resnets)
2329
+
2330
+ if add_upsample:
2331
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
2332
+ else:
2333
+ self.upsamplers = None
2334
+
2335
+ self.gradient_checkpointing = False
2336
+ self.resolution_idx = resolution_idx
2337
+
2338
+ def forward(
2339
+ self,
2340
+ hidden_states: torch.FloatTensor,
2341
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2342
+ temb: Optional[torch.FloatTensor] = None,
2343
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
2344
+ image_only_indicator: Optional[torch.Tensor] = None,
2345
+ ) -> torch.FloatTensor:
2346
+ for resnet, attn in zip(self.resnets, self.attentions):
2347
+ # pop res hidden states
2348
+ res_hidden_states = res_hidden_states_tuple[-1]
2349
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
2350
+
2351
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2352
+
2353
+ if self.training and self.gradient_checkpointing: # TODO
2354
+
2355
+ def create_custom_forward(module, return_dict=None):
2356
+ def custom_forward(*inputs):
2357
+ if return_dict is not None:
2358
+ return module(*inputs, return_dict=return_dict)
2359
+ else:
2360
+ return module(*inputs)
2361
+
2362
+ return custom_forward
2363
+
2364
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
2365
+ hidden_states = torch.utils.checkpoint.checkpoint(
2366
+ create_custom_forward(resnet),
2367
+ hidden_states,
2368
+ temb,
2369
+ image_only_indicator,
2370
+ **ckpt_kwargs,
2371
+ )
2372
+ hidden_states = attn(
2373
+ hidden_states,
2374
+ encoder_hidden_states=encoder_hidden_states,
2375
+ image_only_indicator=image_only_indicator,
2376
+ return_dict=False,
2377
+ )[0]
2378
+ else:
2379
+ hidden_states = resnet(
2380
+ hidden_states,
2381
+ temb,
2382
+ image_only_indicator=image_only_indicator,
2383
+ )
2384
+ hidden_states = attn(
2385
+ hidden_states,
2386
+ encoder_hidden_states=encoder_hidden_states,
2387
+ image_only_indicator=image_only_indicator,
2388
+ return_dict=False,
2389
+ )[0]
2390
+
2391
+ if self.upsamplers is not None:
2392
+ for upsampler in self.upsamplers:
2393
+ hidden_states = upsampler(hidden_states)
2394
+
2395
+ return hidden_states