diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2222 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +1 -12
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +262 -2
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1795 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +319 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +1 -4
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +19 -16
  210. diffusers/utils/loading_utils.py +76 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,8 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
15
+ from dataclasses import dataclass
14
16
  from typing import Any, Dict, Optional, Tuple, Union
15
17
 
16
18
  import torch
@@ -19,8 +21,10 @@ import torch.nn.functional as F
19
21
  import torch.utils.checkpoint
20
22
 
21
23
  from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
22
- from ...loaders import UNet2DConditionLoadersMixin
23
- from ...utils import logging
24
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
25
+ from ...utils import BaseOutput, deprecate, is_torch_version, logging
26
+ from ...utils.torch_utils import apply_freeu
27
+ from ..attention import BasicTransformerBlock
24
28
  from ..attention_processor import (
25
29
  ADDED_KV_ATTENTION_PROCESSORS,
26
30
  CROSS_ATTENTION_PROCESSORS,
@@ -29,35 +33,1114 @@ from ..attention_processor import (
29
33
  AttnAddedKVProcessor,
30
34
  AttnProcessor,
31
35
  AttnProcessor2_0,
36
+ FusedAttnProcessor2_0,
32
37
  IPAdapterAttnProcessor,
33
38
  IPAdapterAttnProcessor2_0,
34
39
  )
35
40
  from ..embeddings import TimestepEmbedding, Timesteps
36
41
  from ..modeling_utils import ModelMixin
37
- from ..transformers.transformer_temporal import TransformerTemporalModel
42
+ from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D
43
+ from ..transformers.dual_transformer_2d import DualTransformer2DModel
44
+ from ..transformers.transformer_2d import Transformer2DModel
38
45
  from .unet_2d_blocks import UNetMidBlock2DCrossAttn
39
46
  from .unet_2d_condition import UNet2DConditionModel
40
- from .unet_3d_blocks import (
41
- CrossAttnDownBlockMotion,
42
- CrossAttnUpBlockMotion,
43
- DownBlockMotion,
44
- UNetMidBlockCrossAttnMotion,
45
- UpBlockMotion,
46
- get_down_block,
47
- get_up_block,
48
- )
49
- from .unet_3d_condition import UNet3DConditionOutput
50
47
 
51
48
 
52
49
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
53
50
 
54
51
 
52
+ @dataclass
53
+ class UNetMotionOutput(BaseOutput):
54
+ """
55
+ The output of [`UNetMotionOutput`].
56
+
57
+ Args:
58
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`):
59
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
60
+ """
61
+
62
+ sample: torch.Tensor
63
+
64
+
65
+ class AnimateDiffTransformer3D(nn.Module):
66
+ """
67
+ A Transformer model for video-like data.
68
+
69
+ Parameters:
70
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
71
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
72
+ in_channels (`int`, *optional*):
73
+ The number of channels in the input and output (specify if the input is **continuous**).
74
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
75
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
76
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
77
+ attention_bias (`bool`, *optional*):
78
+ Configure if the `TransformerBlock` attention should contain a bias parameter.
79
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
80
+ This is fixed during training since it is used to learn a number of position embeddings.
81
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
82
+ Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
83
+ activation functions.
84
+ norm_elementwise_affine (`bool`, *optional*):
85
+ Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
86
+ double_self_attention (`bool`, *optional*):
87
+ Configure if each `TransformerBlock` should contain two self-attention layers.
88
+ positional_embeddings: (`str`, *optional*):
89
+ The type of positional embeddings to apply to the sequence input before passing use.
90
+ num_positional_embeddings: (`int`, *optional*):
91
+ The maximum length of the sequence over which to apply positional embeddings.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ num_attention_heads: int = 16,
97
+ attention_head_dim: int = 88,
98
+ in_channels: Optional[int] = None,
99
+ out_channels: Optional[int] = None,
100
+ num_layers: int = 1,
101
+ dropout: float = 0.0,
102
+ norm_num_groups: int = 32,
103
+ cross_attention_dim: Optional[int] = None,
104
+ attention_bias: bool = False,
105
+ sample_size: Optional[int] = None,
106
+ activation_fn: str = "geglu",
107
+ norm_elementwise_affine: bool = True,
108
+ double_self_attention: bool = True,
109
+ positional_embeddings: Optional[str] = None,
110
+ num_positional_embeddings: Optional[int] = None,
111
+ ):
112
+ super().__init__()
113
+ self.num_attention_heads = num_attention_heads
114
+ self.attention_head_dim = attention_head_dim
115
+ inner_dim = num_attention_heads * attention_head_dim
116
+
117
+ self.in_channels = in_channels
118
+
119
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
120
+ self.proj_in = nn.Linear(in_channels, inner_dim)
121
+
122
+ # 3. Define transformers blocks
123
+ self.transformer_blocks = nn.ModuleList(
124
+ [
125
+ BasicTransformerBlock(
126
+ inner_dim,
127
+ num_attention_heads,
128
+ attention_head_dim,
129
+ dropout=dropout,
130
+ cross_attention_dim=cross_attention_dim,
131
+ activation_fn=activation_fn,
132
+ attention_bias=attention_bias,
133
+ double_self_attention=double_self_attention,
134
+ norm_elementwise_affine=norm_elementwise_affine,
135
+ positional_embeddings=positional_embeddings,
136
+ num_positional_embeddings=num_positional_embeddings,
137
+ )
138
+ for _ in range(num_layers)
139
+ ]
140
+ )
141
+
142
+ self.proj_out = nn.Linear(inner_dim, in_channels)
143
+
144
+ def forward(
145
+ self,
146
+ hidden_states: torch.Tensor,
147
+ encoder_hidden_states: Optional[torch.LongTensor] = None,
148
+ timestep: Optional[torch.LongTensor] = None,
149
+ class_labels: Optional[torch.LongTensor] = None,
150
+ num_frames: int = 1,
151
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
152
+ ) -> torch.Tensor:
153
+ """
154
+ The [`AnimateDiffTransformer3D`] forward method.
155
+
156
+ Args:
157
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
158
+ Input hidden_states.
159
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
160
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
161
+ self-attention.
162
+ timestep ( `torch.LongTensor`, *optional*):
163
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
164
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
165
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
166
+ `AdaLayerZeroNorm`.
167
+ num_frames (`int`, *optional*, defaults to 1):
168
+ The number of frames to be processed per batch. This is used to reshape the hidden states.
169
+ cross_attention_kwargs (`dict`, *optional*):
170
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
171
+ `self.processor` in
172
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
173
+
174
+ Returns:
175
+ torch.Tensor:
176
+ The output tensor.
177
+ """
178
+ # 1. Input
179
+ batch_frames, channel, height, width = hidden_states.shape
180
+ batch_size = batch_frames // num_frames
181
+
182
+ residual = hidden_states
183
+
184
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
185
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
186
+
187
+ hidden_states = self.norm(hidden_states)
188
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
189
+
190
+ hidden_states = self.proj_in(hidden_states)
191
+
192
+ # 2. Blocks
193
+ for block in self.transformer_blocks:
194
+ hidden_states = block(
195
+ hidden_states,
196
+ encoder_hidden_states=encoder_hidden_states,
197
+ timestep=timestep,
198
+ cross_attention_kwargs=cross_attention_kwargs,
199
+ class_labels=class_labels,
200
+ )
201
+
202
+ # 3. Output
203
+ hidden_states = self.proj_out(hidden_states)
204
+ hidden_states = (
205
+ hidden_states[None, None, :]
206
+ .reshape(batch_size, height, width, num_frames, channel)
207
+ .permute(0, 3, 4, 1, 2)
208
+ .contiguous()
209
+ )
210
+ hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
211
+
212
+ output = hidden_states + residual
213
+ return output
214
+
215
+
216
+ class DownBlockMotion(nn.Module):
217
+ def __init__(
218
+ self,
219
+ in_channels: int,
220
+ out_channels: int,
221
+ temb_channels: int,
222
+ dropout: float = 0.0,
223
+ num_layers: int = 1,
224
+ resnet_eps: float = 1e-6,
225
+ resnet_time_scale_shift: str = "default",
226
+ resnet_act_fn: str = "swish",
227
+ resnet_groups: int = 32,
228
+ resnet_pre_norm: bool = True,
229
+ output_scale_factor: float = 1.0,
230
+ add_downsample: bool = True,
231
+ downsample_padding: int = 1,
232
+ temporal_num_attention_heads: Union[int, Tuple[int]] = 1,
233
+ temporal_cross_attention_dim: Optional[int] = None,
234
+ temporal_max_seq_length: int = 32,
235
+ temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
236
+ temporal_double_self_attention: bool = True,
237
+ ):
238
+ super().__init__()
239
+ resnets = []
240
+ motion_modules = []
241
+
242
+ # support for variable transformer layers per temporal block
243
+ if isinstance(temporal_transformer_layers_per_block, int):
244
+ temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
245
+ elif len(temporal_transformer_layers_per_block) != num_layers:
246
+ raise ValueError(
247
+ f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}"
248
+ )
249
+
250
+ # support for variable number of attention head per temporal layers
251
+ if isinstance(temporal_num_attention_heads, int):
252
+ temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers
253
+ elif len(temporal_num_attention_heads) != num_layers:
254
+ raise ValueError(
255
+ f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}"
256
+ )
257
+
258
+ for i in range(num_layers):
259
+ in_channels = in_channels if i == 0 else out_channels
260
+ resnets.append(
261
+ ResnetBlock2D(
262
+ in_channels=in_channels,
263
+ out_channels=out_channels,
264
+ temb_channels=temb_channels,
265
+ eps=resnet_eps,
266
+ groups=resnet_groups,
267
+ dropout=dropout,
268
+ time_embedding_norm=resnet_time_scale_shift,
269
+ non_linearity=resnet_act_fn,
270
+ output_scale_factor=output_scale_factor,
271
+ pre_norm=resnet_pre_norm,
272
+ )
273
+ )
274
+ motion_modules.append(
275
+ AnimateDiffTransformer3D(
276
+ num_attention_heads=temporal_num_attention_heads[i],
277
+ in_channels=out_channels,
278
+ num_layers=temporal_transformer_layers_per_block[i],
279
+ norm_num_groups=resnet_groups,
280
+ cross_attention_dim=temporal_cross_attention_dim,
281
+ attention_bias=False,
282
+ activation_fn="geglu",
283
+ positional_embeddings="sinusoidal",
284
+ num_positional_embeddings=temporal_max_seq_length,
285
+ attention_head_dim=out_channels // temporal_num_attention_heads[i],
286
+ double_self_attention=temporal_double_self_attention,
287
+ )
288
+ )
289
+
290
+ self.resnets = nn.ModuleList(resnets)
291
+ self.motion_modules = nn.ModuleList(motion_modules)
292
+
293
+ if add_downsample:
294
+ self.downsamplers = nn.ModuleList(
295
+ [
296
+ Downsample2D(
297
+ out_channels,
298
+ use_conv=True,
299
+ out_channels=out_channels,
300
+ padding=downsample_padding,
301
+ name="op",
302
+ )
303
+ ]
304
+ )
305
+ else:
306
+ self.downsamplers = None
307
+
308
+ self.gradient_checkpointing = False
309
+
310
+ def forward(
311
+ self,
312
+ hidden_states: torch.Tensor,
313
+ temb: Optional[torch.Tensor] = None,
314
+ num_frames: int = 1,
315
+ *args,
316
+ **kwargs,
317
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
318
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
319
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
320
+ deprecate("scale", "1.0.0", deprecation_message)
321
+
322
+ output_states = ()
323
+
324
+ blocks = zip(self.resnets, self.motion_modules)
325
+ for resnet, motion_module in blocks:
326
+ if self.training and self.gradient_checkpointing:
327
+
328
+ def create_custom_forward(module):
329
+ def custom_forward(*inputs):
330
+ return module(*inputs)
331
+
332
+ return custom_forward
333
+
334
+ if is_torch_version(">=", "1.11.0"):
335
+ hidden_states = torch.utils.checkpoint.checkpoint(
336
+ create_custom_forward(resnet),
337
+ hidden_states,
338
+ temb,
339
+ use_reentrant=False,
340
+ )
341
+ else:
342
+ hidden_states = torch.utils.checkpoint.checkpoint(
343
+ create_custom_forward(resnet), hidden_states, temb
344
+ )
345
+
346
+ else:
347
+ hidden_states = resnet(hidden_states, temb)
348
+
349
+ hidden_states = motion_module(hidden_states, num_frames=num_frames)
350
+
351
+ output_states = output_states + (hidden_states,)
352
+
353
+ if self.downsamplers is not None:
354
+ for downsampler in self.downsamplers:
355
+ hidden_states = downsampler(hidden_states)
356
+
357
+ output_states = output_states + (hidden_states,)
358
+
359
+ return hidden_states, output_states
360
+
361
+
362
+ class CrossAttnDownBlockMotion(nn.Module):
363
+ def __init__(
364
+ self,
365
+ in_channels: int,
366
+ out_channels: int,
367
+ temb_channels: int,
368
+ dropout: float = 0.0,
369
+ num_layers: int = 1,
370
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
371
+ resnet_eps: float = 1e-6,
372
+ resnet_time_scale_shift: str = "default",
373
+ resnet_act_fn: str = "swish",
374
+ resnet_groups: int = 32,
375
+ resnet_pre_norm: bool = True,
376
+ num_attention_heads: int = 1,
377
+ cross_attention_dim: int = 1280,
378
+ output_scale_factor: float = 1.0,
379
+ downsample_padding: int = 1,
380
+ add_downsample: bool = True,
381
+ dual_cross_attention: bool = False,
382
+ use_linear_projection: bool = False,
383
+ only_cross_attention: bool = False,
384
+ upcast_attention: bool = False,
385
+ attention_type: str = "default",
386
+ temporal_cross_attention_dim: Optional[int] = None,
387
+ temporal_num_attention_heads: int = 8,
388
+ temporal_max_seq_length: int = 32,
389
+ temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
390
+ temporal_double_self_attention: bool = True,
391
+ ):
392
+ super().__init__()
393
+ resnets = []
394
+ attentions = []
395
+ motion_modules = []
396
+
397
+ self.has_cross_attention = True
398
+ self.num_attention_heads = num_attention_heads
399
+
400
+ # support for variable transformer layers per block
401
+ if isinstance(transformer_layers_per_block, int):
402
+ transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
403
+ elif len(transformer_layers_per_block) != num_layers:
404
+ raise ValueError(
405
+ f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
406
+ )
407
+
408
+ # support for variable transformer layers per temporal block
409
+ if isinstance(temporal_transformer_layers_per_block, int):
410
+ temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
411
+ elif len(temporal_transformer_layers_per_block) != num_layers:
412
+ raise ValueError(
413
+ f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
414
+ )
415
+
416
+ for i in range(num_layers):
417
+ in_channels = in_channels if i == 0 else out_channels
418
+ resnets.append(
419
+ ResnetBlock2D(
420
+ in_channels=in_channels,
421
+ out_channels=out_channels,
422
+ temb_channels=temb_channels,
423
+ eps=resnet_eps,
424
+ groups=resnet_groups,
425
+ dropout=dropout,
426
+ time_embedding_norm=resnet_time_scale_shift,
427
+ non_linearity=resnet_act_fn,
428
+ output_scale_factor=output_scale_factor,
429
+ pre_norm=resnet_pre_norm,
430
+ )
431
+ )
432
+
433
+ if not dual_cross_attention:
434
+ attentions.append(
435
+ Transformer2DModel(
436
+ num_attention_heads,
437
+ out_channels // num_attention_heads,
438
+ in_channels=out_channels,
439
+ num_layers=transformer_layers_per_block[i],
440
+ cross_attention_dim=cross_attention_dim,
441
+ norm_num_groups=resnet_groups,
442
+ use_linear_projection=use_linear_projection,
443
+ only_cross_attention=only_cross_attention,
444
+ upcast_attention=upcast_attention,
445
+ attention_type=attention_type,
446
+ )
447
+ )
448
+ else:
449
+ attentions.append(
450
+ DualTransformer2DModel(
451
+ num_attention_heads,
452
+ out_channels // num_attention_heads,
453
+ in_channels=out_channels,
454
+ num_layers=1,
455
+ cross_attention_dim=cross_attention_dim,
456
+ norm_num_groups=resnet_groups,
457
+ )
458
+ )
459
+
460
+ motion_modules.append(
461
+ AnimateDiffTransformer3D(
462
+ num_attention_heads=temporal_num_attention_heads,
463
+ in_channels=out_channels,
464
+ num_layers=temporal_transformer_layers_per_block[i],
465
+ norm_num_groups=resnet_groups,
466
+ cross_attention_dim=temporal_cross_attention_dim,
467
+ attention_bias=False,
468
+ activation_fn="geglu",
469
+ positional_embeddings="sinusoidal",
470
+ num_positional_embeddings=temporal_max_seq_length,
471
+ attention_head_dim=out_channels // temporal_num_attention_heads,
472
+ double_self_attention=temporal_double_self_attention,
473
+ )
474
+ )
475
+
476
+ self.attentions = nn.ModuleList(attentions)
477
+ self.resnets = nn.ModuleList(resnets)
478
+ self.motion_modules = nn.ModuleList(motion_modules)
479
+
480
+ if add_downsample:
481
+ self.downsamplers = nn.ModuleList(
482
+ [
483
+ Downsample2D(
484
+ out_channels,
485
+ use_conv=True,
486
+ out_channels=out_channels,
487
+ padding=downsample_padding,
488
+ name="op",
489
+ )
490
+ ]
491
+ )
492
+ else:
493
+ self.downsamplers = None
494
+
495
+ self.gradient_checkpointing = False
496
+
497
+ def forward(
498
+ self,
499
+ hidden_states: torch.Tensor,
500
+ temb: Optional[torch.Tensor] = None,
501
+ encoder_hidden_states: Optional[torch.Tensor] = None,
502
+ attention_mask: Optional[torch.Tensor] = None,
503
+ num_frames: int = 1,
504
+ encoder_attention_mask: Optional[torch.Tensor] = None,
505
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
506
+ additional_residuals: Optional[torch.Tensor] = None,
507
+ ):
508
+ if cross_attention_kwargs is not None:
509
+ if cross_attention_kwargs.get("scale", None) is not None:
510
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
511
+
512
+ output_states = ()
513
+
514
+ blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
515
+ for i, (resnet, attn, motion_module) in enumerate(blocks):
516
+ if self.training and self.gradient_checkpointing:
517
+
518
+ def create_custom_forward(module, return_dict=None):
519
+ def custom_forward(*inputs):
520
+ if return_dict is not None:
521
+ return module(*inputs, return_dict=return_dict)
522
+ else:
523
+ return module(*inputs)
524
+
525
+ return custom_forward
526
+
527
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
528
+ hidden_states = torch.utils.checkpoint.checkpoint(
529
+ create_custom_forward(resnet),
530
+ hidden_states,
531
+ temb,
532
+ **ckpt_kwargs,
533
+ )
534
+ hidden_states = attn(
535
+ hidden_states,
536
+ encoder_hidden_states=encoder_hidden_states,
537
+ cross_attention_kwargs=cross_attention_kwargs,
538
+ attention_mask=attention_mask,
539
+ encoder_attention_mask=encoder_attention_mask,
540
+ return_dict=False,
541
+ )[0]
542
+ else:
543
+ hidden_states = resnet(hidden_states, temb)
544
+
545
+ hidden_states = attn(
546
+ hidden_states,
547
+ encoder_hidden_states=encoder_hidden_states,
548
+ cross_attention_kwargs=cross_attention_kwargs,
549
+ attention_mask=attention_mask,
550
+ encoder_attention_mask=encoder_attention_mask,
551
+ return_dict=False,
552
+ )[0]
553
+ hidden_states = motion_module(
554
+ hidden_states,
555
+ num_frames=num_frames,
556
+ )
557
+
558
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
559
+ if i == len(blocks) - 1 and additional_residuals is not None:
560
+ hidden_states = hidden_states + additional_residuals
561
+
562
+ output_states = output_states + (hidden_states,)
563
+
564
+ if self.downsamplers is not None:
565
+ for downsampler in self.downsamplers:
566
+ hidden_states = downsampler(hidden_states)
567
+
568
+ output_states = output_states + (hidden_states,)
569
+
570
+ return hidden_states, output_states
571
+
572
+
573
+ class CrossAttnUpBlockMotion(nn.Module):
574
+ def __init__(
575
+ self,
576
+ in_channels: int,
577
+ out_channels: int,
578
+ prev_output_channel: int,
579
+ temb_channels: int,
580
+ resolution_idx: Optional[int] = None,
581
+ dropout: float = 0.0,
582
+ num_layers: int = 1,
583
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
584
+ resnet_eps: float = 1e-6,
585
+ resnet_time_scale_shift: str = "default",
586
+ resnet_act_fn: str = "swish",
587
+ resnet_groups: int = 32,
588
+ resnet_pre_norm: bool = True,
589
+ num_attention_heads: int = 1,
590
+ cross_attention_dim: int = 1280,
591
+ output_scale_factor: float = 1.0,
592
+ add_upsample: bool = True,
593
+ dual_cross_attention: bool = False,
594
+ use_linear_projection: bool = False,
595
+ only_cross_attention: bool = False,
596
+ upcast_attention: bool = False,
597
+ attention_type: str = "default",
598
+ temporal_cross_attention_dim: Optional[int] = None,
599
+ temporal_num_attention_heads: int = 8,
600
+ temporal_max_seq_length: int = 32,
601
+ temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
602
+ ):
603
+ super().__init__()
604
+ resnets = []
605
+ attentions = []
606
+ motion_modules = []
607
+
608
+ self.has_cross_attention = True
609
+ self.num_attention_heads = num_attention_heads
610
+
611
+ # support for variable transformer layers per block
612
+ if isinstance(transformer_layers_per_block, int):
613
+ transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
614
+ elif len(transformer_layers_per_block) != num_layers:
615
+ raise ValueError(
616
+ f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}"
617
+ )
618
+
619
+ # support for variable transformer layers per temporal block
620
+ if isinstance(temporal_transformer_layers_per_block, int):
621
+ temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
622
+ elif len(temporal_transformer_layers_per_block) != num_layers:
623
+ raise ValueError(
624
+ f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}"
625
+ )
626
+
627
+ for i in range(num_layers):
628
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
629
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
630
+
631
+ resnets.append(
632
+ ResnetBlock2D(
633
+ in_channels=resnet_in_channels + res_skip_channels,
634
+ out_channels=out_channels,
635
+ temb_channels=temb_channels,
636
+ eps=resnet_eps,
637
+ groups=resnet_groups,
638
+ dropout=dropout,
639
+ time_embedding_norm=resnet_time_scale_shift,
640
+ non_linearity=resnet_act_fn,
641
+ output_scale_factor=output_scale_factor,
642
+ pre_norm=resnet_pre_norm,
643
+ )
644
+ )
645
+
646
+ if not dual_cross_attention:
647
+ attentions.append(
648
+ Transformer2DModel(
649
+ num_attention_heads,
650
+ out_channels // num_attention_heads,
651
+ in_channels=out_channels,
652
+ num_layers=transformer_layers_per_block[i],
653
+ cross_attention_dim=cross_attention_dim,
654
+ norm_num_groups=resnet_groups,
655
+ use_linear_projection=use_linear_projection,
656
+ only_cross_attention=only_cross_attention,
657
+ upcast_attention=upcast_attention,
658
+ attention_type=attention_type,
659
+ )
660
+ )
661
+ else:
662
+ attentions.append(
663
+ DualTransformer2DModel(
664
+ num_attention_heads,
665
+ out_channels // num_attention_heads,
666
+ in_channels=out_channels,
667
+ num_layers=1,
668
+ cross_attention_dim=cross_attention_dim,
669
+ norm_num_groups=resnet_groups,
670
+ )
671
+ )
672
+ motion_modules.append(
673
+ AnimateDiffTransformer3D(
674
+ num_attention_heads=temporal_num_attention_heads,
675
+ in_channels=out_channels,
676
+ num_layers=temporal_transformer_layers_per_block[i],
677
+ norm_num_groups=resnet_groups,
678
+ cross_attention_dim=temporal_cross_attention_dim,
679
+ attention_bias=False,
680
+ activation_fn="geglu",
681
+ positional_embeddings="sinusoidal",
682
+ num_positional_embeddings=temporal_max_seq_length,
683
+ attention_head_dim=out_channels // temporal_num_attention_heads,
684
+ )
685
+ )
686
+
687
+ self.attentions = nn.ModuleList(attentions)
688
+ self.resnets = nn.ModuleList(resnets)
689
+ self.motion_modules = nn.ModuleList(motion_modules)
690
+
691
+ if add_upsample:
692
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
693
+ else:
694
+ self.upsamplers = None
695
+
696
+ self.gradient_checkpointing = False
697
+ self.resolution_idx = resolution_idx
698
+
699
+ def forward(
700
+ self,
701
+ hidden_states: torch.Tensor,
702
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
703
+ temb: Optional[torch.Tensor] = None,
704
+ encoder_hidden_states: Optional[torch.Tensor] = None,
705
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
706
+ upsample_size: Optional[int] = None,
707
+ attention_mask: Optional[torch.Tensor] = None,
708
+ encoder_attention_mask: Optional[torch.Tensor] = None,
709
+ num_frames: int = 1,
710
+ ) -> torch.Tensor:
711
+ if cross_attention_kwargs is not None:
712
+ if cross_attention_kwargs.get("scale", None) is not None:
713
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
714
+
715
+ is_freeu_enabled = (
716
+ getattr(self, "s1", None)
717
+ and getattr(self, "s2", None)
718
+ and getattr(self, "b1", None)
719
+ and getattr(self, "b2", None)
720
+ )
721
+
722
+ blocks = zip(self.resnets, self.attentions, self.motion_modules)
723
+ for resnet, attn, motion_module in blocks:
724
+ # pop res hidden states
725
+ res_hidden_states = res_hidden_states_tuple[-1]
726
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
727
+
728
+ # FreeU: Only operate on the first two stages
729
+ if is_freeu_enabled:
730
+ hidden_states, res_hidden_states = apply_freeu(
731
+ self.resolution_idx,
732
+ hidden_states,
733
+ res_hidden_states,
734
+ s1=self.s1,
735
+ s2=self.s2,
736
+ b1=self.b1,
737
+ b2=self.b2,
738
+ )
739
+
740
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
741
+
742
+ if self.training and self.gradient_checkpointing:
743
+
744
+ def create_custom_forward(module, return_dict=None):
745
+ def custom_forward(*inputs):
746
+ if return_dict is not None:
747
+ return module(*inputs, return_dict=return_dict)
748
+ else:
749
+ return module(*inputs)
750
+
751
+ return custom_forward
752
+
753
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
754
+ hidden_states = torch.utils.checkpoint.checkpoint(
755
+ create_custom_forward(resnet),
756
+ hidden_states,
757
+ temb,
758
+ **ckpt_kwargs,
759
+ )
760
+ hidden_states = attn(
761
+ hidden_states,
762
+ encoder_hidden_states=encoder_hidden_states,
763
+ cross_attention_kwargs=cross_attention_kwargs,
764
+ attention_mask=attention_mask,
765
+ encoder_attention_mask=encoder_attention_mask,
766
+ return_dict=False,
767
+ )[0]
768
+ else:
769
+ hidden_states = resnet(hidden_states, temb)
770
+
771
+ hidden_states = attn(
772
+ hidden_states,
773
+ encoder_hidden_states=encoder_hidden_states,
774
+ cross_attention_kwargs=cross_attention_kwargs,
775
+ attention_mask=attention_mask,
776
+ encoder_attention_mask=encoder_attention_mask,
777
+ return_dict=False,
778
+ )[0]
779
+ hidden_states = motion_module(
780
+ hidden_states,
781
+ num_frames=num_frames,
782
+ )
783
+
784
+ if self.upsamplers is not None:
785
+ for upsampler in self.upsamplers:
786
+ hidden_states = upsampler(hidden_states, upsample_size)
787
+
788
+ return hidden_states
789
+
790
+
791
+ class UpBlockMotion(nn.Module):
792
+ def __init__(
793
+ self,
794
+ in_channels: int,
795
+ prev_output_channel: int,
796
+ out_channels: int,
797
+ temb_channels: int,
798
+ resolution_idx: Optional[int] = None,
799
+ dropout: float = 0.0,
800
+ num_layers: int = 1,
801
+ resnet_eps: float = 1e-6,
802
+ resnet_time_scale_shift: str = "default",
803
+ resnet_act_fn: str = "swish",
804
+ resnet_groups: int = 32,
805
+ resnet_pre_norm: bool = True,
806
+ output_scale_factor: float = 1.0,
807
+ add_upsample: bool = True,
808
+ temporal_cross_attention_dim: Optional[int] = None,
809
+ temporal_num_attention_heads: int = 8,
810
+ temporal_max_seq_length: int = 32,
811
+ temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
812
+ ):
813
+ super().__init__()
814
+ resnets = []
815
+ motion_modules = []
816
+
817
+ # support for variable transformer layers per temporal block
818
+ if isinstance(temporal_transformer_layers_per_block, int):
819
+ temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
820
+ elif len(temporal_transformer_layers_per_block) != num_layers:
821
+ raise ValueError(
822
+ f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
823
+ )
824
+
825
+ for i in range(num_layers):
826
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
827
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
828
+
829
+ resnets.append(
830
+ ResnetBlock2D(
831
+ in_channels=resnet_in_channels + res_skip_channels,
832
+ out_channels=out_channels,
833
+ temb_channels=temb_channels,
834
+ eps=resnet_eps,
835
+ groups=resnet_groups,
836
+ dropout=dropout,
837
+ time_embedding_norm=resnet_time_scale_shift,
838
+ non_linearity=resnet_act_fn,
839
+ output_scale_factor=output_scale_factor,
840
+ pre_norm=resnet_pre_norm,
841
+ )
842
+ )
843
+
844
+ motion_modules.append(
845
+ AnimateDiffTransformer3D(
846
+ num_attention_heads=temporal_num_attention_heads,
847
+ in_channels=out_channels,
848
+ num_layers=temporal_transformer_layers_per_block[i],
849
+ norm_num_groups=resnet_groups,
850
+ cross_attention_dim=temporal_cross_attention_dim,
851
+ attention_bias=False,
852
+ activation_fn="geglu",
853
+ positional_embeddings="sinusoidal",
854
+ num_positional_embeddings=temporal_max_seq_length,
855
+ attention_head_dim=out_channels // temporal_num_attention_heads,
856
+ )
857
+ )
858
+
859
+ self.resnets = nn.ModuleList(resnets)
860
+ self.motion_modules = nn.ModuleList(motion_modules)
861
+
862
+ if add_upsample:
863
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
864
+ else:
865
+ self.upsamplers = None
866
+
867
+ self.gradient_checkpointing = False
868
+ self.resolution_idx = resolution_idx
869
+
870
+ def forward(
871
+ self,
872
+ hidden_states: torch.Tensor,
873
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
874
+ temb: Optional[torch.Tensor] = None,
875
+ upsample_size=None,
876
+ num_frames: int = 1,
877
+ *args,
878
+ **kwargs,
879
+ ) -> torch.Tensor:
880
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
881
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
882
+ deprecate("scale", "1.0.0", deprecation_message)
883
+
884
+ is_freeu_enabled = (
885
+ getattr(self, "s1", None)
886
+ and getattr(self, "s2", None)
887
+ and getattr(self, "b1", None)
888
+ and getattr(self, "b2", None)
889
+ )
890
+
891
+ blocks = zip(self.resnets, self.motion_modules)
892
+
893
+ for resnet, motion_module in blocks:
894
+ # pop res hidden states
895
+ res_hidden_states = res_hidden_states_tuple[-1]
896
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
897
+
898
+ # FreeU: Only operate on the first two stages
899
+ if is_freeu_enabled:
900
+ hidden_states, res_hidden_states = apply_freeu(
901
+ self.resolution_idx,
902
+ hidden_states,
903
+ res_hidden_states,
904
+ s1=self.s1,
905
+ s2=self.s2,
906
+ b1=self.b1,
907
+ b2=self.b2,
908
+ )
909
+
910
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
911
+
912
+ if self.training and self.gradient_checkpointing:
913
+
914
+ def create_custom_forward(module):
915
+ def custom_forward(*inputs):
916
+ return module(*inputs)
917
+
918
+ return custom_forward
919
+
920
+ if is_torch_version(">=", "1.11.0"):
921
+ hidden_states = torch.utils.checkpoint.checkpoint(
922
+ create_custom_forward(resnet),
923
+ hidden_states,
924
+ temb,
925
+ use_reentrant=False,
926
+ )
927
+ else:
928
+ hidden_states = torch.utils.checkpoint.checkpoint(
929
+ create_custom_forward(resnet), hidden_states, temb
930
+ )
931
+ else:
932
+ hidden_states = resnet(hidden_states, temb)
933
+
934
+ hidden_states = motion_module(hidden_states, num_frames=num_frames)
935
+
936
+ if self.upsamplers is not None:
937
+ for upsampler in self.upsamplers:
938
+ hidden_states = upsampler(hidden_states, upsample_size)
939
+
940
+ return hidden_states
941
+
942
+
943
+ class UNetMidBlockCrossAttnMotion(nn.Module):
944
+ def __init__(
945
+ self,
946
+ in_channels: int,
947
+ temb_channels: int,
948
+ dropout: float = 0.0,
949
+ num_layers: int = 1,
950
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
951
+ resnet_eps: float = 1e-6,
952
+ resnet_time_scale_shift: str = "default",
953
+ resnet_act_fn: str = "swish",
954
+ resnet_groups: int = 32,
955
+ resnet_pre_norm: bool = True,
956
+ num_attention_heads: int = 1,
957
+ output_scale_factor: float = 1.0,
958
+ cross_attention_dim: int = 1280,
959
+ dual_cross_attention: bool = False,
960
+ use_linear_projection: bool = False,
961
+ upcast_attention: bool = False,
962
+ attention_type: str = "default",
963
+ temporal_num_attention_heads: int = 1,
964
+ temporal_cross_attention_dim: Optional[int] = None,
965
+ temporal_max_seq_length: int = 32,
966
+ temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
967
+ ):
968
+ super().__init__()
969
+
970
+ self.has_cross_attention = True
971
+ self.num_attention_heads = num_attention_heads
972
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
973
+
974
+ # support for variable transformer layers per block
975
+ if isinstance(transformer_layers_per_block, int):
976
+ transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
977
+ elif len(transformer_layers_per_block) != num_layers:
978
+ raise ValueError(
979
+ f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
980
+ )
981
+
982
+ # support for variable transformer layers per temporal block
983
+ if isinstance(temporal_transformer_layers_per_block, int):
984
+ temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
985
+ elif len(temporal_transformer_layers_per_block) != num_layers:
986
+ raise ValueError(
987
+ f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
988
+ )
989
+
990
+ # there is always at least one resnet
991
+ resnets = [
992
+ ResnetBlock2D(
993
+ in_channels=in_channels,
994
+ out_channels=in_channels,
995
+ temb_channels=temb_channels,
996
+ eps=resnet_eps,
997
+ groups=resnet_groups,
998
+ dropout=dropout,
999
+ time_embedding_norm=resnet_time_scale_shift,
1000
+ non_linearity=resnet_act_fn,
1001
+ output_scale_factor=output_scale_factor,
1002
+ pre_norm=resnet_pre_norm,
1003
+ )
1004
+ ]
1005
+ attentions = []
1006
+ motion_modules = []
1007
+
1008
+ for i in range(num_layers):
1009
+ if not dual_cross_attention:
1010
+ attentions.append(
1011
+ Transformer2DModel(
1012
+ num_attention_heads,
1013
+ in_channels // num_attention_heads,
1014
+ in_channels=in_channels,
1015
+ num_layers=transformer_layers_per_block[i],
1016
+ cross_attention_dim=cross_attention_dim,
1017
+ norm_num_groups=resnet_groups,
1018
+ use_linear_projection=use_linear_projection,
1019
+ upcast_attention=upcast_attention,
1020
+ attention_type=attention_type,
1021
+ )
1022
+ )
1023
+ else:
1024
+ attentions.append(
1025
+ DualTransformer2DModel(
1026
+ num_attention_heads,
1027
+ in_channels // num_attention_heads,
1028
+ in_channels=in_channels,
1029
+ num_layers=1,
1030
+ cross_attention_dim=cross_attention_dim,
1031
+ norm_num_groups=resnet_groups,
1032
+ )
1033
+ )
1034
+ resnets.append(
1035
+ ResnetBlock2D(
1036
+ in_channels=in_channels,
1037
+ out_channels=in_channels,
1038
+ temb_channels=temb_channels,
1039
+ eps=resnet_eps,
1040
+ groups=resnet_groups,
1041
+ dropout=dropout,
1042
+ time_embedding_norm=resnet_time_scale_shift,
1043
+ non_linearity=resnet_act_fn,
1044
+ output_scale_factor=output_scale_factor,
1045
+ pre_norm=resnet_pre_norm,
1046
+ )
1047
+ )
1048
+ motion_modules.append(
1049
+ AnimateDiffTransformer3D(
1050
+ num_attention_heads=temporal_num_attention_heads,
1051
+ attention_head_dim=in_channels // temporal_num_attention_heads,
1052
+ in_channels=in_channels,
1053
+ num_layers=temporal_transformer_layers_per_block[i],
1054
+ norm_num_groups=resnet_groups,
1055
+ cross_attention_dim=temporal_cross_attention_dim,
1056
+ attention_bias=False,
1057
+ positional_embeddings="sinusoidal",
1058
+ num_positional_embeddings=temporal_max_seq_length,
1059
+ activation_fn="geglu",
1060
+ )
1061
+ )
1062
+
1063
+ self.attentions = nn.ModuleList(attentions)
1064
+ self.resnets = nn.ModuleList(resnets)
1065
+ self.motion_modules = nn.ModuleList(motion_modules)
1066
+
1067
+ self.gradient_checkpointing = False
1068
+
1069
+ def forward(
1070
+ self,
1071
+ hidden_states: torch.Tensor,
1072
+ temb: Optional[torch.Tensor] = None,
1073
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1074
+ attention_mask: Optional[torch.Tensor] = None,
1075
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1076
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1077
+ num_frames: int = 1,
1078
+ ) -> torch.Tensor:
1079
+ if cross_attention_kwargs is not None:
1080
+ if cross_attention_kwargs.get("scale", None) is not None:
1081
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1082
+
1083
+ hidden_states = self.resnets[0](hidden_states, temb)
1084
+
1085
+ blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
1086
+ for attn, resnet, motion_module in blocks:
1087
+ if self.training and self.gradient_checkpointing:
1088
+
1089
+ def create_custom_forward(module, return_dict=None):
1090
+ def custom_forward(*inputs):
1091
+ if return_dict is not None:
1092
+ return module(*inputs, return_dict=return_dict)
1093
+ else:
1094
+ return module(*inputs)
1095
+
1096
+ return custom_forward
1097
+
1098
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1099
+ hidden_states = attn(
1100
+ hidden_states,
1101
+ encoder_hidden_states=encoder_hidden_states,
1102
+ cross_attention_kwargs=cross_attention_kwargs,
1103
+ attention_mask=attention_mask,
1104
+ encoder_attention_mask=encoder_attention_mask,
1105
+ return_dict=False,
1106
+ )[0]
1107
+ hidden_states = torch.utils.checkpoint.checkpoint(
1108
+ create_custom_forward(motion_module),
1109
+ hidden_states,
1110
+ temb,
1111
+ **ckpt_kwargs,
1112
+ )
1113
+ hidden_states = torch.utils.checkpoint.checkpoint(
1114
+ create_custom_forward(resnet),
1115
+ hidden_states,
1116
+ temb,
1117
+ **ckpt_kwargs,
1118
+ )
1119
+ else:
1120
+ hidden_states = attn(
1121
+ hidden_states,
1122
+ encoder_hidden_states=encoder_hidden_states,
1123
+ cross_attention_kwargs=cross_attention_kwargs,
1124
+ attention_mask=attention_mask,
1125
+ encoder_attention_mask=encoder_attention_mask,
1126
+ return_dict=False,
1127
+ )[0]
1128
+ hidden_states = motion_module(
1129
+ hidden_states,
1130
+ num_frames=num_frames,
1131
+ )
1132
+ hidden_states = resnet(hidden_states, temb)
1133
+
1134
+ return hidden_states
1135
+
1136
+
55
1137
  class MotionModules(nn.Module):
56
1138
  def __init__(
57
1139
  self,
58
1140
  in_channels: int,
59
1141
  layers_per_block: int = 2,
60
- num_attention_heads: int = 8,
1142
+ transformer_layers_per_block: Union[int, Tuple[int]] = 8,
1143
+ num_attention_heads: Union[int, Tuple[int]] = 8,
61
1144
  attention_bias: bool = False,
62
1145
  cross_attention_dim: Optional[int] = None,
63
1146
  activation_fn: str = "geglu",
@@ -67,10 +1150,19 @@ class MotionModules(nn.Module):
67
1150
  super().__init__()
68
1151
  self.motion_modules = nn.ModuleList([])
69
1152
 
1153
+ if isinstance(transformer_layers_per_block, int):
1154
+ transformer_layers_per_block = (transformer_layers_per_block,) * layers_per_block
1155
+ elif len(transformer_layers_per_block) != layers_per_block:
1156
+ raise ValueError(
1157
+ f"The number of transformer layers per block must match the number of layers per block, "
1158
+ f"got {layers_per_block} and {len(transformer_layers_per_block)}"
1159
+ )
1160
+
70
1161
  for i in range(layers_per_block):
71
1162
  self.motion_modules.append(
72
- TransformerTemporalModel(
1163
+ AnimateDiffTransformer3D(
73
1164
  in_channels=in_channels,
1165
+ num_layers=transformer_layers_per_block[i],
74
1166
  norm_num_groups=norm_num_groups,
75
1167
  cross_attention_dim=cross_attention_dim,
76
1168
  activation_fn=activation_fn,
@@ -83,14 +1175,16 @@ class MotionModules(nn.Module):
83
1175
  )
84
1176
 
85
1177
 
86
- class MotionAdapter(ModelMixin, ConfigMixin):
1178
+ class MotionAdapter(ModelMixin, ConfigMixin, FromOriginalModelMixin):
87
1179
  @register_to_config
88
1180
  def __init__(
89
1181
  self,
90
1182
  block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
91
- motion_layers_per_block: int = 2,
1183
+ motion_layers_per_block: Union[int, Tuple[int]] = 2,
1184
+ motion_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]] = 1,
92
1185
  motion_mid_block_layers_per_block: int = 1,
93
- motion_num_attention_heads: int = 8,
1186
+ motion_transformer_layers_per_mid_block: Union[int, Tuple[int]] = 1,
1187
+ motion_num_attention_heads: Union[int, Tuple[int]] = 8,
94
1188
  motion_norm_num_groups: int = 32,
95
1189
  motion_max_seq_length: int = 32,
96
1190
  use_motion_mid_block: bool = True,
@@ -101,11 +1195,15 @@ class MotionAdapter(ModelMixin, ConfigMixin):
101
1195
  Args:
102
1196
  block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
103
1197
  The tuple of output channels for each UNet block.
104
- motion_layers_per_block (`int`, *optional*, defaults to 2):
1198
+ motion_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 2):
105
1199
  The number of motion layers per UNet block.
1200
+ motion_transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple[int]]`, *optional*, defaults to 1):
1201
+ The number of transformer layers to use in each motion layer in each block.
106
1202
  motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1):
107
1203
  The number of motion layers in the middle UNet block.
108
- motion_num_attention_heads (`int`, *optional*, defaults to 8):
1204
+ motion_transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
1205
+ The number of transformer layers to use in each motion layer in the middle block.
1206
+ motion_num_attention_heads (`int` or `Tuple[int]`, *optional*, defaults to 8):
109
1207
  The number of heads to use in each attention layer of the motion module.
110
1208
  motion_norm_num_groups (`int`, *optional*, defaults to 32):
111
1209
  The number of groups to use in each group normalization layer of the motion module.
@@ -119,6 +1217,35 @@ class MotionAdapter(ModelMixin, ConfigMixin):
119
1217
  down_blocks = []
120
1218
  up_blocks = []
121
1219
 
1220
+ if isinstance(motion_layers_per_block, int):
1221
+ motion_layers_per_block = (motion_layers_per_block,) * len(block_out_channels)
1222
+ elif len(motion_layers_per_block) != len(block_out_channels):
1223
+ raise ValueError(
1224
+ f"The number of motion layers per block must match the number of blocks, "
1225
+ f"got {len(block_out_channels)} and {len(motion_layers_per_block)}"
1226
+ )
1227
+
1228
+ if isinstance(motion_transformer_layers_per_block, int):
1229
+ motion_transformer_layers_per_block = (motion_transformer_layers_per_block,) * len(block_out_channels)
1230
+
1231
+ if isinstance(motion_transformer_layers_per_mid_block, int):
1232
+ motion_transformer_layers_per_mid_block = (
1233
+ motion_transformer_layers_per_mid_block,
1234
+ ) * motion_mid_block_layers_per_block
1235
+ elif len(motion_transformer_layers_per_mid_block) != motion_mid_block_layers_per_block:
1236
+ raise ValueError(
1237
+ f"The number of layers per mid block ({motion_mid_block_layers_per_block}) "
1238
+ f"must match the length of motion_transformer_layers_per_mid_block ({len(motion_transformer_layers_per_mid_block)})"
1239
+ )
1240
+
1241
+ if isinstance(motion_num_attention_heads, int):
1242
+ motion_num_attention_heads = (motion_num_attention_heads,) * len(block_out_channels)
1243
+ elif len(motion_num_attention_heads) != len(block_out_channels):
1244
+ raise ValueError(
1245
+ f"The length of the attention head number tuple in the motion module must match the "
1246
+ f"number of block, got {len(motion_num_attention_heads)} and {len(block_out_channels)}"
1247
+ )
1248
+
122
1249
  if conv_in_channels:
123
1250
  # input
124
1251
  self.conv_in = nn.Conv2d(conv_in_channels, block_out_channels[0], kernel_size=3, padding=1)
@@ -134,9 +1261,10 @@ class MotionAdapter(ModelMixin, ConfigMixin):
134
1261
  cross_attention_dim=None,
135
1262
  activation_fn="geglu",
136
1263
  attention_bias=False,
137
- num_attention_heads=motion_num_attention_heads,
1264
+ num_attention_heads=motion_num_attention_heads[i],
138
1265
  max_seq_length=motion_max_seq_length,
139
- layers_per_block=motion_layers_per_block,
1266
+ layers_per_block=motion_layers_per_block[i],
1267
+ transformer_layers_per_block=motion_transformer_layers_per_block[i],
140
1268
  )
141
1269
  )
142
1270
 
@@ -147,15 +1275,20 @@ class MotionAdapter(ModelMixin, ConfigMixin):
147
1275
  cross_attention_dim=None,
148
1276
  activation_fn="geglu",
149
1277
  attention_bias=False,
150
- num_attention_heads=motion_num_attention_heads,
151
- layers_per_block=motion_mid_block_layers_per_block,
1278
+ num_attention_heads=motion_num_attention_heads[-1],
152
1279
  max_seq_length=motion_max_seq_length,
1280
+ layers_per_block=motion_mid_block_layers_per_block,
1281
+ transformer_layers_per_block=motion_transformer_layers_per_mid_block,
153
1282
  )
154
1283
  else:
155
1284
  self.mid_block = None
156
1285
 
157
1286
  reversed_block_out_channels = list(reversed(block_out_channels))
158
1287
  output_channel = reversed_block_out_channels[0]
1288
+
1289
+ reversed_motion_layers_per_block = list(reversed(motion_layers_per_block))
1290
+ reversed_motion_transformer_layers_per_block = list(reversed(motion_transformer_layers_per_block))
1291
+ reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads))
159
1292
  for i, channel in enumerate(reversed_block_out_channels):
160
1293
  output_channel = reversed_block_out_channels[i]
161
1294
  up_blocks.append(
@@ -165,9 +1298,10 @@ class MotionAdapter(ModelMixin, ConfigMixin):
165
1298
  cross_attention_dim=None,
166
1299
  activation_fn="geglu",
167
1300
  attention_bias=False,
168
- num_attention_heads=motion_num_attention_heads,
1301
+ num_attention_heads=reversed_motion_num_attention_heads[i],
169
1302
  max_seq_length=motion_max_seq_length,
170
- layers_per_block=motion_layers_per_block + 1,
1303
+ layers_per_block=reversed_motion_layers_per_block[i] + 1,
1304
+ transformer_layers_per_block=reversed_motion_transformer_layers_per_block[i],
171
1305
  )
172
1306
  )
173
1307
 
@@ -178,7 +1312,7 @@ class MotionAdapter(ModelMixin, ConfigMixin):
178
1312
  pass
179
1313
 
180
1314
 
181
- class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
1315
+ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
182
1316
  r"""
183
1317
  A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a
184
1318
  sample shaped output.
@@ -208,7 +1342,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
208
1342
  "CrossAttnUpBlockMotion",
209
1343
  ),
