diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2252 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +3 -14
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +293 -8
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1937 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +403 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +50 -6
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +37 -15
  210. diffusers/utils/loading_utils.py +80 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -27,17 +27,58 @@ from ..resnet import (
27
27
  TemporalConvLayer,
28
28
  Upsample2D,
29
29
  )
30
- from ..transformers.dual_transformer_2d import DualTransformer2DModel
31
30
  from ..transformers.transformer_2d import Transformer2DModel
32
31
  from ..transformers.transformer_temporal import (
33
32
  TransformerSpatioTemporalModel,
34
33
  TransformerTemporalModel,
35
34
  )
35
+ from .unet_motion_model import (
36
+ CrossAttnDownBlockMotion,
37
+ CrossAttnUpBlockMotion,
38
+ DownBlockMotion,
39
+ UNetMidBlockCrossAttnMotion,
40
+ UpBlockMotion,
41
+ )
36
42
 
37
43
 
38
44
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
45
 
40
46
 
47
+ class DownBlockMotion(DownBlockMotion):
48
+ def __init__(self, *args, **kwargs):
49
+ deprecation_message = "Importing `DownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import DownBlockMotion` instead."
50
+ deprecate("DownBlockMotion", "1.0.0", deprecation_message)
51
+ super().__init__(*args, **kwargs)
52
+
53
+
54
+ class CrossAttnDownBlockMotion(CrossAttnDownBlockMotion):
55
+ def __init__(self, *args, **kwargs):
56
+ deprecation_message = "Importing `CrossAttnDownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnDownBlockMotion` instead."
57
+ deprecate("CrossAttnDownBlockMotion", "1.0.0", deprecation_message)
58
+ super().__init__(*args, **kwargs)
59
+
60
+
61
+ class UpBlockMotion(UpBlockMotion):
62
+ def __init__(self, *args, **kwargs):
63
+ deprecation_message = "Importing `UpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UpBlockMotion` instead."
64
+ deprecate("UpBlockMotion", "1.0.0", deprecation_message)
65
+ super().__init__(*args, **kwargs)
66
+
67
+
68
+ class CrossAttnUpBlockMotion(CrossAttnUpBlockMotion):
69
+ def __init__(self, *args, **kwargs):
70
+ deprecation_message = "Importing `CrossAttnUpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnUpBlockMotion` instead."
71
+ deprecate("CrossAttnUpBlockMotion", "1.0.0", deprecation_message)
72
+ super().__init__(*args, **kwargs)
73
+
74
+
75
+ class UNetMidBlockCrossAttnMotion(UNetMidBlockCrossAttnMotion):
76
+ def __init__(self, *args, **kwargs):
77
+ deprecation_message = "Importing `UNetMidBlockCrossAttnMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UNetMidBlockCrossAttnMotion` instead."
78
+ deprecate("UNetMidBlockCrossAttnMotion", "1.0.0", deprecation_message)
79
+ super().__init__(*args, **kwargs)
80
+
81
+
41
82
  def get_down_block(
42
83
  down_block_type: str,
43
84
  num_layers: int,
@@ -58,12 +99,12 @@ def get_down_block(
58
99
  resnet_time_scale_shift: str = "default",
59
100
  temporal_num_attention_heads: int = 8,
60
101
  temporal_max_seq_length: int = 32,
61
- transformer_layers_per_block: int = 1,
102
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
103
+ temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
104
+ dropout: float = 0.0,
62
105
  ) -> Union[
63
106
  "DownBlock3D",
64
107
  "CrossAttnDownBlock3D",
65
- "DownBlockMotion",
66
- "CrossAttnDownBlockMotion",
67
108
  "DownBlockSpatioTemporal",
68
109
  "CrossAttnDownBlockSpatioTemporal",
69
110
  ]:
@@ -79,6 +120,7 @@ def get_down_block(
79
120
  resnet_groups=resnet_groups,
80
121
  downsample_padding=downsample_padding,
81
122
  resnet_time_scale_shift=resnet_time_scale_shift,
123
+ dropout=dropout,
82
124
  )
83
125
  elif down_block_type == "CrossAttnDownBlock3D":
84
126
  if cross_attention_dim is None:
@@ -100,45 +142,7 @@ def get_down_block(
100
142
  only_cross_attention=only_cross_attention,
101
143
  upcast_attention=upcast_attention,
102
144
  resnet_time_scale_shift=resnet_time_scale_shift,
103
- )
104
- if down_block_type == "DownBlockMotion":
105
- return DownBlockMotion(
106
- num_layers=num_layers,
107
- in_channels=in_channels,
108
- out_channels=out_channels,
109
- temb_channels=temb_channels,
110
- add_downsample=add_downsample,
111
- resnet_eps=resnet_eps,
112
- resnet_act_fn=resnet_act_fn,
113
- resnet_groups=resnet_groups,
114
- downsample_padding=downsample_padding,
115
- resnet_time_scale_shift=resnet_time_scale_shift,
116
- temporal_num_attention_heads=temporal_num_attention_heads,
117
- temporal_max_seq_length=temporal_max_seq_length,
118
- )
119
- elif down_block_type == "CrossAttnDownBlockMotion":
120
- if cross_attention_dim is None:
121
- raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
122
- return CrossAttnDownBlockMotion(
123
- num_layers=num_layers,
124
- transformer_layers_per_block=transformer_layers_per_block,
125
- in_channels=in_channels,
126
- out_channels=out_channels,
127
- temb_channels=temb_channels,
128
- add_downsample=add_downsample,
129
- resnet_eps=resnet_eps,
130
- resnet_act_fn=resnet_act_fn,
131
- resnet_groups=resnet_groups,
132
- downsample_padding=downsample_padding,
133
- cross_attention_dim=cross_attention_dim,
134
- num_attention_heads=num_attention_heads,
135
- dual_cross_attention=dual_cross_attention,
136
- use_linear_projection=use_linear_projection,
137
- only_cross_attention=only_cross_attention,
138
- upcast_attention=upcast_attention,
139
- resnet_time_scale_shift=resnet_time_scale_shift,
140
- temporal_num_attention_heads=temporal_num_attention_heads,
141
- temporal_max_seq_length=temporal_max_seq_length,
145
+ dropout=dropout,
142
146
  )
143
147
  elif down_block_type == "DownBlockSpatioTemporal":
144
148
  # added for SDV
@@ -189,13 +193,12 @@ def get_up_block(
189
193
  temporal_num_attention_heads: int = 8,
190
194
  temporal_cross_attention_dim: Optional[int] = None,
191
195
  temporal_max_seq_length: int = 32,
192
- transformer_layers_per_block: int = 1,
196
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
197
+ temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
193
198
  dropout: float = 0.0,
194
199
  ) -> Union[
195
200
  "UpBlock3D",
196
201
  "CrossAttnUpBlock3D",
197
- "UpBlockMotion",
198
- "CrossAttnUpBlockMotion",
199
202
  "UpBlockSpatioTemporal",
200
203
  "CrossAttnUpBlockSpatioTemporal",
201
204
  ]:
@@ -212,6 +215,7 @@ def get_up_block(
212
215
  resnet_groups=resnet_groups,
213
216
  resnet_time_scale_shift=resnet_time_scale_shift,
214
217
  resolution_idx=resolution_idx,
218
+ dropout=dropout,
215
219
  )
216
220
  elif up_block_type == "CrossAttnUpBlock3D":
217
221
  if cross_attention_dim is None:
@@ -234,47 +238,7 @@ def get_up_block(
234
238
  upcast_attention=upcast_attention,
235
239
  resnet_time_scale_shift=resnet_time_scale_shift,
236
240
  resolution_idx=resolution_idx,
237
- )
238
- if up_block_type == "UpBlockMotion":
239
- return UpBlockMotion(
240
- num_layers=num_layers,
241
- in_channels=in_channels,
242
- out_channels=out_channels,
243
- prev_output_channel=prev_output_channel,
244
- temb_channels=temb_channels,
245
- add_upsample=add_upsample,
246
- resnet_eps=resnet_eps,
247
- resnet_act_fn=resnet_act_fn,
248
- resnet_groups=resnet_groups,
249
- resnet_time_scale_shift=resnet_time_scale_shift,
250
- resolution_idx=resolution_idx,
251
- temporal_num_attention_heads=temporal_num_attention_heads,
252
- temporal_max_seq_length=temporal_max_seq_length,
253
- )
254
- elif up_block_type == "CrossAttnUpBlockMotion":
255
- if cross_attention_dim is None:
256
- raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
257
- return CrossAttnUpBlockMotion(
258
- num_layers=num_layers,
259
- transformer_layers_per_block=transformer_layers_per_block,
260
- in_channels=in_channels,
261
- out_channels=out_channels,
262
- prev_output_channel=prev_output_channel,
263
- temb_channels=temb_channels,
264
- add_upsample=add_upsample,
265
- resnet_eps=resnet_eps,
266
- resnet_act_fn=resnet_act_fn,
267
- resnet_groups=resnet_groups,
268
- cross_attention_dim=cross_attention_dim,
269
- num_attention_heads=num_attention_heads,
270
- dual_cross_attention=dual_cross_attention,
271
- use_linear_projection=use_linear_projection,
272
- only_cross_attention=only_cross_attention,
273
- upcast_attention=upcast_attention,
274
- resnet_time_scale_shift=resnet_time_scale_shift,
275
- resolution_idx=resolution_idx,
276
- temporal_num_attention_heads=temporal_num_attention_heads,
277
- temporal_max_seq_length=temporal_max_seq_length,
241
+ dropout=dropout,
278
242
  )
279
243
  elif up_block_type == "UpBlockSpatioTemporal":
280
244
  # added for SDV
@@ -932,839 +896,6 @@ class UpBlock3D(nn.Module):
932
896
  return hidden_states
933
897
 
934
898
 
935
- class DownBlockMotion(nn.Module):
936
- def __init__(
937
- self,
938
- in_channels: int,
939
- out_channels: int,
940
- temb_channels: int,
941
- dropout: float = 0.0,
942
- num_layers: int = 1,
943
- resnet_eps: float = 1e-6,
944
- resnet_time_scale_shift: str = "default",
945
- resnet_act_fn: str = "swish",
946
- resnet_groups: int = 32,
947
- resnet_pre_norm: bool = True,
948
- output_scale_factor: float = 1.0,
949
- add_downsample: bool = True,
950
- downsample_padding: int = 1,
951
- temporal_num_attention_heads: int = 1,
952
- temporal_cross_attention_dim: Optional[int] = None,
953
- temporal_max_seq_length: int = 32,
954
- ):
955
- super().__init__()
956
- resnets = []
957
- motion_modules = []
958
-
959
- for i in range(num_layers):
960
- in_channels = in_channels if i == 0 else out_channels
961
- resnets.append(
962
- ResnetBlock2D(
963
- in_channels=in_channels,
964
- out_channels=out_channels,
965
- temb_channels=temb_channels,
966
- eps=resnet_eps,
967
- groups=resnet_groups,
968
- dropout=dropout,
969
- time_embedding_norm=resnet_time_scale_shift,
970
- non_linearity=resnet_act_fn,
971
- output_scale_factor=output_scale_factor,
972
- pre_norm=resnet_pre_norm,
973
- )
974
- )
975
- motion_modules.append(
976
- TransformerTemporalModel(
977
- num_attention_heads=temporal_num_attention_heads,
978
- in_channels=out_channels,
979
- norm_num_groups=resnet_groups,
980
- cross_attention_dim=temporal_cross_attention_dim,
981
- attention_bias=False,
982
- activation_fn="geglu",
983
- positional_embeddings="sinusoidal",
984
- num_positional_embeddings=temporal_max_seq_length,
985
- attention_head_dim=out_channels // temporal_num_attention_heads,
986
- )
987
- )
988
-
989
- self.resnets = nn.ModuleList(resnets)
990
- self.motion_modules = nn.ModuleList(motion_modules)
991
-
992
- if add_downsample:
993
- self.downsamplers = nn.ModuleList(
994
- [
995
- Downsample2D(
996
- out_channels,
997
- use_conv=True,
998
- out_channels=out_channels,
999
- padding=downsample_padding,
1000
- name="op",
1001
- )
1002
- ]
1003
- )
1004
- else:
1005
- self.downsamplers = None
1006
-
1007
- self.gradient_checkpointing = False
1008
-
1009
- def forward(
1010
- self,
1011
- hidden_states: torch.Tensor,
1012
- temb: Optional[torch.Tensor] = None,
1013
- num_frames: int = 1,
1014
- *args,
1015
- **kwargs,
1016
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
1017
- if len(args) > 0 or kwargs.get("scale", None) is not None:
1018
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1019
- deprecate("scale", "1.0.0", deprecation_message)
1020
-
1021
- output_states = ()
1022
-
1023
- blocks = zip(self.resnets, self.motion_modules)
1024
- for resnet, motion_module in blocks:
1025
- if self.training and self.gradient_checkpointing:
1026
-
1027
- def create_custom_forward(module):
1028
- def custom_forward(*inputs):
1029
- return module(*inputs)
1030
-
1031
- return custom_forward
1032
-
1033
- if is_torch_version(">=", "1.11.0"):
1034
- hidden_states = torch.utils.checkpoint.checkpoint(
1035
- create_custom_forward(resnet),
1036
- hidden_states,
1037
- temb,
1038
- use_reentrant=False,
1039
- )
1040
- else:
1041
- hidden_states = torch.utils.checkpoint.checkpoint(
1042
- create_custom_forward(resnet), hidden_states, temb
1043
- )
1044
-
1045
- else:
1046
- hidden_states = resnet(hidden_states, temb)
1047
- hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
1048
-
1049
- output_states = output_states + (hidden_states,)
1050
-
1051
- if self.downsamplers is not None:
1052
- for downsampler in self.downsamplers:
1053
- hidden_states = downsampler(hidden_states)
1054
-
1055
- output_states = output_states + (hidden_states,)
1056
-
1057
- return hidden_states, output_states
1058
-
1059
-
1060
- class CrossAttnDownBlockMotion(nn.Module):
1061
- def __init__(
1062
- self,
1063
- in_channels: int,
1064
- out_channels: int,
1065
- temb_channels: int,
1066
- dropout: float = 0.0,
1067
- num_layers: int = 1,
1068
- transformer_layers_per_block: int = 1,
1069
- resnet_eps: float = 1e-6,
1070
- resnet_time_scale_shift: str = "default",
1071
- resnet_act_fn: str = "swish",
1072
- resnet_groups: int = 32,
1073
- resnet_pre_norm: bool = True,
1074
- num_attention_heads: int = 1,
1075
- cross_attention_dim: int = 1280,
1076
- output_scale_factor: float = 1.0,
1077
- downsample_padding: int = 1,
1078
- add_downsample: bool = True,
1079
- dual_cross_attention: bool = False,
1080
- use_linear_projection: bool = False,
1081
- only_cross_attention: bool = False,
1082
- upcast_attention: bool = False,
1083
- attention_type: str = "default",
1084
- temporal_cross_attention_dim: Optional[int] = None,
1085
- temporal_num_attention_heads: int = 8,
1086
- temporal_max_seq_length: int = 32,
1087
- ):
1088
- super().__init__()
1089
- resnets = []
1090
- attentions = []
1091
- motion_modules = []
1092
-
1093
- self.has_cross_attention = True
1094
- self.num_attention_heads = num_attention_heads
1095
-
1096
- for i in range(num_layers):
1097
- in_channels = in_channels if i == 0 else out_channels
1098
- resnets.append(
1099
- ResnetBlock2D(
1100
- in_channels=in_channels,
1101
- out_channels=out_channels,
1102
- temb_channels=temb_channels,
1103
- eps=resnet_eps,
1104
- groups=resnet_groups,
1105
- dropout=dropout,
1106
- time_embedding_norm=resnet_time_scale_shift,
1107
- non_linearity=resnet_act_fn,
1108
- output_scale_factor=output_scale_factor,
1109
- pre_norm=resnet_pre_norm,
1110
- )
1111
- )
1112
-
1113
- if not dual_cross_attention:
1114
- attentions.append(
1115
- Transformer2DModel(
1116
- num_attention_heads,
1117
- out_channels // num_attention_heads,
1118
- in_channels=out_channels,
1119
- num_layers=transformer_layers_per_block,
1120
- cross_attention_dim=cross_attention_dim,
1121
- norm_num_groups=resnet_groups,
1122
- use_linear_projection=use_linear_projection,
1123
- only_cross_attention=only_cross_attention,
1124
- upcast_attention=upcast_attention,
1125
- attention_type=attention_type,
1126
- )
1127
- )
1128
- else:
1129
- attentions.append(
1130
- DualTransformer2DModel(
1131
- num_attention_heads,
1132
- out_channels // num_attention_heads,
1133
- in_channels=out_channels,
1134
- num_layers=1,
1135
- cross_attention_dim=cross_attention_dim,
1136
- norm_num_groups=resnet_groups,
1137
- )
1138
- )
1139
-
1140
- motion_modules.append(
1141
- TransformerTemporalModel(
1142
- num_attention_heads=temporal_num_attention_heads,
1143
- in_channels=out_channels,
1144
- norm_num_groups=resnet_groups,
1145
- cross_attention_dim=temporal_cross_attention_dim,
1146
- attention_bias=False,
1147
- activation_fn="geglu",
1148
- positional_embeddings="sinusoidal",
1149
- num_positional_embeddings=temporal_max_seq_length,
1150
- attention_head_dim=out_channels // temporal_num_attention_heads,
1151
- )
1152
- )
1153
-
1154
- self.attentions = nn.ModuleList(attentions)
1155
- self.resnets = nn.ModuleList(resnets)
1156
- self.motion_modules = nn.ModuleList(motion_modules)
1157
-
1158
- if add_downsample:
1159
- self.downsamplers = nn.ModuleList(
1160
- [
1161
- Downsample2D(
1162
- out_channels,
1163
- use_conv=True,
1164
- out_channels=out_channels,
1165
- padding=downsample_padding,
1166
- name="op",
1167
- )
1168
- ]
1169
- )
1170
- else:
1171
- self.downsamplers = None
1172
-
1173
- self.gradient_checkpointing = False
1174
-
1175
- def forward(
1176
- self,
1177
- hidden_states: torch.Tensor,
1178
- temb: Optional[torch.Tensor] = None,
1179
- encoder_hidden_states: Optional[torch.Tensor] = None,
1180
- attention_mask: Optional[torch.Tensor] = None,
1181
- num_frames: int = 1,
1182
- encoder_attention_mask: Optional[torch.Tensor] = None,
1183
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1184
- additional_residuals: Optional[torch.Tensor] = None,
1185
- ):
1186
- if cross_attention_kwargs is not None:
1187
- if cross_attention_kwargs.get("scale", None) is not None:
1188
- logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1189
-
1190
- output_states = ()
1191
-
1192
- blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
1193
- for i, (resnet, attn, motion_module) in enumerate(blocks):
1194
- if self.training and self.gradient_checkpointing:
1195
-
1196
- def create_custom_forward(module, return_dict=None):
1197
- def custom_forward(*inputs):
1198
- if return_dict is not None:
1199
- return module(*inputs, return_dict=return_dict)
1200
- else:
1201
- return module(*inputs)
1202
-
1203
- return custom_forward
1204
-
1205
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1206
- hidden_states = torch.utils.checkpoint.checkpoint(
1207
- create_custom_forward(resnet),
1208
- hidden_states,
1209
- temb,
1210
- **ckpt_kwargs,
1211
- )
1212
- hidden_states = attn(
1213
- hidden_states,
1214
- encoder_hidden_states=encoder_hidden_states,
1215
- cross_attention_kwargs=cross_attention_kwargs,
1216
- attention_mask=attention_mask,
1217
- encoder_attention_mask=encoder_attention_mask,
1218
- return_dict=False,
1219
- )[0]
1220
- else:
1221
- hidden_states = resnet(hidden_states, temb)
1222
- hidden_states = attn(
1223
- hidden_states,
1224
- encoder_hidden_states=encoder_hidden_states,
1225
- cross_attention_kwargs=cross_attention_kwargs,
1226
- attention_mask=attention_mask,
1227
- encoder_attention_mask=encoder_attention_mask,
1228
- return_dict=False,
1229
- )[0]
1230
- hidden_states = motion_module(
1231
- hidden_states,
1232
- num_frames=num_frames,
1233
- )[0]
1234
-
1235
- # apply additional residuals to the output of the last pair of resnet and attention blocks
1236
- if i == len(blocks) - 1 and additional_residuals is not None:
1237
- hidden_states = hidden_states + additional_residuals
1238
-
1239
- output_states = output_states + (hidden_states,)
1240
-
1241
- if self.downsamplers is not None:
1242
- for downsampler in self.downsamplers:
1243
- hidden_states = downsampler(hidden_states)
1244
-
1245
- output_states = output_states + (hidden_states,)
1246
-
1247
- return hidden_states, output_states
1248
-
1249
-
1250
- class CrossAttnUpBlockMotion(nn.Module):
1251
- def __init__(
1252
- self,
1253
- in_channels: int,
1254
- out_channels: int,
1255
- prev_output_channel: int,
1256
- temb_channels: int,
1257
- resolution_idx: Optional[int] = None,
1258
- dropout: float = 0.0,
1259
- num_layers: int = 1,
1260
- transformer_layers_per_block: int = 1,
1261
- resnet_eps: float = 1e-6,
1262
- resnet_time_scale_shift: str = "default",
1263
- resnet_act_fn: str = "swish",
1264
- resnet_groups: int = 32,
1265
- resnet_pre_norm: bool = True,
1266
- num_attention_heads: int = 1,
1267
- cross_attention_dim: int = 1280,
1268
- output_scale_factor: float = 1.0,
1269
- add_upsample: bool = True,
1270
- dual_cross_attention: bool = False,
1271
- use_linear_projection: bool = False,
1272
- only_cross_attention: bool = False,
1273
- upcast_attention: bool = False,
1274
- attention_type: str = "default",
1275
- temporal_cross_attention_dim: Optional[int] = None,
1276
- temporal_num_attention_heads: int = 8,
1277
- temporal_max_seq_length: int = 32,
1278
- ):
1279
- super().__init__()
1280
- resnets = []
1281
- attentions = []
1282
- motion_modules = []
1283
-
1284
- self.has_cross_attention = True
1285
- self.num_attention_heads = num_attention_heads
1286
-
1287
- for i in range(num_layers):
1288
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1289
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
1290
-
1291
- resnets.append(
1292
- ResnetBlock2D(
1293
- in_channels=resnet_in_channels + res_skip_channels,
1294
- out_channels=out_channels,
1295
- temb_channels=temb_channels,
1296
- eps=resnet_eps,
1297
- groups=resnet_groups,
1298
- dropout=dropout,
1299
- time_embedding_norm=resnet_time_scale_shift,
1300
- non_linearity=resnet_act_fn,
1301
- output_scale_factor=output_scale_factor,
1302
- pre_norm=resnet_pre_norm,
1303
- )
1304
- )
1305
-
1306
- if not dual_cross_attention:
1307
- attentions.append(
1308
- Transformer2DModel(
1309
- num_attention_heads,
1310
- out_channels // num_attention_heads,
1311
- in_channels=out_channels,
1312
- num_layers=transformer_layers_per_block,
1313
- cross_attention_dim=cross_attention_dim,
1314
- norm_num_groups=resnet_groups,
1315
- use_linear_projection=use_linear_projection,
1316
- only_cross_attention=only_cross_attention,
1317
- upcast_attention=upcast_attention,
1318
- attention_type=attention_type,
1319
- )
1320
- )
1321
- else:
1322
- attentions.append(
1323
- DualTransformer2DModel(
1324
- num_attention_heads,
1325
- out_channels // num_attention_heads,
1326
- in_channels=out_channels,
1327
- num_layers=1,
1328
- cross_attention_dim=cross_attention_dim,
1329
- norm_num_groups=resnet_groups,
1330
- )
1331
- )
1332
- motion_modules.append(
1333
- TransformerTemporalModel(
1334
- num_attention_heads=temporal_num_attention_heads,
1335
- in_channels=out_channels,
1336
- norm_num_groups=resnet_groups,
1337
- cross_attention_dim=temporal_cross_attention_dim,
1338
- attention_bias=False,
1339
- activation_fn="geglu",
1340
- positional_embeddings="sinusoidal",
1341
- num_positional_embeddings=temporal_max_seq_length,
1342
- attention_head_dim=out_channels // temporal_num_attention_heads,
1343
- )
1344
- )
1345
-
1346
- self.attentions = nn.ModuleList(attentions)
1347
- self.resnets = nn.ModuleList(resnets)
1348
- self.motion_modules = nn.ModuleList(motion_modules)
1349
-
1350
- if add_upsample:
1351
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1352
- else:
1353
- self.upsamplers = None
1354
-
1355
- self.gradient_checkpointing = False
1356
- self.resolution_idx = resolution_idx
1357
-
1358
- def forward(
1359
- self,
1360
- hidden_states: torch.Tensor,
1361
- res_hidden_states_tuple: Tuple[torch.Tensor, ...],
1362
- temb: Optional[torch.Tensor] = None,
1363
- encoder_hidden_states: Optional[torch.Tensor] = None,
1364
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1365
- upsample_size: Optional[int] = None,
1366
- attention_mask: Optional[torch.Tensor] = None,
1367
- encoder_attention_mask: Optional[torch.Tensor] = None,
1368
- num_frames: int = 1,
1369
- ) -> torch.Tensor:
1370
- if cross_attention_kwargs is not None:
1371
- if cross_attention_kwargs.get("scale", None) is not None:
1372
- logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1373
-
1374
- is_freeu_enabled = (
1375
- getattr(self, "s1", None)
1376
- and getattr(self, "s2", None)
1377
- and getattr(self, "b1", None)
1378
- and getattr(self, "b2", None)
1379
- )
1380
-
1381
- blocks = zip(self.resnets, self.attentions, self.motion_modules)
1382
- for resnet, attn, motion_module in blocks:
1383
- # pop res hidden states
1384
- res_hidden_states = res_hidden_states_tuple[-1]
1385
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1386
-
1387
- # FreeU: Only operate on the first two stages
1388
- if is_freeu_enabled:
1389
- hidden_states, res_hidden_states = apply_freeu(
1390
- self.resolution_idx,
1391
- hidden_states,
1392
- res_hidden_states,
1393
- s1=self.s1,
1394
- s2=self.s2,
1395
- b1=self.b1,
1396
- b2=self.b2,
1397
- )
1398
-
1399
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1400
-
1401
- if self.training and self.gradient_checkpointing:
1402
-
1403
- def create_custom_forward(module, return_dict=None):
1404
- def custom_forward(*inputs):
1405
- if return_dict is not None:
1406
- return module(*inputs, return_dict=return_dict)
1407
- else:
1408
- return module(*inputs)
1409
-
1410
- return custom_forward
1411
-
1412
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1413
- hidden_states = torch.utils.checkpoint.checkpoint(
1414
- create_custom_forward(resnet),
1415
- hidden_states,
1416
- temb,
1417
- **ckpt_kwargs,
1418
- )
1419
- hidden_states = attn(
1420
- hidden_states,
1421
- encoder_hidden_states=encoder_hidden_states,
1422
- cross_attention_kwargs=cross_attention_kwargs,
1423
- attention_mask=attention_mask,
1424
- encoder_attention_mask=encoder_attention_mask,
1425
- return_dict=False,
1426
- )[0]
1427
- else:
1428
- hidden_states = resnet(hidden_states, temb)
1429
- hidden_states = attn(
1430
- hidden_states,
1431
- encoder_hidden_states=encoder_hidden_states,
1432
- cross_attention_kwargs=cross_attention_kwargs,
1433
- attention_mask=attention_mask,
1434
- encoder_attention_mask=encoder_attention_mask,
1435
- return_dict=False,
1436
- )[0]
1437
- hidden_states = motion_module(
1438
- hidden_states,
1439
- num_frames=num_frames,
1440
- )[0]
1441
-
1442
- if self.upsamplers is not None:
1443
- for upsampler in self.upsamplers:
1444
- hidden_states = upsampler(hidden_states, upsample_size)
1445
-
1446
- return hidden_states
1447
-
1448
-
1449
- class UpBlockMotion(nn.Module):
1450
- def __init__(
1451
- self,
1452
- in_channels: int,
1453
- prev_output_channel: int,
1454
- out_channels: int,
1455
- temb_channels: int,
1456
- resolution_idx: Optional[int] = None,
1457
- dropout: float = 0.0,
1458
- num_layers: int = 1,
1459
- resnet_eps: float = 1e-6,
1460
- resnet_time_scale_shift: str = "default",
1461
- resnet_act_fn: str = "swish",
1462
- resnet_groups: int = 32,
1463
- resnet_pre_norm: bool = True,
1464
- output_scale_factor: float = 1.0,
1465
- add_upsample: bool = True,
1466
- temporal_norm_num_groups: int = 32,
1467
- temporal_cross_attention_dim: Optional[int] = None,
1468
- temporal_num_attention_heads: int = 8,
1469
- temporal_max_seq_length: int = 32,
1470
- ):
1471
- super().__init__()
1472
- resnets = []
1473
- motion_modules = []
1474
-
1475
- for i in range(num_layers):
1476
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1477
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
1478
-
1479
- resnets.append(
1480
- ResnetBlock2D(
1481
- in_channels=resnet_in_channels + res_skip_channels,
1482
- out_channels=out_channels,
1483
- temb_channels=temb_channels,
1484
- eps=resnet_eps,
1485
- groups=resnet_groups,
1486
- dropout=dropout,
1487
- time_embedding_norm=resnet_time_scale_shift,
1488
- non_linearity=resnet_act_fn,
1489
- output_scale_factor=output_scale_factor,
1490
- pre_norm=resnet_pre_norm,
1491
- )
1492
- )
1493
-
1494
- motion_modules.append(
1495
- TransformerTemporalModel(
1496
- num_attention_heads=temporal_num_attention_heads,
1497
- in_channels=out_channels,
1498
- norm_num_groups=temporal_norm_num_groups,
1499
- cross_attention_dim=temporal_cross_attention_dim,
1500
- attention_bias=False,
1501
- activation_fn="geglu",
1502
- positional_embeddings="sinusoidal",
1503
- num_positional_embeddings=temporal_max_seq_length,
1504
- attention_head_dim=out_channels // temporal_num_attention_heads,
1505
- )
1506
- )
1507
-
1508
- self.resnets = nn.ModuleList(resnets)
1509
- self.motion_modules = nn.ModuleList(motion_modules)
1510
-
1511
- if add_upsample:
1512
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1513
- else:
1514
- self.upsamplers = None
1515
-
1516
- self.gradient_checkpointing = False
1517
- self.resolution_idx = resolution_idx
1518
-
1519
- def forward(
1520
- self,
1521
- hidden_states: torch.Tensor,
1522
- res_hidden_states_tuple: Tuple[torch.Tensor, ...],
1523
- temb: Optional[torch.Tensor] = None,
1524
- upsample_size=None,
1525
- num_frames: int = 1,
1526
- *args,
1527
- **kwargs,
1528
- ) -> torch.Tensor:
1529
- if len(args) > 0 or kwargs.get("scale", None) is not None:
1530
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1531
- deprecate("scale", "1.0.0", deprecation_message)
1532
-
1533
- is_freeu_enabled = (
1534
- getattr(self, "s1", None)
1535
- and getattr(self, "s2", None)
1536
- and getattr(self, "b1", None)
1537
- and getattr(self, "b2", None)
1538
- )
1539
-
1540
- blocks = zip(self.resnets, self.motion_modules)
1541
-
1542
- for resnet, motion_module in blocks:
1543
- # pop res hidden states
1544
- res_hidden_states = res_hidden_states_tuple[-1]
1545
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1546
-
1547
- # FreeU: Only operate on the first two stages
1548
- if is_freeu_enabled:
1549
- hidden_states, res_hidden_states = apply_freeu(
1550
- self.resolution_idx,
1551
- hidden_states,
1552
- res_hidden_states,
1553
- s1=self.s1,
1554
- s2=self.s2,
1555
- b1=self.b1,
1556
- b2=self.b2,
1557
- )
1558
-
1559
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1560
-
1561
- if self.training and self.gradient_checkpointing:
1562
-
1563
- def create_custom_forward(module):
1564
- def custom_forward(*inputs):
1565
- return module(*inputs)
1566
-
1567
- return custom_forward
1568
-
1569
- if is_torch_version(">=", "1.11.0"):
1570
- hidden_states = torch.utils.checkpoint.checkpoint(
1571
- create_custom_forward(resnet),
1572
- hidden_states,
1573
- temb,
1574
- use_reentrant=False,
1575
- )
1576
- else:
1577
- hidden_states = torch.utils.checkpoint.checkpoint(
1578
- create_custom_forward(resnet), hidden_states, temb
1579
- )
1580
-
1581
- else:
1582
- hidden_states = resnet(hidden_states, temb)
1583
- hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
1584
-
1585
- if self.upsamplers is not None:
1586
- for upsampler in self.upsamplers:
1587
- hidden_states = upsampler(hidden_states, upsample_size)
1588
-
1589
- return hidden_states
1590
-
1591
-
1592
- class UNetMidBlockCrossAttnMotion(nn.Module):
1593
- def __init__(
1594
- self,
1595
- in_channels: int,
1596
- temb_channels: int,
1597
- dropout: float = 0.0,
1598
- num_layers: int = 1,
1599
- transformer_layers_per_block: int = 1,
1600
- resnet_eps: float = 1e-6,
1601
- resnet_time_scale_shift: str = "default",
1602
- resnet_act_fn: str = "swish",
1603
- resnet_groups: int = 32,
1604
- resnet_pre_norm: bool = True,
1605
- num_attention_heads: int = 1,
1606
- output_scale_factor: float = 1.0,
1607
- cross_attention_dim: int = 1280,
1608
- dual_cross_attention: float = False,
1609
- use_linear_projection: float = False,
1610
- upcast_attention: float = False,
1611
- attention_type: str = "default",
1612
- temporal_num_attention_heads: int = 1,
1613
- temporal_cross_attention_dim: Optional[int] = None,
1614
- temporal_max_seq_length: int = 32,
1615
- ):
1616
- super().__init__()
1617
-
1618
- self.has_cross_attention = True
1619
- self.num_attention_heads = num_attention_heads
1620
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
1621
-
1622
- # there is always at least one resnet
1623
- resnets = [
1624
- ResnetBlock2D(
1625
- in_channels=in_channels,
1626
- out_channels=in_channels,
1627
- temb_channels=temb_channels,
1628
- eps=resnet_eps,
1629
- groups=resnet_groups,
1630
- dropout=dropout,
1631
- time_embedding_norm=resnet_time_scale_shift,
1632
- non_linearity=resnet_act_fn,
1633
- output_scale_factor=output_scale_factor,
1634
- pre_norm=resnet_pre_norm,
1635
- )
1636
- ]
1637
- attentions = []
1638
- motion_modules = []
1639
-
1640
- for _ in range(num_layers):
1641
- if not dual_cross_attention:
1642
- attentions.append(
1643
- Transformer2DModel(
1644
- num_attention_heads,
1645
- in_channels // num_attention_heads,
1646
- in_channels=in_channels,
1647
- num_layers=transformer_layers_per_block,
1648
- cross_attention_dim=cross_attention_dim,
1649
- norm_num_groups=resnet_groups,
1650
- use_linear_projection=use_linear_projection,
1651
- upcast_attention=upcast_attention,
1652
- attention_type=attention_type,
1653
- )
1654
- )
1655
- else:
1656
- attentions.append(
1657
- DualTransformer2DModel(
1658
- num_attention_heads,
1659
- in_channels // num_attention_heads,
1660
- in_channels=in_channels,
1661
- num_layers=1,
1662
- cross_attention_dim=cross_attention_dim,
1663
- norm_num_groups=resnet_groups,
1664
- )
1665
- )
1666
- resnets.append(
1667
- ResnetBlock2D(
1668
- in_channels=in_channels,
1669
- out_channels=in_channels,
1670
- temb_channels=temb_channels,
1671
- eps=resnet_eps,
1672
- groups=resnet_groups,
1673
- dropout=dropout,
1674
- time_embedding_norm=resnet_time_scale_shift,
1675
- non_linearity=resnet_act_fn,
1676
- output_scale_factor=output_scale_factor,
1677
- pre_norm=resnet_pre_norm,
1678
- )
1679
- )
1680
- motion_modules.append(
1681
- TransformerTemporalModel(
1682
- num_attention_heads=temporal_num_attention_heads,
1683
- attention_head_dim=in_channels // temporal_num_attention_heads,
1684
- in_channels=in_channels,
1685
- norm_num_groups=resnet_groups,
1686
- cross_attention_dim=temporal_cross_attention_dim,
1687
- attention_bias=False,
1688
- positional_embeddings="sinusoidal",
1689
- num_positional_embeddings=temporal_max_seq_length,
1690
- activation_fn="geglu",
1691
- )
1692
- )
1693
-
1694
- self.attentions = nn.ModuleList(attentions)
1695
- self.resnets = nn.ModuleList(resnets)
1696
- self.motion_modules = nn.ModuleList(motion_modules)
1697
-
1698
- self.gradient_checkpointing = False
1699
-
1700
- def forward(
1701
- self,
1702
- hidden_states: torch.Tensor,
1703
- temb: Optional[torch.Tensor] = None,
1704
- encoder_hidden_states: Optional[torch.Tensor] = None,
1705
- attention_mask: Optional[torch.Tensor] = None,
1706
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1707
- encoder_attention_mask: Optional[torch.Tensor] = None,
1708
- num_frames: int = 1,
1709
- ) -> torch.Tensor:
1710
- if cross_attention_kwargs is not None:
1711
- if cross_attention_kwargs.get("scale", None) is not None:
1712
- logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1713
-
1714
- hidden_states = self.resnets[0](hidden_states, temb)
1715
-
1716
- blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
1717
- for attn, resnet, motion_module in blocks:
1718
- if self.training and self.gradient_checkpointing:
1719
-
1720
- def create_custom_forward(module, return_dict=None):
1721
- def custom_forward(*inputs):
1722
- if return_dict is not None:
1723
- return module(*inputs, return_dict=return_dict)
1724
- else:
1725
- return module(*inputs)
1726
-
1727
- return custom_forward
1728
-
1729
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1730
- hidden_states = attn(
1731
- hidden_states,
1732
- encoder_hidden_states=encoder_hidden_states,
1733
- cross_attention_kwargs=cross_attention_kwargs,
1734
- attention_mask=attention_mask,
1735
- encoder_attention_mask=encoder_attention_mask,
1736
- return_dict=False,
1737
- )[0]
1738
- hidden_states = torch.utils.checkpoint.checkpoint(
1739
- create_custom_forward(motion_module),
1740
- hidden_states,
1741
- temb,
1742
- **ckpt_kwargs,
1743
- )
1744
- hidden_states = torch.utils.checkpoint.checkpoint(
1745
- create_custom_forward(resnet),
1746
- hidden_states,
1747
- temb,
1748
- **ckpt_kwargs,
1749
- )
1750
- else:
1751
- hidden_states = attn(
1752
- hidden_states,
1753
- encoder_hidden_states=encoder_hidden_states,
1754
- cross_attention_kwargs=cross_attention_kwargs,
1755
- attention_mask=attention_mask,
1756
- encoder_attention_mask=encoder_attention_mask,
1757
- return_dict=False,
1758
- )[0]
1759
- hidden_states = motion_module(
1760
- hidden_states,
1761
- num_frames=num_frames,
1762
- )[0]
1763
- hidden_states = resnet(hidden_states, temb)
1764
-
1765
- return hidden_states
1766
-
1767
-
1768
899
  class MidBlockTemporalDecoder(nn.Module):
1769
900
  def __init__(
1770
901
  self,