210
1344
  block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
211
- layers_per_block: int = 2,
1345
+ layers_per_block: Union[int, Tuple[int]] = 2,
212
1346
  downsample_padding: int = 1,
213
1347
  mid_block_scale_factor: float = 1,
214
1348
  act_fn: str = "silu",
@@ -216,12 +1350,18 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
216
1350
  norm_eps: float = 1e-5,
217
1351
  cross_attention_dim: int = 1280,
218
1352
  transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
219
- reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
1353
+ reverse_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
1354
+ temporal_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
1355
+ reverse_temporal_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
1356
+ transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
1357
+ temporal_transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = 1,
220
1358
  use_linear_projection: bool = False,
221
1359
  num_attention_heads: Union[int, Tuple[int, ...]] = 8,
222
1360
  motion_max_seq_length: int = 32,
223
- motion_num_attention_heads: int = 8,
224
- use_motion_mid_block: int = True,
1361
+ motion_num_attention_heads: Union[int, Tuple[int, ...]] = 8,
1362
+ reverse_motion_num_attention_heads: Optional[Union[int, Tuple[int, ...], Tuple[Tuple[int, ...], ...]]] = None,
1363
+ use_motion_mid_block: bool = True,
1364
+ mid_block_layers: int = 1,
225
1365
  encoder_hid_dim: Optional[int] = None,
226
1366
  encoder_hid_dim_type: Optional[str] = None,
227
1367
  addition_embed_type: Optional[str] = None,
@@ -264,6 +1404,16 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
264
1404
  if isinstance(layer_number_per_block, list):
265
1405
  raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
266
1406
 
1407
+ if (
1408
+ isinstance(temporal_transformer_layers_per_block, list)
1409
+ and reverse_temporal_transformer_layers_per_block is None
1410
+ ):
1411
+ for layer_number_per_block in temporal_transformer_layers_per_block:
1412
+ if isinstance(layer_number_per_block, list):
1413
+ raise ValueError(
1414
+ "Must provide 'reverse_temporal_transformer_layers_per_block` if using asymmetrical motion module in UNet."
1415
+ )
1416
+
267
1417
  # input
268
1418
  conv_in_kernel = 3
269
1419
  conv_out_kernel = 3
@@ -304,6 +1454,20 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
304
1454
  if isinstance(transformer_layers_per_block, int):
305
1455
  transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
306
1456
 
1457
+ if isinstance(reverse_transformer_layers_per_block, int):
1458
+ reverse_transformer_layers_per_block = [reverse_transformer_layers_per_block] * len(down_block_types)
1459
+
1460
+ if isinstance(temporal_transformer_layers_per_block, int):
1461
+ temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types)
1462
+
1463
+ if isinstance(reverse_temporal_transformer_layers_per_block, int):
1464
+ reverse_temporal_transformer_layers_per_block = [reverse_temporal_transformer_layers_per_block] * len(
1465
+ down_block_types
1466
+ )
1467
+
1468
+ if isinstance(motion_num_attention_heads, int):
1469
+ motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types)
1470
+
307
1471
  # down
308
1472
  output_channel = block_out_channels[0]
309
1473
  for i, down_block_type in enumerate(down_block_types):
@@ -311,28 +1475,53 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
311
1475
  output_channel = block_out_channels[i]
312
1476
  is_final_block = i == len(block_out_channels) - 1
313
1477
 
314
- down_block = get_down_block(
315
- down_block_type,
316
- num_layers=layers_per_block[i],
317
- in_channels=input_channel,
318
- out_channels=output_channel,
319
- temb_channels=time_embed_dim,
320
- add_downsample=not is_final_block,
321
- resnet_eps=norm_eps,
322
- resnet_act_fn=act_fn,
323
- resnet_groups=norm_num_groups,
324
- cross_attention_dim=cross_attention_dim[i],
325
- num_attention_heads=num_attention_heads[i],
326
- downsample_padding=downsample_padding,
327
- use_linear_projection=use_linear_projection,
328
- dual_cross_attention=False,
329
- temporal_num_attention_heads=motion_num_attention_heads,
330
- temporal_max_seq_length=motion_max_seq_length,
331
- transformer_layers_per_block=transformer_layers_per_block[i],
332
- )
1478
+ if down_block_type == "CrossAttnDownBlockMotion":
1479
+ down_block = CrossAttnDownBlockMotion(
1480
+ in_channels=input_channel,
1481
+ out_channels=output_channel,
1482
+ temb_channels=time_embed_dim,
1483
+ num_layers=layers_per_block[i],
1484
+ transformer_layers_per_block=transformer_layers_per_block[i],
1485
+ resnet_eps=norm_eps,
1486
+ resnet_act_fn=act_fn,
1487
+ resnet_groups=norm_num_groups,
1488
+ num_attention_heads=num_attention_heads[i],
1489
+ cross_attention_dim=cross_attention_dim[i],
1490
+ downsample_padding=downsample_padding,
1491
+ add_downsample=not is_final_block,
1492
+ use_linear_projection=use_linear_projection,
1493
+ temporal_num_attention_heads=motion_num_attention_heads[i],
1494
+ temporal_max_seq_length=motion_max_seq_length,
1495
+ temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
1496
+ )
1497
+ elif down_block_type == "DownBlockMotion":
1498
+ down_block = DownBlockMotion(
1499
+ in_channels=input_channel,
1500
+ out_channels=output_channel,
1501
+ temb_channels=time_embed_dim,
1502
+ num_layers=layers_per_block[i],
1503
+ resnet_eps=norm_eps,
1504
+ resnet_act_fn=act_fn,
1505
+ resnet_groups=norm_num_groups,
1506
+ add_downsample=not is_final_block,
1507
+ downsample_padding=downsample_padding,
1508
+ temporal_num_attention_heads=motion_num_attention_heads[i],
1509
+ temporal_max_seq_length=motion_max_seq_length,
1510
+ temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
1511
+ )
1512
+ else:
1513
+ raise ValueError(
1514
+ "Invalid `down_block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`"
1515
+ )
1516
+
333
1517
  self.down_blocks.append(down_block)
334
1518
 
335
1519
  # mid
1520
+ if transformer_layers_per_mid_block is None:
1521
+ transformer_layers_per_mid_block = (
1522
+ transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1
1523
+ )
1524
+
336
1525
  if use_motion_mid_block:
337
1526
  self.mid_block = UNetMidBlockCrossAttnMotion(
338
1527
  in_channels=block_out_channels[-1],
@@ -345,9 +1534,11 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
345
1534
  resnet_groups=norm_num_groups,
346
1535
  dual_cross_attention=False,
347
1536
  use_linear_projection=use_linear_projection,
348
- temporal_num_attention_heads=motion_num_attention_heads,
1537
+ num_layers=mid_block_layers,
1538
+ temporal_num_attention_heads=motion_num_attention_heads[-1],
349
1539
  temporal_max_seq_length=motion_max_seq_length,
350
- transformer_layers_per_block=transformer_layers_per_block[-1],
1540
+ transformer_layers_per_block=transformer_layers_per_mid_block,
1541
+ temporal_transformer_layers_per_block=temporal_transformer_layers_per_mid_block,
351
1542
  )
352
1543
 
353
1544
  else:
@@ -362,7 +1553,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
362
1553
  resnet_groups=norm_num_groups,
363
1554
  dual_cross_attention=False,
364
1555
  use_linear_projection=use_linear_projection,
365
- transformer_layers_per_block=transformer_layers_per_block[-1],
1556
+ num_layers=mid_block_layers,
1557
+ transformer_layers_per_block=transformer_layers_per_mid_block,
366
1558
  )
367
1559
 
368
1560
  # count how many layers upsample the images
@@ -373,7 +1565,13 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
373
1565
  reversed_num_attention_heads = list(reversed(num_attention_heads))
374
1566
  reversed_layers_per_block = list(reversed(layers_per_block))
375
1567
  reversed_cross_attention_dim = list(reversed(cross_attention_dim))
376
- reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
1568
+ reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads))
1569
+
1570
+ if reverse_transformer_layers_per_block is None:
1571
+ reverse_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
1572
+
1573
+ if reverse_temporal_transformer_layers_per_block is None:
1574
+ reverse_temporal_transformer_layers_per_block = list(reversed(temporal_transformer_layers_per_block))
377
1575
 
378
1576
  output_channel = reversed_block_out_channels[0]
379
1577
  for i, up_block_type in enumerate(up_block_types):
@@ -390,26 +1588,47 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
390
1588
  else:
391
1589
  add_upsample = False
392
1590
 
393
- up_block = get_up_block(
394
- up_block_type,
395
- num_layers=reversed_layers_per_block[i] + 1,
396
- in_channels=input_channel,
397
- out_channels=output_channel,
398
- prev_output_channel=prev_output_channel,
399
- temb_channels=time_embed_dim,
400
- add_upsample=add_upsample,
401
- resnet_eps=norm_eps,
402
- resnet_act_fn=act_fn,
403
- resnet_groups=norm_num_groups,
404
- cross_attention_dim=reversed_cross_attention_dim[i],
405
- num_attention_heads=reversed_num_attention_heads[i],
406
- dual_cross_attention=False,
407
- resolution_idx=i,
408
- use_linear_projection=use_linear_projection,
409
- temporal_num_attention_heads=motion_num_attention_heads,
410
- temporal_max_seq_length=motion_max_seq_length,
411
- transformer_layers_per_block=reversed_transformer_layers_per_block[i],
412
- )
1591
+ if up_block_type == "CrossAttnUpBlockMotion":
1592
+ up_block = CrossAttnUpBlockMotion(
1593
+ in_channels=input_channel,
1594
+ out_channels=output_channel,
1595
+ prev_output_channel=prev_output_channel,
1596
+ temb_channels=time_embed_dim,
1597
+ resolution_idx=i,
1598
+ num_layers=reversed_layers_per_block[i] + 1,
1599
+ transformer_layers_per_block=reverse_transformer_layers_per_block[i],
1600
+ resnet_eps=norm_eps,
1601
+ resnet_act_fn=act_fn,
1602
+ resnet_groups=norm_num_groups,
1603
+ num_attention_heads=reversed_num_attention_heads[i],
1604
+ cross_attention_dim=reversed_cross_attention_dim[i],
1605
+ add_upsample=add_upsample,
1606
+ use_linear_projection=use_linear_projection,
1607
+ temporal_num_attention_heads=reversed_motion_num_attention_heads[i],
1608
+ temporal_max_seq_length=motion_max_seq_length,
1609
+ temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i],
1610
+ )
1611
+ elif up_block_type == "UpBlockMotion":
1612
+ up_block = UpBlockMotion(
1613
+ in_channels=input_channel,
1614
+ prev_output_channel=prev_output_channel,
1615
+ out_channels=output_channel,
1616
+ temb_channels=time_embed_dim,
1617
+ resolution_idx=i,
1618
+ num_layers=reversed_layers_per_block[i] + 1,
1619
+ resnet_eps=norm_eps,
1620
+ resnet_act_fn=act_fn,
1621
+ resnet_groups=norm_num_groups,
1622
+ add_upsample=add_upsample,
1623
+ temporal_num_attention_heads=reversed_motion_num_attention_heads[i],
1624
+ temporal_max_seq_length=motion_max_seq_length,
1625
+ temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i],
1626
+ )
1627
+ else:
1628
+ raise ValueError(
1629
+ "Invalid `up_block_type` encountered. Must be one of `CrossAttnUpBlockMotion` or `UpBlockMotion`"
1630
+ )
1631
+
413
1632
  self.up_blocks.append(up_block)
414
1633
  prev_output_channel = output_channel
415
1634
 
@@ -440,6 +1659,24 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
440
1659
  if has_motion_adapter:
441
1660
  motion_adapter.to(device=unet.device)
442
1661
 
1662
+ # check compatibility of number of blocks
1663
+ if len(unet.config["down_block_types"]) != len(motion_adapter.config["block_out_channels"]):
1664
+ raise ValueError("Incompatible Motion Adapter, got different number of blocks")
1665
+
1666
+ # check layers compatibility for each block
1667
+ if isinstance(unet.config["layers_per_block"], int):
1668
+ expanded_layers_per_block = [unet.config["layers_per_block"]] * len(unet.config["down_block_types"])
1669
+ else:
1670
+ expanded_layers_per_block = list(unet.config["layers_per_block"])
1671
+ if isinstance(motion_adapter.config["motion_layers_per_block"], int):
1672
+ expanded_adapter_layers_per_block = [motion_adapter.config["motion_layers_per_block"]] * len(
1673
+ motion_adapter.config["block_out_channels"]
1674
+ )
1675
+ else:
1676
+ expanded_adapter_layers_per_block = list(motion_adapter.config["motion_layers_per_block"])
1677
+ if expanded_layers_per_block != expanded_adapter_layers_per_block:
1678
+ raise ValueError("Incompatible Motion Adapter, got different number of layers per block")
1679
+
443
1680
  # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459
444
1681
  config = dict(unet.config)
445
1682
  config["_class_name"] = cls.__name__
@@ -458,13 +1695,20 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
458
1695
  up_blocks.append("CrossAttnUpBlockMotion")
459
1696
  else:
460
1697
  up_blocks.append("UpBlockMotion")
461
-
462
1698
  config["up_block_types"] = up_blocks
463
1699
 
464
1700
  if has_motion_adapter:
465
1701
  config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"]
466
1702
  config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"]
467
1703
  config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"]
1704
+ config["layers_per_block"] = motion_adapter.config["motion_layers_per_block"]
1705
+ config["temporal_transformer_layers_per_mid_block"] = motion_adapter.config[
1706
+ "motion_transformer_layers_per_mid_block"
1707
+ ]
1708
+ config["temporal_transformer_layers_per_block"] = motion_adapter.config[
1709
+ "motion_transformer_layers_per_block"
1710
+ ]
1711
+ config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"]
468
1712
 
469
1713
  # For PIA UNets we need to set the number input channels to 9
470
1714
  if motion_adapter.config["conv_in_channels"]:
@@ -474,7 +1718,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
474
1718
  if not config.get("num_attention_heads"):
475
1719
  config["num_attention_heads"] = config["attention_head_dim"]
476
1720
 
477
- config = FrozenDict(config)
1721
+ expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
1722
+ config = FrozenDict({k: config.get(k) for k in config if k in expected_kwargs or k in optional_kwargs})
1723
+ config["_class_name"] = cls.__name__
478
1724
  model = cls.from_config(config)
479
1725
 
480
1726
  if not load_weights:
@@ -637,7 +1883,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
637
1883
 
638
1884
  def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
639
1885
  if hasattr(module, "get_processor"):
640
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
1886
+ processors[f"{name}.processor"] = module.get_processor()
641
1887
 
642
1888
  for sub_name, child in module.named_children():
643
1889
  fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@@ -684,7 +1930,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
684
1930
  for name, module in self.named_children():
685
1931
  fn_recursive_attn_processor(name, module, processor)
686
1932
 
687
- # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
688
1933
  def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
689
1934
  """
690
1935
  Sets the attention processor to use [feed forward
@@ -714,7 +1959,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
714
1959
  for module in self.children():
715
1960
  fn_recursive_feed_forward(module, chunk_size, dim)
716
1961
 
717
- # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
718
1962
  def disable_forward_chunking(self) -> None:
719
1963
  def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
720
1964
  if hasattr(module, "set_chunk_feed_forward"):
@@ -804,6 +2048,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
804
2048
  if isinstance(module, Attention):
805
2049
  module.fuse_projections(fuse=True)
806
2050
 
2051
+ self.set_attn_processor(FusedAttnProcessor2_0())
2052
+
807
2053
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
808
2054
  def unfuse_qkv_projections(self):
809
2055
  """Disables the fused QKV projection if enabled.
@@ -830,7 +2076,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
830
2076
  down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
831
2077
  mid_block_additional_residual: Optional[torch.Tensor] = None,
832
2078
  return_dict: bool = True,
833
- ) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]:
2079
+ ) -> Union[UNetMotionOutput, Tuple[torch.Tensor]]:
834
2080
  r"""
835
2081
  The [`UNetMotionModel`] forward method.
836
2082
 
@@ -856,12 +2102,12 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
856
2102
  mid_block_additional_residual: (`torch.Tensor`, *optional*):
857
2103
  A tensor that if specified is added to the residual of the middle unet block.
858
2104
  return_dict (`bool`, *optional*, defaults to `True`):
859
- Whether or not to return a [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
2105
+ Whether or not to return a [`~models.unets.unet_motion_model.UNetMotionOutput`] instead of a plain
860
2106
  tuple.
861
2107
 
862
2108
  Returns:
863
- [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
864
- If `return_dict` is True, an [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] is returned,
2109
+ [`~models.unets.unet_motion_model.UNetMotionOutput`] or `tuple`:
2110
+ If `return_dict` is True, an [`~models.unets.unet_motion_model.UNetMotionOutput`] is returned,
865
2111
  otherwise a `tuple` is returned where the first element is the sample tensor.
866
2112
  """
867
2113
  # By default samples have to be AT least a multiple of the overall upsampling factor.
@@ -1045,4 +2291,4 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
1045
2291
  if not return_dict:
1046
2292
  return (sample,)
1047
2293
 
1048
- return UNet3DConditionOutput(sample=sample)
2294
+ return UNetMotionOutput(sample=sample)