diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +94 -3
- diffusers/commands/env.py +1 -5
- diffusers/configuration_utils.py +4 -9
- diffusers/dependency_versions_table.py +2 -2
- diffusers/image_processor.py +1 -2
- diffusers/loaders/__init__.py +17 -2
- diffusers/loaders/ip_adapter.py +10 -7
- diffusers/loaders/lora_base.py +752 -0
- diffusers/loaders/lora_pipeline.py +2252 -0
- diffusers/loaders/peft.py +213 -5
- diffusers/loaders/single_file.py +3 -14
- diffusers/loaders/single_file_model.py +31 -10
- diffusers/loaders/single_file_utils.py +293 -8
- diffusers/loaders/textual_inversion.py +1 -6
- diffusers/loaders/unet.py +23 -208
- diffusers/models/__init__.py +20 -0
- diffusers/models/activations.py +22 -0
- diffusers/models/attention.py +386 -7
- diffusers/models/attention_processor.py +1937 -629
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_kl.py +14 -3
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vq_model.py +4 -4
- diffusers/models/controlnet.py +2 -3
- diffusers/models/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnet_sd3.py +11 -11
- diffusers/models/controlnet_sparsectrl.py +789 -0
- diffusers/models/controlnet_xs.py +40 -10
- diffusers/models/downsampling.py +68 -0
- diffusers/models/embeddings.py +403 -36
- diffusers/models/model_loading_utils.py +1 -3
- diffusers/models/modeling_flax_utils.py +1 -6
- diffusers/models/modeling_utils.py +4 -16
- diffusers/models/normalization.py +203 -12
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +102 -1
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/transformer_flux.py +455 -0
- diffusers/models/transformers/transformer_sd3.py +18 -4
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +8 -1
- diffusers/models/unets/unet_3d_blocks.py +51 -920
- diffusers/models/unets/unet_3d_condition.py +4 -1
- diffusers/models/unets/unet_i2vgen_xl.py +4 -1
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +1330 -84
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +1 -3
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +64 -0
- diffusers/models/vq_model.py +8 -4
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +100 -3
- diffusers/pipelines/animatediff/__init__.py +4 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
- diffusers/pipelines/auto_pipeline.py +97 -19
- diffusers/pipelines/cogvideo/__init__.py +48 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +749 -0
- diffusers/pipelines/flux/pipeline_output.py +21 -0
- diffusers/pipelines/free_init_utils.py +2 -0
- diffusers/pipelines/free_noise_utils.py +236 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +334 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
- diffusers/pipelines/pag/__init__.py +67 -0
- diffusers/pipelines/pag/pag_utils.py +237 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
- diffusers/pipelines/pia/pipeline_pia.py +30 -37
- diffusers/pipelines/pipeline_flax_utils.py +4 -9
- diffusers/pipelines/pipeline_loading_utils.py +0 -3
- diffusers/pipelines/pipeline_utils.py +2 -14
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
- diffusers/schedulers/__init__.py +8 -0
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +1 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
- diffusers/schedulers/scheduling_ddpm.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +2 -2
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
- diffusers/schedulers/scheduling_ipndm.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
- diffusers/schedulers/scheduling_utils.py +1 -3
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/training_utils.py +99 -14
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +210 -0
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
- diffusers/utils/dynamic_modules_utils.py +1 -11
- diffusers/utils/export_utils.py +50 -6
- diffusers/utils/hub_utils.py +45 -42
- diffusers/utils/import_utils.py +37 -15
- diffusers/utils/loading_utils.py +80 -3
- diffusers/utils/testing_utils.py +11 -8
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1728
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,8 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
+
|
15
|
+
from dataclasses import dataclass
|
14
16
|
from typing import Any, Dict, Optional, Tuple, Union
|
15
17
|
|
16
18
|
import torch
|
@@ -19,8 +21,10 @@ import torch.nn.functional as F
|
|
19
21
|
import torch.utils.checkpoint
|
20
22
|
|
21
23
|
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
22
|
-
from ...loaders import UNet2DConditionLoadersMixin
|
23
|
-
from ...utils import logging
|
24
|
+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
|
25
|
+
from ...utils import BaseOutput, deprecate, is_torch_version, logging
|
26
|
+
from ...utils.torch_utils import apply_freeu
|
27
|
+
from ..attention import BasicTransformerBlock
|
24
28
|
from ..attention_processor import (
|
25
29
|
ADDED_KV_ATTENTION_PROCESSORS,
|
26
30
|
CROSS_ATTENTION_PROCESSORS,
|
@@ -29,35 +33,1114 @@ from ..attention_processor import (
|
|
29
33
|
AttnAddedKVProcessor,
|
30
34
|
AttnProcessor,
|
31
35
|
AttnProcessor2_0,
|
36
|
+
FusedAttnProcessor2_0,
|
32
37
|
IPAdapterAttnProcessor,
|
33
38
|
IPAdapterAttnProcessor2_0,
|
34
39
|
)
|
35
40
|
from ..embeddings import TimestepEmbedding, Timesteps
|
36
41
|
from ..modeling_utils import ModelMixin
|
37
|
-
from ..
|
42
|
+
from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
43
|
+
from ..transformers.dual_transformer_2d import DualTransformer2DModel
|
44
|
+
from ..transformers.transformer_2d import Transformer2DModel
|
38
45
|
from .unet_2d_blocks import UNetMidBlock2DCrossAttn
|
39
46
|
from .unet_2d_condition import UNet2DConditionModel
|
40
|
-
from .unet_3d_blocks import (
|
41
|
-
CrossAttnDownBlockMotion,
|
42
|
-
CrossAttnUpBlockMotion,
|
43
|
-
DownBlockMotion,
|
44
|
-
UNetMidBlockCrossAttnMotion,
|
45
|
-
UpBlockMotion,
|
46
|
-
get_down_block,
|
47
|
-
get_up_block,
|
48
|
-
)
|
49
|
-
from .unet_3d_condition import UNet3DConditionOutput
|
50
47
|
|
51
48
|
|
52
49
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
53
50
|
|
54
51
|
|
52
|
+
@dataclass
|
53
|
+
class UNetMotionOutput(BaseOutput):
|
54
|
+
"""
|
55
|
+
The output of [`UNetMotionOutput`].
|
56
|
+
|
57
|
+
Args:
|
58
|
+
sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`):
|
59
|
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
60
|
+
"""
|
61
|
+
|
62
|
+
sample: torch.Tensor
|
63
|
+
|
64
|
+
|
65
|
+
class AnimateDiffTransformer3D(nn.Module):
|
66
|
+
"""
|
67
|
+
A Transformer model for video-like data.
|
68
|
+
|
69
|
+
Parameters:
|
70
|
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
71
|
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
72
|
+
in_channels (`int`, *optional*):
|
73
|
+
The number of channels in the input and output (specify if the input is **continuous**).
|
74
|
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
75
|
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
76
|
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
77
|
+
attention_bias (`bool`, *optional*):
|
78
|
+
Configure if the `TransformerBlock` attention should contain a bias parameter.
|
79
|
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
80
|
+
This is fixed during training since it is used to learn a number of position embeddings.
|
81
|
+
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
82
|
+
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
|
83
|
+
activation functions.
|
84
|
+
norm_elementwise_affine (`bool`, *optional*):
|
85
|
+
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
|
86
|
+
double_self_attention (`bool`, *optional*):
|
87
|
+
Configure if each `TransformerBlock` should contain two self-attention layers.
|
88
|
+
positional_embeddings: (`str`, *optional*):
|
89
|
+
The type of positional embeddings to apply to the sequence input before passing use.
|
90
|
+
num_positional_embeddings: (`int`, *optional*):
|
91
|
+
The maximum length of the sequence over which to apply positional embeddings.
|
92
|
+
"""
|
93
|
+
|
94
|
+
def __init__(
|
95
|
+
self,
|
96
|
+
num_attention_heads: int = 16,
|
97
|
+
attention_head_dim: int = 88,
|
98
|
+
in_channels: Optional[int] = None,
|
99
|
+
out_channels: Optional[int] = None,
|
100
|
+
num_layers: int = 1,
|
101
|
+
dropout: float = 0.0,
|
102
|
+
norm_num_groups: int = 32,
|
103
|
+
cross_attention_dim: Optional[int] = None,
|
104
|
+
attention_bias: bool = False,
|
105
|
+
sample_size: Optional[int] = None,
|
106
|
+
activation_fn: str = "geglu",
|
107
|
+
norm_elementwise_affine: bool = True,
|
108
|
+
double_self_attention: bool = True,
|
109
|
+
positional_embeddings: Optional[str] = None,
|
110
|
+
num_positional_embeddings: Optional[int] = None,
|
111
|
+
):
|
112
|
+
super().__init__()
|
113
|
+
self.num_attention_heads = num_attention_heads
|
114
|
+
self.attention_head_dim = attention_head_dim
|
115
|
+
inner_dim = num_attention_heads * attention_head_dim
|
116
|
+
|
117
|
+
self.in_channels = in_channels
|
118
|
+
|
119
|
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
120
|
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
121
|
+
|
122
|
+
# 3. Define transformers blocks
|
123
|
+
self.transformer_blocks = nn.ModuleList(
|
124
|
+
[
|
125
|
+
BasicTransformerBlock(
|
126
|
+
inner_dim,
|
127
|
+
num_attention_heads,
|
128
|
+
attention_head_dim,
|
129
|
+
dropout=dropout,
|
130
|
+
cross_attention_dim=cross_attention_dim,
|
131
|
+
activation_fn=activation_fn,
|
132
|
+
attention_bias=attention_bias,
|
133
|
+
double_self_attention=double_self_attention,
|
134
|
+
norm_elementwise_affine=norm_elementwise_affine,
|
135
|
+
positional_embeddings=positional_embeddings,
|
136
|
+
num_positional_embeddings=num_positional_embeddings,
|
137
|
+
)
|
138
|
+
for _ in range(num_layers)
|
139
|
+
]
|
140
|
+
)
|
141
|
+
|
142
|
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
143
|
+
|
144
|
+
def forward(
|
145
|
+
self,
|
146
|
+
hidden_states: torch.Tensor,
|
147
|
+
encoder_hidden_states: Optional[torch.LongTensor] = None,
|
148
|
+
timestep: Optional[torch.LongTensor] = None,
|
149
|
+
class_labels: Optional[torch.LongTensor] = None,
|
150
|
+
num_frames: int = 1,
|
151
|
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
152
|
+
) -> torch.Tensor:
|
153
|
+
"""
|
154
|
+
The [`AnimateDiffTransformer3D`] forward method.
|
155
|
+
|
156
|
+
Args:
|
157
|
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
|
158
|
+
Input hidden_states.
|
159
|
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
160
|
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
161
|
+
self-attention.
|
162
|
+
timestep ( `torch.LongTensor`, *optional*):
|
163
|
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
164
|
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
165
|
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
166
|
+
`AdaLayerZeroNorm`.
|
167
|
+
num_frames (`int`, *optional*, defaults to 1):
|
168
|
+
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
169
|
+
cross_attention_kwargs (`dict`, *optional*):
|
170
|
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
171
|
+
`self.processor` in
|
172
|
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
torch.Tensor:
|
176
|
+
The output tensor.
|
177
|
+
"""
|
178
|
+
# 1. Input
|
179
|
+
batch_frames, channel, height, width = hidden_states.shape
|
180
|
+
batch_size = batch_frames // num_frames
|
181
|
+
|
182
|
+
residual = hidden_states
|
183
|
+
|
184
|
+
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
|
185
|
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
186
|
+
|
187
|
+
hidden_states = self.norm(hidden_states)
|
188
|
+
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
|
189
|
+
|
190
|
+
hidden_states = self.proj_in(hidden_states)
|
191
|
+
|
192
|
+
# 2. Blocks
|
193
|
+
for block in self.transformer_blocks:
|
194
|
+
hidden_states = block(
|
195
|
+
hidden_states,
|
196
|
+
encoder_hidden_states=encoder_hidden_states,
|
197
|
+
timestep=timestep,
|
198
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
199
|
+
class_labels=class_labels,
|
200
|
+
)
|
201
|
+
|
202
|
+
# 3. Output
|
203
|
+
hidden_states = self.proj_out(hidden_states)
|
204
|
+
hidden_states = (
|
205
|
+
hidden_states[None, None, :]
|
206
|
+
.reshape(batch_size, height, width, num_frames, channel)
|
207
|
+
.permute(0, 3, 4, 1, 2)
|
208
|
+
.contiguous()
|
209
|
+
)
|
210
|
+
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
|
211
|
+
|
212
|
+
output = hidden_states + residual
|
213
|
+
return output
|
214
|
+
|
215
|
+
|
216
|
+
class DownBlockMotion(nn.Module):
|
217
|
+
def __init__(
|
218
|
+
self,
|
219
|
+
in_channels: int,
|
220
|
+
out_channels: int,
|
221
|
+
temb_channels: int,
|
222
|
+
dropout: float = 0.0,
|
223
|
+
num_layers: int = 1,
|
224
|
+
resnet_eps: float = 1e-6,
|
225
|
+
resnet_time_scale_shift: str = "default",
|
226
|
+
resnet_act_fn: str = "swish",
|
227
|
+
resnet_groups: int = 32,
|
228
|
+
resnet_pre_norm: bool = True,
|
229
|
+
output_scale_factor: float = 1.0,
|
230
|
+
add_downsample: bool = True,
|
231
|
+
downsample_padding: int = 1,
|
232
|
+
temporal_num_attention_heads: Union[int, Tuple[int]] = 1,
|
233
|
+
temporal_cross_attention_dim: Optional[int] = None,
|
234
|
+
temporal_max_seq_length: int = 32,
|
235
|
+
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
236
|
+
temporal_double_self_attention: bool = True,
|
237
|
+
):
|
238
|
+
super().__init__()
|
239
|
+
resnets = []
|
240
|
+
motion_modules = []
|
241
|
+
|
242
|
+
# support for variable transformer layers per temporal block
|
243
|
+
if isinstance(temporal_transformer_layers_per_block, int):
|
244
|
+
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
|
245
|
+
elif len(temporal_transformer_layers_per_block) != num_layers:
|
246
|
+
raise ValueError(
|
247
|
+
f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}"
|
248
|
+
)
|
249
|
+
|
250
|
+
# support for variable number of attention head per temporal layers
|
251
|
+
if isinstance(temporal_num_attention_heads, int):
|
252
|
+
temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers
|
253
|
+
elif len(temporal_num_attention_heads) != num_layers:
|
254
|
+
raise ValueError(
|
255
|
+
f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}"
|
256
|
+
)
|
257
|
+
|
258
|
+
for i in range(num_layers):
|
259
|
+
in_channels = in_channels if i == 0 else out_channels
|
260
|
+
resnets.append(
|
261
|
+
ResnetBlock2D(
|
262
|
+
in_channels=in_channels,
|
263
|
+
out_channels=out_channels,
|
264
|
+
temb_channels=temb_channels,
|
265
|
+
eps=resnet_eps,
|
266
|
+
groups=resnet_groups,
|
267
|
+
dropout=dropout,
|
268
|
+
time_embedding_norm=resnet_time_scale_shift,
|
269
|
+
non_linearity=resnet_act_fn,
|
270
|
+
output_scale_factor=output_scale_factor,
|
271
|
+
pre_norm=resnet_pre_norm,
|
272
|
+
)
|
273
|
+
)
|
274
|
+
motion_modules.append(
|
275
|
+
AnimateDiffTransformer3D(
|
276
|
+
num_attention_heads=temporal_num_attention_heads[i],
|
277
|
+
in_channels=out_channels,
|
278
|
+
num_layers=temporal_transformer_layers_per_block[i],
|
279
|
+
norm_num_groups=resnet_groups,
|
280
|
+
cross_attention_dim=temporal_cross_attention_dim,
|
281
|
+
attention_bias=False,
|
282
|
+
activation_fn="geglu",
|
283
|
+
positional_embeddings="sinusoidal",
|
284
|
+
num_positional_embeddings=temporal_max_seq_length,
|
285
|
+
attention_head_dim=out_channels // temporal_num_attention_heads[i],
|
286
|
+
double_self_attention=temporal_double_self_attention,
|
287
|
+
)
|
288
|
+
)
|
289
|
+
|
290
|
+
self.resnets = nn.ModuleList(resnets)
|
291
|
+
self.motion_modules = nn.ModuleList(motion_modules)
|
292
|
+
|
293
|
+
if add_downsample:
|
294
|
+
self.downsamplers = nn.ModuleList(
|
295
|
+
[
|
296
|
+
Downsample2D(
|
297
|
+
out_channels,
|
298
|
+
use_conv=True,
|
299
|
+
out_channels=out_channels,
|
300
|
+
padding=downsample_padding,
|
301
|
+
name="op",
|
302
|
+
)
|
303
|
+
]
|
304
|
+
)
|
305
|
+
else:
|
306
|
+
self.downsamplers = None
|
307
|
+
|
308
|
+
self.gradient_checkpointing = False
|
309
|
+
|
310
|
+
def forward(
|
311
|
+
self,
|
312
|
+
hidden_states: torch.Tensor,
|
313
|
+
temb: Optional[torch.Tensor] = None,
|
314
|
+
num_frames: int = 1,
|
315
|
+
*args,
|
316
|
+
**kwargs,
|
317
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
318
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
319
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
320
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
321
|
+
|
322
|
+
output_states = ()
|
323
|
+
|
324
|
+
blocks = zip(self.resnets, self.motion_modules)
|
325
|
+
for resnet, motion_module in blocks:
|
326
|
+
if self.training and self.gradient_checkpointing:
|
327
|
+
|
328
|
+
def create_custom_forward(module):
|
329
|
+
def custom_forward(*inputs):
|
330
|
+
return module(*inputs)
|
331
|
+
|
332
|
+
return custom_forward
|
333
|
+
|
334
|
+
if is_torch_version(">=", "1.11.0"):
|
335
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
336
|
+
create_custom_forward(resnet),
|
337
|
+
hidden_states,
|
338
|
+
temb,
|
339
|
+
use_reentrant=False,
|
340
|
+
)
|
341
|
+
else:
|
342
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
343
|
+
create_custom_forward(resnet), hidden_states, temb
|
344
|
+
)
|
345
|
+
|
346
|
+
else:
|
347
|
+
hidden_states = resnet(hidden_states, temb)
|
348
|
+
|
349
|
+
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
350
|
+
|
351
|
+
output_states = output_states + (hidden_states,)
|
352
|
+
|
353
|
+
if self.downsamplers is not None:
|
354
|
+
for downsampler in self.downsamplers:
|
355
|
+
hidden_states = downsampler(hidden_states)
|
356
|
+
|
357
|
+
output_states = output_states + (hidden_states,)
|
358
|
+
|
359
|
+
return hidden_states, output_states
|
360
|
+
|
361
|
+
|
362
|
+
class CrossAttnDownBlockMotion(nn.Module):
|
363
|
+
def __init__(
|
364
|
+
self,
|
365
|
+
in_channels: int,
|
366
|
+
out_channels: int,
|
367
|
+
temb_channels: int,
|
368
|
+
dropout: float = 0.0,
|
369
|
+
num_layers: int = 1,
|
370
|
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
371
|
+
resnet_eps: float = 1e-6,
|
372
|
+
resnet_time_scale_shift: str = "default",
|
373
|
+
resnet_act_fn: str = "swish",
|
374
|
+
resnet_groups: int = 32,
|
375
|
+
resnet_pre_norm: bool = True,
|
376
|
+
num_attention_heads: int = 1,
|
377
|
+
cross_attention_dim: int = 1280,
|
378
|
+
output_scale_factor: float = 1.0,
|
379
|
+
downsample_padding: int = 1,
|
380
|
+
add_downsample: bool = True,
|
381
|
+
dual_cross_attention: bool = False,
|
382
|
+
use_linear_projection: bool = False,
|
383
|
+
only_cross_attention: bool = False,
|
384
|
+
upcast_attention: bool = False,
|
385
|
+
attention_type: str = "default",
|
386
|
+
temporal_cross_attention_dim: Optional[int] = None,
|
387
|
+
temporal_num_attention_heads: int = 8,
|
388
|
+
temporal_max_seq_length: int = 32,
|
389
|
+
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
390
|
+
temporal_double_self_attention: bool = True,
|
391
|
+
):
|
392
|
+
super().__init__()
|
393
|
+
resnets = []
|
394
|
+
attentions = []
|
395
|
+
motion_modules = []
|
396
|
+
|
397
|
+
self.has_cross_attention = True
|
398
|
+
self.num_attention_heads = num_attention_heads
|
399
|
+
|
400
|
+
# support for variable transformer layers per block
|
401
|
+
if isinstance(transformer_layers_per_block, int):
|
402
|
+
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
|
403
|
+
elif len(transformer_layers_per_block) != num_layers:
|
404
|
+
raise ValueError(
|
405
|
+
f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
|
406
|
+
)
|
407
|
+
|
408
|
+
# support for variable transformer layers per temporal block
|
409
|
+
if isinstance(temporal_transformer_layers_per_block, int):
|
410
|
+
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
|
411
|
+
elif len(temporal_transformer_layers_per_block) != num_layers:
|
412
|
+
raise ValueError(
|
413
|
+
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
|
414
|
+
)
|
415
|
+
|
416
|
+
for i in range(num_layers):
|
417
|
+
in_channels = in_channels if i == 0 else out_channels
|
418
|
+
resnets.append(
|
419
|
+
ResnetBlock2D(
|
420
|
+
in_channels=in_channels,
|
421
|
+
out_channels=out_channels,
|
422
|
+
temb_channels=temb_channels,
|
423
|
+
eps=resnet_eps,
|
424
|
+
groups=resnet_groups,
|
425
|
+
dropout=dropout,
|
426
|
+
time_embedding_norm=resnet_time_scale_shift,
|
427
|
+
non_linearity=resnet_act_fn,
|
428
|
+
output_scale_factor=output_scale_factor,
|
429
|
+
pre_norm=resnet_pre_norm,
|
430
|
+
)
|
431
|
+
)
|
432
|
+
|
433
|
+
if not dual_cross_attention:
|
434
|
+
attentions.append(
|
435
|
+
Transformer2DModel(
|
436
|
+
num_attention_heads,
|
437
|
+
out_channels // num_attention_heads,
|
438
|
+
in_channels=out_channels,
|
439
|
+
num_layers=transformer_layers_per_block[i],
|
440
|
+
cross_attention_dim=cross_attention_dim,
|
441
|
+
norm_num_groups=resnet_groups,
|
442
|
+
use_linear_projection=use_linear_projection,
|
443
|
+
only_cross_attention=only_cross_attention,
|
444
|
+
upcast_attention=upcast_attention,
|
445
|
+
attention_type=attention_type,
|
446
|
+
)
|
447
|
+
)
|
448
|
+
else:
|
449
|
+
attentions.append(
|
450
|
+
DualTransformer2DModel(
|
451
|
+
num_attention_heads,
|
452
|
+
out_channels // num_attention_heads,
|
453
|
+
in_channels=out_channels,
|
454
|
+
num_layers=1,
|
455
|
+
cross_attention_dim=cross_attention_dim,
|
456
|
+
norm_num_groups=resnet_groups,
|
457
|
+
)
|
458
|
+
)
|
459
|
+
|
460
|
+
motion_modules.append(
|
461
|
+
AnimateDiffTransformer3D(
|
462
|
+
num_attention_heads=temporal_num_attention_heads,
|
463
|
+
in_channels=out_channels,
|
464
|
+
num_layers=temporal_transformer_layers_per_block[i],
|
465
|
+
norm_num_groups=resnet_groups,
|
466
|
+
cross_attention_dim=temporal_cross_attention_dim,
|
467
|
+
attention_bias=False,
|
468
|
+
activation_fn="geglu",
|
469
|
+
positional_embeddings="sinusoidal",
|
470
|
+
num_positional_embeddings=temporal_max_seq_length,
|
471
|
+
attention_head_dim=out_channels // temporal_num_attention_heads,
|
472
|
+
double_self_attention=temporal_double_self_attention,
|
473
|
+
)
|
474
|
+
)
|
475
|
+
|
476
|
+
self.attentions = nn.ModuleList(attentions)
|
477
|
+
self.resnets = nn.ModuleList(resnets)
|
478
|
+
self.motion_modules = nn.ModuleList(motion_modules)
|
479
|
+
|
480
|
+
if add_downsample:
|
481
|
+
self.downsamplers = nn.ModuleList(
|
482
|
+
[
|
483
|
+
Downsample2D(
|
484
|
+
out_channels,
|
485
|
+
use_conv=True,
|
486
|
+
out_channels=out_channels,
|
487
|
+
padding=downsample_padding,
|
488
|
+
name="op",
|
489
|
+
)
|
490
|
+
]
|
491
|
+
)
|
492
|
+
else:
|
493
|
+
self.downsamplers = None
|
494
|
+
|
495
|
+
self.gradient_checkpointing = False
|
496
|
+
|
497
|
+
def forward(
|
498
|
+
self,
|
499
|
+
hidden_states: torch.Tensor,
|
500
|
+
temb: Optional[torch.Tensor] = None,
|
501
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
502
|
+
attention_mask: Optional[torch.Tensor] = None,
|
503
|
+
num_frames: int = 1,
|
504
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
505
|
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
506
|
+
additional_residuals: Optional[torch.Tensor] = None,
|
507
|
+
):
|
508
|
+
if cross_attention_kwargs is not None:
|
509
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
510
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
511
|
+
|
512
|
+
output_states = ()
|
513
|
+
|
514
|
+
blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
|
515
|
+
for i, (resnet, attn, motion_module) in enumerate(blocks):
|
516
|
+
if self.training and self.gradient_checkpointing:
|
517
|
+
|
518
|
+
def create_custom_forward(module, return_dict=None):
|
519
|
+
def custom_forward(*inputs):
|
520
|
+
if return_dict is not None:
|
521
|
+
return module(*inputs, return_dict=return_dict)
|
522
|
+
else:
|
523
|
+
return module(*inputs)
|
524
|
+
|
525
|
+
return custom_forward
|
526
|
+
|
527
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
528
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
529
|
+
create_custom_forward(resnet),
|
530
|
+
hidden_states,
|
531
|
+
temb,
|
532
|
+
**ckpt_kwargs,
|
533
|
+
)
|
534
|
+
hidden_states = attn(
|
535
|
+
hidden_states,
|
536
|
+
encoder_hidden_states=encoder_hidden_states,
|
537
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
538
|
+
attention_mask=attention_mask,
|
539
|
+
encoder_attention_mask=encoder_attention_mask,
|
540
|
+
return_dict=False,
|
541
|
+
)[0]
|
542
|
+
else:
|
543
|
+
hidden_states = resnet(hidden_states, temb)
|
544
|
+
|
545
|
+
hidden_states = attn(
|
546
|
+
hidden_states,
|
547
|
+
encoder_hidden_states=encoder_hidden_states,
|
548
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
549
|
+
attention_mask=attention_mask,
|
550
|
+
encoder_attention_mask=encoder_attention_mask,
|
551
|
+
return_dict=False,
|
552
|
+
)[0]
|
553
|
+
hidden_states = motion_module(
|
554
|
+
hidden_states,
|
555
|
+
num_frames=num_frames,
|
556
|
+
)
|
557
|
+
|
558
|
+
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
559
|
+
if i == len(blocks) - 1 and additional_residuals is not None:
|
560
|
+
hidden_states = hidden_states + additional_residuals
|
561
|
+
|
562
|
+
output_states = output_states + (hidden_states,)
|
563
|
+
|
564
|
+
if self.downsamplers is not None:
|
565
|
+
for downsampler in self.downsamplers:
|
566
|
+
hidden_states = downsampler(hidden_states)
|
567
|
+
|
568
|
+
output_states = output_states + (hidden_states,)
|
569
|
+
|
570
|
+
return hidden_states, output_states
|
571
|
+
|
572
|
+
|
573
|
+
class CrossAttnUpBlockMotion(nn.Module):
|
574
|
+
def __init__(
|
575
|
+
self,
|
576
|
+
in_channels: int,
|
577
|
+
out_channels: int,
|
578
|
+
prev_output_channel: int,
|
579
|
+
temb_channels: int,
|
580
|
+
resolution_idx: Optional[int] = None,
|
581
|
+
dropout: float = 0.0,
|
582
|
+
num_layers: int = 1,
|
583
|
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
584
|
+
resnet_eps: float = 1e-6,
|
585
|
+
resnet_time_scale_shift: str = "default",
|
586
|
+
resnet_act_fn: str = "swish",
|
587
|
+
resnet_groups: int = 32,
|
588
|
+
resnet_pre_norm: bool = True,
|
589
|
+
num_attention_heads: int = 1,
|
590
|
+
cross_attention_dim: int = 1280,
|
591
|
+
output_scale_factor: float = 1.0,
|
592
|
+
add_upsample: bool = True,
|
593
|
+
dual_cross_attention: bool = False,
|
594
|
+
use_linear_projection: bool = False,
|
595
|
+
only_cross_attention: bool = False,
|
596
|
+
upcast_attention: bool = False,
|
597
|
+
attention_type: str = "default",
|
598
|
+
temporal_cross_attention_dim: Optional[int] = None,
|
599
|
+
temporal_num_attention_heads: int = 8,
|
600
|
+
temporal_max_seq_length: int = 32,
|
601
|
+
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
602
|
+
):
|
603
|
+
super().__init__()
|
604
|
+
resnets = []
|
605
|
+
attentions = []
|
606
|
+
motion_modules = []
|
607
|
+
|
608
|
+
self.has_cross_attention = True
|
609
|
+
self.num_attention_heads = num_attention_heads
|
610
|
+
|
611
|
+
# support for variable transformer layers per block
|
612
|
+
if isinstance(transformer_layers_per_block, int):
|
613
|
+
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
|
614
|
+
elif len(transformer_layers_per_block) != num_layers:
|
615
|
+
raise ValueError(
|
616
|
+
f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}"
|
617
|
+
)
|
618
|
+
|
619
|
+
# support for variable transformer layers per temporal block
|
620
|
+
if isinstance(temporal_transformer_layers_per_block, int):
|
621
|
+
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
|
622
|
+
elif len(temporal_transformer_layers_per_block) != num_layers:
|
623
|
+
raise ValueError(
|
624
|
+
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}"
|
625
|
+
)
|
626
|
+
|
627
|
+
for i in range(num_layers):
|
628
|
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
629
|
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
630
|
+
|
631
|
+
resnets.append(
|
632
|
+
ResnetBlock2D(
|
633
|
+
in_channels=resnet_in_channels + res_skip_channels,
|
634
|
+
out_channels=out_channels,
|
635
|
+
temb_channels=temb_channels,
|
636
|
+
eps=resnet_eps,
|
637
|
+
groups=resnet_groups,
|
638
|
+
dropout=dropout,
|
639
|
+
time_embedding_norm=resnet_time_scale_shift,
|
640
|
+
non_linearity=resnet_act_fn,
|
641
|
+
output_scale_factor=output_scale_factor,
|
642
|
+
pre_norm=resnet_pre_norm,
|
643
|
+
)
|
644
|
+
)
|
645
|
+
|
646
|
+
if not dual_cross_attention:
|
647
|
+
attentions.append(
|
648
|
+
Transformer2DModel(
|
649
|
+
num_attention_heads,
|
650
|
+
out_channels // num_attention_heads,
|
651
|
+
in_channels=out_channels,
|
652
|
+
num_layers=transformer_layers_per_block[i],
|
653
|
+
cross_attention_dim=cross_attention_dim,
|
654
|
+
norm_num_groups=resnet_groups,
|
655
|
+
use_linear_projection=use_linear_projection,
|
656
|
+
only_cross_attention=only_cross_attention,
|
657
|
+
upcast_attention=upcast_attention,
|
658
|
+
attention_type=attention_type,
|
659
|
+
)
|
660
|
+
)
|
661
|
+
else:
|
662
|
+
attentions.append(
|
663
|
+
DualTransformer2DModel(
|
664
|
+
num_attention_heads,
|
665
|
+
out_channels // num_attention_heads,
|
666
|
+
in_channels=out_channels,
|
667
|
+
num_layers=1,
|
668
|
+
cross_attention_dim=cross_attention_dim,
|
669
|
+
norm_num_groups=resnet_groups,
|
670
|
+
)
|
671
|
+
)
|
672
|
+
motion_modules.append(
|
673
|
+
AnimateDiffTransformer3D(
|
674
|
+
num_attention_heads=temporal_num_attention_heads,
|
675
|
+
in_channels=out_channels,
|
676
|
+
num_layers=temporal_transformer_layers_per_block[i],
|
677
|
+
norm_num_groups=resnet_groups,
|
678
|
+
cross_attention_dim=temporal_cross_attention_dim,
|
679
|
+
attention_bias=False,
|
680
|
+
activation_fn="geglu",
|
681
|
+
positional_embeddings="sinusoidal",
|
682
|
+
num_positional_embeddings=temporal_max_seq_length,
|
683
|
+
attention_head_dim=out_channels // temporal_num_attention_heads,
|
684
|
+
)
|
685
|
+
)
|
686
|
+
|
687
|
+
self.attentions = nn.ModuleList(attentions)
|
688
|
+
self.resnets = nn.ModuleList(resnets)
|
689
|
+
self.motion_modules = nn.ModuleList(motion_modules)
|
690
|
+
|
691
|
+
if add_upsample:
|
692
|
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
693
|
+
else:
|
694
|
+
self.upsamplers = None
|
695
|
+
|
696
|
+
self.gradient_checkpointing = False
|
697
|
+
self.resolution_idx = resolution_idx
|
698
|
+
|
699
|
+
def forward(
|
700
|
+
self,
|
701
|
+
hidden_states: torch.Tensor,
|
702
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
703
|
+
temb: Optional[torch.Tensor] = None,
|
704
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
705
|
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
706
|
+
upsample_size: Optional[int] = None,
|
707
|
+
attention_mask: Optional[torch.Tensor] = None,
|
708
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
709
|
+
num_frames: int = 1,
|
710
|
+
) -> torch.Tensor:
|
711
|
+
if cross_attention_kwargs is not None:
|
712
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
713
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
714
|
+
|
715
|
+
is_freeu_enabled = (
|
716
|
+
getattr(self, "s1", None)
|
717
|
+
and getattr(self, "s2", None)
|
718
|
+
and getattr(self, "b1", None)
|
719
|
+
and getattr(self, "b2", None)
|
720
|
+
)
|
721
|
+
|
722
|
+
blocks = zip(self.resnets, self.attentions, self.motion_modules)
|
723
|
+
for resnet, attn, motion_module in blocks:
|
724
|
+
# pop res hidden states
|
725
|
+
res_hidden_states = res_hidden_states_tuple[-1]
|
726
|
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
727
|
+
|
728
|
+
# FreeU: Only operate on the first two stages
|
729
|
+
if is_freeu_enabled:
|
730
|
+
hidden_states, res_hidden_states = apply_freeu(
|
731
|
+
self.resolution_idx,
|
732
|
+
hidden_states,
|
733
|
+
res_hidden_states,
|
734
|
+
s1=self.s1,
|
735
|
+
s2=self.s2,
|
736
|
+
b1=self.b1,
|
737
|
+
b2=self.b2,
|
738
|
+
)
|
739
|
+
|
740
|
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
741
|
+
|
742
|
+
if self.training and self.gradient_checkpointing:
|
743
|
+
|
744
|
+
def create_custom_forward(module, return_dict=None):
|
745
|
+
def custom_forward(*inputs):
|
746
|
+
if return_dict is not None:
|
747
|
+
return module(*inputs, return_dict=return_dict)
|
748
|
+
else:
|
749
|
+
return module(*inputs)
|
750
|
+
|
751
|
+
return custom_forward
|
752
|
+
|
753
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
754
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
755
|
+
create_custom_forward(resnet),
|
756
|
+
hidden_states,
|
757
|
+
temb,
|
758
|
+
**ckpt_kwargs,
|
759
|
+
)
|
760
|
+
hidden_states = attn(
|
761
|
+
hidden_states,
|
762
|
+
encoder_hidden_states=encoder_hidden_states,
|
763
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
764
|
+
attention_mask=attention_mask,
|
765
|
+
encoder_attention_mask=encoder_attention_mask,
|
766
|
+
return_dict=False,
|
767
|
+
)[0]
|
768
|
+
else:
|
769
|
+
hidden_states = resnet(hidden_states, temb)
|
770
|
+
|
771
|
+
hidden_states = attn(
|
772
|
+
hidden_states,
|
773
|
+
encoder_hidden_states=encoder_hidden_states,
|
774
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
775
|
+
attention_mask=attention_mask,
|
776
|
+
encoder_attention_mask=encoder_attention_mask,
|
777
|
+
return_dict=False,
|
778
|
+
)[0]
|
779
|
+
hidden_states = motion_module(
|
780
|
+
hidden_states,
|
781
|
+
num_frames=num_frames,
|
782
|
+
)
|
783
|
+
|
784
|
+
if self.upsamplers is not None:
|
785
|
+
for upsampler in self.upsamplers:
|
786
|
+
hidden_states = upsampler(hidden_states, upsample_size)
|
787
|
+
|
788
|
+
return hidden_states
|
789
|
+
|
790
|
+
|
791
|
+
class UpBlockMotion(nn.Module):
|
792
|
+
def __init__(
|
793
|
+
self,
|
794
|
+
in_channels: int,
|
795
|
+
prev_output_channel: int,
|
796
|
+
out_channels: int,
|
797
|
+
temb_channels: int,
|
798
|
+
resolution_idx: Optional[int] = None,
|
799
|
+
dropout: float = 0.0,
|
800
|
+
num_layers: int = 1,
|
801
|
+
resnet_eps: float = 1e-6,
|
802
|
+
resnet_time_scale_shift: str = "default",
|
803
|
+
resnet_act_fn: str = "swish",
|
804
|
+
resnet_groups: int = 32,
|
805
|
+
resnet_pre_norm: bool = True,
|
806
|
+
output_scale_factor: float = 1.0,
|
807
|
+
add_upsample: bool = True,
|
808
|
+
temporal_cross_attention_dim: Optional[int] = None,
|
809
|
+
temporal_num_attention_heads: int = 8,
|
810
|
+
temporal_max_seq_length: int = 32,
|
811
|
+
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
812
|
+
):
|
813
|
+
super().__init__()
|
814
|
+
resnets = []
|
815
|
+
motion_modules = []
|
816
|
+
|
817
|
+
# support for variable transformer layers per temporal block
|
818
|
+
if isinstance(temporal_transformer_layers_per_block, int):
|
819
|
+
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
|
820
|
+
elif len(temporal_transformer_layers_per_block) != num_layers:
|
821
|
+
raise ValueError(
|
822
|
+
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
|
823
|
+
)
|
824
|
+
|
825
|
+
for i in range(num_layers):
|
826
|
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
827
|
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
828
|
+
|
829
|
+
resnets.append(
|
830
|
+
ResnetBlock2D(
|
831
|
+
in_channels=resnet_in_channels + res_skip_channels,
|
832
|
+
out_channels=out_channels,
|
833
|
+
temb_channels=temb_channels,
|
834
|
+
eps=resnet_eps,
|
835
|
+
groups=resnet_groups,
|
836
|
+
dropout=dropout,
|
837
|
+
time_embedding_norm=resnet_time_scale_shift,
|
838
|
+
non_linearity=resnet_act_fn,
|
839
|
+
output_scale_factor=output_scale_factor,
|
840
|
+
pre_norm=resnet_pre_norm,
|
841
|
+
)
|
842
|
+
)
|
843
|
+
|
844
|
+
motion_modules.append(
|
845
|
+
AnimateDiffTransformer3D(
|
846
|
+
num_attention_heads=temporal_num_attention_heads,
|
847
|
+
in_channels=out_channels,
|
848
|
+
num_layers=temporal_transformer_layers_per_block[i],
|
849
|
+
norm_num_groups=resnet_groups,
|
850
|
+
cross_attention_dim=temporal_cross_attention_dim,
|
851
|
+
attention_bias=False,
|
852
|
+
activation_fn="geglu",
|
853
|
+
positional_embeddings="sinusoidal",
|
854
|
+
num_positional_embeddings=temporal_max_seq_length,
|
855
|
+
attention_head_dim=out_channels // temporal_num_attention_heads,
|
856
|
+
)
|
857
|
+
)
|
858
|
+
|
859
|
+
self.resnets = nn.ModuleList(resnets)
|
860
|
+
self.motion_modules = nn.ModuleList(motion_modules)
|
861
|
+
|
862
|
+
if add_upsample:
|
863
|
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
864
|
+
else:
|
865
|
+
self.upsamplers = None
|
866
|
+
|
867
|
+
self.gradient_checkpointing = False
|
868
|
+
self.resolution_idx = resolution_idx
|
869
|
+
|
870
|
+
def forward(
|
871
|
+
self,
|
872
|
+
hidden_states: torch.Tensor,
|
873
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
874
|
+
temb: Optional[torch.Tensor] = None,
|
875
|
+
upsample_size=None,
|
876
|
+
num_frames: int = 1,
|
877
|
+
*args,
|
878
|
+
**kwargs,
|
879
|
+
) -> torch.Tensor:
|
880
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
881
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
882
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
883
|
+
|
884
|
+
is_freeu_enabled = (
|
885
|
+
getattr(self, "s1", None)
|
886
|
+
and getattr(self, "s2", None)
|
887
|
+
and getattr(self, "b1", None)
|
888
|
+
and getattr(self, "b2", None)
|
889
|
+
)
|
890
|
+
|
891
|
+
blocks = zip(self.resnets, self.motion_modules)
|
892
|
+
|
893
|
+
for resnet, motion_module in blocks:
|
894
|
+
# pop res hidden states
|
895
|
+
res_hidden_states = res_hidden_states_tuple[-1]
|
896
|
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
897
|
+
|
898
|
+
# FreeU: Only operate on the first two stages
|
899
|
+
if is_freeu_enabled:
|
900
|
+
hidden_states, res_hidden_states = apply_freeu(
|
901
|
+
self.resolution_idx,
|
902
|
+
hidden_states,
|
903
|
+
res_hidden_states,
|
904
|
+
s1=self.s1,
|
905
|
+
s2=self.s2,
|
906
|
+
b1=self.b1,
|
907
|
+
b2=self.b2,
|
908
|
+
)
|
909
|
+
|
910
|
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
911
|
+
|
912
|
+
if self.training and self.gradient_checkpointing:
|
913
|
+
|
914
|
+
def create_custom_forward(module):
|
915
|
+
def custom_forward(*inputs):
|
916
|
+
return module(*inputs)
|
917
|
+
|
918
|
+
return custom_forward
|
919
|
+
|
920
|
+
if is_torch_version(">=", "1.11.0"):
|
921
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
922
|
+
create_custom_forward(resnet),
|
923
|
+
hidden_states,
|
924
|
+
temb,
|
925
|
+
use_reentrant=False,
|
926
|
+
)
|
927
|
+
else:
|
928
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
929
|
+
create_custom_forward(resnet), hidden_states, temb
|
930
|
+
)
|
931
|
+
else:
|
932
|
+
hidden_states = resnet(hidden_states, temb)
|
933
|
+
|
934
|
+
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
935
|
+
|
936
|
+
if self.upsamplers is not None:
|
937
|
+
for upsampler in self.upsamplers:
|
938
|
+
hidden_states = upsampler(hidden_states, upsample_size)
|
939
|
+
|
940
|
+
return hidden_states
|
941
|
+
|
942
|
+
|
943
|
+
class UNetMidBlockCrossAttnMotion(nn.Module):
|
944
|
+
def __init__(
|
945
|
+
self,
|
946
|
+
in_channels: int,
|
947
|
+
temb_channels: int,
|
948
|
+
dropout: float = 0.0,
|
949
|
+
num_layers: int = 1,
|
950
|
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
951
|
+
resnet_eps: float = 1e-6,
|
952
|
+
resnet_time_scale_shift: str = "default",
|
953
|
+
resnet_act_fn: str = "swish",
|
954
|
+
resnet_groups: int = 32,
|
955
|
+
resnet_pre_norm: bool = True,
|
956
|
+
num_attention_heads: int = 1,
|
957
|
+
output_scale_factor: float = 1.0,
|
958
|
+
cross_attention_dim: int = 1280,
|
959
|
+
dual_cross_attention: bool = False,
|
960
|
+
use_linear_projection: bool = False,
|
961
|
+
upcast_attention: bool = False,
|
962
|
+
attention_type: str = "default",
|
963
|
+
temporal_num_attention_heads: int = 1,
|
964
|
+
temporal_cross_attention_dim: Optional[int] = None,
|
965
|
+
temporal_max_seq_length: int = 32,
|
966
|
+
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
967
|
+
):
|
968
|
+
super().__init__()
|
969
|
+
|
970
|
+
self.has_cross_attention = True
|
971
|
+
self.num_attention_heads = num_attention_heads
|
972
|
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
973
|
+
|
974
|
+
# support for variable transformer layers per block
|
975
|
+
if isinstance(transformer_layers_per_block, int):
|
976
|
+
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
|
977
|
+
elif len(transformer_layers_per_block) != num_layers:
|
978
|
+
raise ValueError(
|
979
|
+
f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
|
980
|
+
)
|
981
|
+
|
982
|
+
# support for variable transformer layers per temporal block
|
983
|
+
if isinstance(temporal_transformer_layers_per_block, int):
|
984
|
+
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
|
985
|
+
elif len(temporal_transformer_layers_per_block) != num_layers:
|
986
|
+
raise ValueError(
|
987
|
+
f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
|
988
|
+
)
|
989
|
+
|
990
|
+
# there is always at least one resnet
|
991
|
+
resnets = [
|
992
|
+
ResnetBlock2D(
|
993
|
+
in_channels=in_channels,
|
994
|
+
out_channels=in_channels,
|
995
|
+
temb_channels=temb_channels,
|
996
|
+
eps=resnet_eps,
|
997
|
+
groups=resnet_groups,
|
998
|
+
dropout=dropout,
|
999
|
+
time_embedding_norm=resnet_time_scale_shift,
|
1000
|
+
non_linearity=resnet_act_fn,
|
1001
|
+
output_scale_factor=output_scale_factor,
|
1002
|
+
pre_norm=resnet_pre_norm,
|
1003
|
+
)
|
1004
|
+
]
|
1005
|
+
attentions = []
|
1006
|
+
motion_modules = []
|
1007
|
+
|
1008
|
+
for i in range(num_layers):
|
1009
|
+
if not dual_cross_attention:
|
1010
|
+
attentions.append(
|
1011
|
+
Transformer2DModel(
|
1012
|
+
num_attention_heads,
|
1013
|
+
in_channels // num_attention_heads,
|
1014
|
+
in_channels=in_channels,
|
1015
|
+
num_layers=transformer_layers_per_block[i],
|
1016
|
+
cross_attention_dim=cross_attention_dim,
|
1017
|
+
norm_num_groups=resnet_groups,
|
1018
|
+
use_linear_projection=use_linear_projection,
|
1019
|
+
upcast_attention=upcast_attention,
|
1020
|
+
attention_type=attention_type,
|
1021
|
+
)
|
1022
|
+
)
|
1023
|
+
else:
|
1024
|
+
attentions.append(
|
1025
|
+
DualTransformer2DModel(
|
1026
|
+
num_attention_heads,
|
1027
|
+
in_channels // num_attention_heads,
|
1028
|
+
in_channels=in_channels,
|
1029
|
+
num_layers=1,
|
1030
|
+
cross_attention_dim=cross_attention_dim,
|
1031
|
+
norm_num_groups=resnet_groups,
|
1032
|
+
)
|
1033
|
+
)
|
1034
|
+
resnets.append(
|
1035
|
+
ResnetBlock2D(
|
1036
|
+
in_channels=in_channels,
|
1037
|
+
out_channels=in_channels,
|
1038
|
+
temb_channels=temb_channels,
|
1039
|
+
eps=resnet_eps,
|
1040
|
+
groups=resnet_groups,
|
1041
|
+
dropout=dropout,
|
1042
|
+
time_embedding_norm=resnet_time_scale_shift,
|
1043
|
+
non_linearity=resnet_act_fn,
|
1044
|
+
output_scale_factor=output_scale_factor,
|
1045
|
+
pre_norm=resnet_pre_norm,
|
1046
|
+
)
|
1047
|
+
)
|
1048
|
+
motion_modules.append(
|
1049
|
+
AnimateDiffTransformer3D(
|
1050
|
+
num_attention_heads=temporal_num_attention_heads,
|
1051
|
+
attention_head_dim=in_channels // temporal_num_attention_heads,
|
1052
|
+
in_channels=in_channels,
|
1053
|
+
num_layers=temporal_transformer_layers_per_block[i],
|
1054
|
+
norm_num_groups=resnet_groups,
|
1055
|
+
cross_attention_dim=temporal_cross_attention_dim,
|
1056
|
+
attention_bias=False,
|
1057
|
+
positional_embeddings="sinusoidal",
|
1058
|
+
num_positional_embeddings=temporal_max_seq_length,
|
1059
|
+
activation_fn="geglu",
|
1060
|
+
)
|
1061
|
+
)
|
1062
|
+
|
1063
|
+
self.attentions = nn.ModuleList(attentions)
|
1064
|
+
self.resnets = nn.ModuleList(resnets)
|
1065
|
+
self.motion_modules = nn.ModuleList(motion_modules)
|
1066
|
+
|
1067
|
+
self.gradient_checkpointing = False
|
1068
|
+
|
1069
|
+
def forward(
|
1070
|
+
self,
|
1071
|
+
hidden_states: torch.Tensor,
|
1072
|
+
temb: Optional[torch.Tensor] = None,
|
1073
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1074
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1075
|
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1076
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1077
|
+
num_frames: int = 1,
|
1078
|
+
) -> torch.Tensor:
|
1079
|
+
if cross_attention_kwargs is not None:
|
1080
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
1081
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
1082
|
+
|
1083
|
+
hidden_states = self.resnets[0](hidden_states, temb)
|
1084
|
+
|
1085
|
+
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
|
1086
|
+
for attn, resnet, motion_module in blocks:
|
1087
|
+
if self.training and self.gradient_checkpointing:
|
1088
|
+
|
1089
|
+
def create_custom_forward(module, return_dict=None):
|
1090
|
+
def custom_forward(*inputs):
|
1091
|
+
if return_dict is not None:
|
1092
|
+
return module(*inputs, return_dict=return_dict)
|
1093
|
+
else:
|
1094
|
+
return module(*inputs)
|
1095
|
+
|
1096
|
+
return custom_forward
|
1097
|
+
|
1098
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1099
|
+
hidden_states = attn(
|
1100
|
+
hidden_states,
|
1101
|
+
encoder_hidden_states=encoder_hidden_states,
|
1102
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
1103
|
+
attention_mask=attention_mask,
|
1104
|
+
encoder_attention_mask=encoder_attention_mask,
|
1105
|
+
return_dict=False,
|
1106
|
+
)[0]
|
1107
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1108
|
+
create_custom_forward(motion_module),
|
1109
|
+
hidden_states,
|
1110
|
+
temb,
|
1111
|
+
**ckpt_kwargs,
|
1112
|
+
)
|
1113
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1114
|
+
create_custom_forward(resnet),
|
1115
|
+
hidden_states,
|
1116
|
+
temb,
|
1117
|
+
**ckpt_kwargs,
|
1118
|
+
)
|
1119
|
+
else:
|
1120
|
+
hidden_states = attn(
|
1121
|
+
hidden_states,
|
1122
|
+
encoder_hidden_states=encoder_hidden_states,
|
1123
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
1124
|
+
attention_mask=attention_mask,
|
1125
|
+
encoder_attention_mask=encoder_attention_mask,
|
1126
|
+
return_dict=False,
|
1127
|
+
)[0]
|
1128
|
+
hidden_states = motion_module(
|
1129
|
+
hidden_states,
|
1130
|
+
num_frames=num_frames,
|
1131
|
+
)
|
1132
|
+
hidden_states = resnet(hidden_states, temb)
|
1133
|
+
|
1134
|
+
return hidden_states
|
1135
|
+
|
1136
|
+
|
55
1137
|
class MotionModules(nn.Module):
|
56
1138
|
def __init__(
|
57
1139
|
self,
|
58
1140
|
in_channels: int,
|
59
1141
|
layers_per_block: int = 2,
|
60
|
-
|
1142
|
+
transformer_layers_per_block: Union[int, Tuple[int]] = 8,
|
1143
|
+
num_attention_heads: Union[int, Tuple[int]] = 8,
|
61
1144
|
attention_bias: bool = False,
|
62
1145
|
cross_attention_dim: Optional[int] = None,
|
63
1146
|
activation_fn: str = "geglu",
|
@@ -67,10 +1150,19 @@ class MotionModules(nn.Module):
|
|
67
1150
|
super().__init__()
|
68
1151
|
self.motion_modules = nn.ModuleList([])
|
69
1152
|
|
1153
|
+
if isinstance(transformer_layers_per_block, int):
|
1154
|
+
transformer_layers_per_block = (transformer_layers_per_block,) * layers_per_block
|
1155
|
+
elif len(transformer_layers_per_block) != layers_per_block:
|
1156
|
+
raise ValueError(
|
1157
|
+
f"The number of transformer layers per block must match the number of layers per block, "
|
1158
|
+
f"got {layers_per_block} and {len(transformer_layers_per_block)}"
|
1159
|
+
)
|
1160
|
+
|
70
1161
|
for i in range(layers_per_block):
|
71
1162
|
self.motion_modules.append(
|
72
|
-
|
1163
|
+
AnimateDiffTransformer3D(
|
73
1164
|
in_channels=in_channels,
|
1165
|
+
num_layers=transformer_layers_per_block[i],
|
74
1166
|
norm_num_groups=norm_num_groups,
|
75
1167
|
cross_attention_dim=cross_attention_dim,
|
76
1168
|
activation_fn=activation_fn,
|
@@ -83,14 +1175,16 @@ class MotionModules(nn.Module):
|
|
83
1175
|
)
|
84
1176
|
|
85
1177
|
|
86
|
-
class MotionAdapter(ModelMixin, ConfigMixin):
|
1178
|
+
class MotionAdapter(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
87
1179
|
@register_to_config
|
88
1180
|
def __init__(
|
89
1181
|
self,
|
90
1182
|
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
91
|
-
motion_layers_per_block: int = 2,
|
1183
|
+
motion_layers_per_block: Union[int, Tuple[int]] = 2,
|
1184
|
+
motion_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]] = 1,
|
92
1185
|
motion_mid_block_layers_per_block: int = 1,
|
93
|
-
|
1186
|
+
motion_transformer_layers_per_mid_block: Union[int, Tuple[int]] = 1,
|
1187
|
+
motion_num_attention_heads: Union[int, Tuple[int]] = 8,
|
94
1188
|
motion_norm_num_groups: int = 32,
|
95
1189
|
motion_max_seq_length: int = 32,
|
96
1190
|
use_motion_mid_block: bool = True,
|
@@ -101,11 +1195,15 @@ class MotionAdapter(ModelMixin, ConfigMixin):
|
|
101
1195
|
Args:
|
102
1196
|
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
103
1197
|
The tuple of output channels for each UNet block.
|
104
|
-
motion_layers_per_block (`int`, *optional*, defaults to 2):
|
1198
|
+
motion_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 2):
|
105
1199
|
The number of motion layers per UNet block.
|
1200
|
+
motion_transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple[int]]`, *optional*, defaults to 1):
|
1201
|
+
The number of transformer layers to use in each motion layer in each block.
|
106
1202
|
motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1):
|
107
1203
|
The number of motion layers in the middle UNet block.
|
108
|
-
|
1204
|
+
motion_transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
1205
|
+
The number of transformer layers to use in each motion layer in the middle block.
|
1206
|
+
motion_num_attention_heads (`int` or `Tuple[int]`, *optional*, defaults to 8):
|
109
1207
|
The number of heads to use in each attention layer of the motion module.
|
110
1208
|
motion_norm_num_groups (`int`, *optional*, defaults to 32):
|
111
1209
|
The number of groups to use in each group normalization layer of the motion module.
|
@@ -119,6 +1217,35 @@ class MotionAdapter(ModelMixin, ConfigMixin):
|
|
119
1217
|
down_blocks = []
|
120
1218
|
up_blocks = []
|
121
1219
|
|
1220
|
+
if isinstance(motion_layers_per_block, int):
|
1221
|
+
motion_layers_per_block = (motion_layers_per_block,) * len(block_out_channels)
|
1222
|
+
elif len(motion_layers_per_block) != len(block_out_channels):
|
1223
|
+
raise ValueError(
|
1224
|
+
f"The number of motion layers per block must match the number of blocks, "
|
1225
|
+
f"got {len(block_out_channels)} and {len(motion_layers_per_block)}"
|
1226
|
+
)
|
1227
|
+
|
1228
|
+
if isinstance(motion_transformer_layers_per_block, int):
|
1229
|
+
motion_transformer_layers_per_block = (motion_transformer_layers_per_block,) * len(block_out_channels)
|
1230
|
+
|
1231
|
+
if isinstance(motion_transformer_layers_per_mid_block, int):
|
1232
|
+
motion_transformer_layers_per_mid_block = (
|
1233
|
+
motion_transformer_layers_per_mid_block,
|
1234
|
+
) * motion_mid_block_layers_per_block
|
1235
|
+
elif len(motion_transformer_layers_per_mid_block) != motion_mid_block_layers_per_block:
|
1236
|
+
raise ValueError(
|
1237
|
+
f"The number of layers per mid block ({motion_mid_block_layers_per_block}) "
|
1238
|
+
f"must match the length of motion_transformer_layers_per_mid_block ({len(motion_transformer_layers_per_mid_block)})"
|
1239
|
+
)
|
1240
|
+
|
1241
|
+
if isinstance(motion_num_attention_heads, int):
|
1242
|
+
motion_num_attention_heads = (motion_num_attention_heads,) * len(block_out_channels)
|
1243
|
+
elif len(motion_num_attention_heads) != len(block_out_channels):
|
1244
|
+
raise ValueError(
|
1245
|
+
f"The length of the attention head number tuple in the motion module must match the "
|
1246
|
+
f"number of block, got {len(motion_num_attention_heads)} and {len(block_out_channels)}"
|
1247
|
+
)
|
1248
|
+
|
122
1249
|
if conv_in_channels:
|
123
1250
|
# input
|
124
1251
|
self.conv_in = nn.Conv2d(conv_in_channels, block_out_channels[0], kernel_size=3, padding=1)
|
@@ -134,9 +1261,10 @@ class MotionAdapter(ModelMixin, ConfigMixin):
|
|
134
1261
|
cross_attention_dim=None,
|
135
1262
|
activation_fn="geglu",
|
136
1263
|
attention_bias=False,
|
137
|
-
num_attention_heads=motion_num_attention_heads,
|
1264
|
+
num_attention_heads=motion_num_attention_heads[i],
|
138
1265
|
max_seq_length=motion_max_seq_length,
|
139
|
-
layers_per_block=motion_layers_per_block,
|
1266
|
+
layers_per_block=motion_layers_per_block[i],
|
1267
|
+
transformer_layers_per_block=motion_transformer_layers_per_block[i],
|
140
1268
|
)
|
141
1269
|
)
|
142
1270
|
|
@@ -147,15 +1275,20 @@ class MotionAdapter(ModelMixin, ConfigMixin):
|
|
147
1275
|
cross_attention_dim=None,
|
148
1276
|
activation_fn="geglu",
|
149
1277
|
attention_bias=False,
|
150
|
-
num_attention_heads=motion_num_attention_heads,
|
151
|
-
layers_per_block=motion_mid_block_layers_per_block,
|
1278
|
+
num_attention_heads=motion_num_attention_heads[-1],
|
152
1279
|
max_seq_length=motion_max_seq_length,
|
1280
|
+
layers_per_block=motion_mid_block_layers_per_block,
|
1281
|
+
transformer_layers_per_block=motion_transformer_layers_per_mid_block,
|
153
1282
|
)
|
154
1283
|
else:
|
155
1284
|
self.mid_block = None
|
156
1285
|
|
157
1286
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
158
1287
|
output_channel = reversed_block_out_channels[0]
|
1288
|
+
|
1289
|
+
reversed_motion_layers_per_block = list(reversed(motion_layers_per_block))
|
1290
|
+
reversed_motion_transformer_layers_per_block = list(reversed(motion_transformer_layers_per_block))
|
1291
|
+
reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads))
|
159
1292
|
for i, channel in enumerate(reversed_block_out_channels):
|
160
1293
|
output_channel = reversed_block_out_channels[i]
|
161
1294
|
up_blocks.append(
|
@@ -165,9 +1298,10 @@ class MotionAdapter(ModelMixin, ConfigMixin):
|
|
165
1298
|
cross_attention_dim=None,
|
166
1299
|
activation_fn="geglu",
|
167
1300
|
attention_bias=False,
|
168
|
-
num_attention_heads=
|
1301
|
+
num_attention_heads=reversed_motion_num_attention_heads[i],
|
169
1302
|
max_seq_length=motion_max_seq_length,
|
170
|
-
layers_per_block=
|
1303
|
+
layers_per_block=reversed_motion_layers_per_block[i] + 1,
|
1304
|
+
transformer_layers_per_block=reversed_motion_transformer_layers_per_block[i],
|
171
1305
|
)
|
172
1306
|
)
|
173
1307
|
|
@@ -178,7 +1312,7 @@ class MotionAdapter(ModelMixin, ConfigMixin):
|
|
178
1312
|
pass
|
179
1313
|
|
180
1314
|
|
181
|
-
class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
1315
|
+
class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
182
1316
|
r"""
|
183
1317
|
A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a
|
184
1318
|
sample shaped output.
|
@@ -208,7 +1342,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
208
1342
|
"CrossAttnUpBlockMotion",
|
209
1343
|
),
|
210
1344
|
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
211
|
-
layers_per_block: int = 2,
|
1345
|
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
212
1346
|
downsample_padding: int = 1,
|
213
1347
|
mid_block_scale_factor: float = 1,
|
214
1348
|
act_fn: str = "silu",
|
@@ -216,12 +1350,18 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
216
1350
|
norm_eps: float = 1e-5,
|
217
1351
|
cross_attention_dim: int = 1280,
|
218
1352
|
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
219
|
-
reverse_transformer_layers_per_block: Optional[
|
1353
|
+
reverse_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
|
1354
|
+
temporal_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
1355
|
+
reverse_temporal_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
|
1356
|
+
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
|
1357
|
+
temporal_transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = 1,
|
220
1358
|
use_linear_projection: bool = False,
|
221
1359
|
num_attention_heads: Union[int, Tuple[int, ...]] = 8,
|
222
1360
|
motion_max_seq_length: int = 32,
|
223
|
-
motion_num_attention_heads: int = 8,
|
224
|
-
|
1361
|
+
motion_num_attention_heads: Union[int, Tuple[int, ...]] = 8,
|
1362
|
+
reverse_motion_num_attention_heads: Optional[Union[int, Tuple[int, ...], Tuple[Tuple[int, ...], ...]]] = None,
|
1363
|
+
use_motion_mid_block: bool = True,
|
1364
|
+
mid_block_layers: int = 1,
|
225
1365
|
encoder_hid_dim: Optional[int] = None,
|
226
1366
|
encoder_hid_dim_type: Optional[str] = None,
|
227
1367
|
addition_embed_type: Optional[str] = None,
|
@@ -264,6 +1404,16 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
264
1404
|
if isinstance(layer_number_per_block, list):
|
265
1405
|
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
|
266
1406
|
|
1407
|
+
if (
|
1408
|
+
isinstance(temporal_transformer_layers_per_block, list)
|
1409
|
+
and reverse_temporal_transformer_layers_per_block is None
|
1410
|
+
):
|
1411
|
+
for layer_number_per_block in temporal_transformer_layers_per_block:
|
1412
|
+
if isinstance(layer_number_per_block, list):
|
1413
|
+
raise ValueError(
|
1414
|
+
"Must provide 'reverse_temporal_transformer_layers_per_block` if using asymmetrical motion module in UNet."
|
1415
|
+
)
|
1416
|
+
|
267
1417
|
# input
|
268
1418
|
conv_in_kernel = 3
|
269
1419
|
conv_out_kernel = 3
|
@@ -304,6 +1454,20 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
304
1454
|
if isinstance(transformer_layers_per_block, int):
|
305
1455
|
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
306
1456
|
|
1457
|
+
if isinstance(reverse_transformer_layers_per_block, int):
|
1458
|
+
reverse_transformer_layers_per_block = [reverse_transformer_layers_per_block] * len(down_block_types)
|
1459
|
+
|
1460
|
+
if isinstance(temporal_transformer_layers_per_block, int):
|
1461
|
+
temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types)
|
1462
|
+
|
1463
|
+
if isinstance(reverse_temporal_transformer_layers_per_block, int):
|
1464
|
+
reverse_temporal_transformer_layers_per_block = [reverse_temporal_transformer_layers_per_block] * len(
|
1465
|
+
down_block_types
|
1466
|
+
)
|
1467
|
+
|
1468
|
+
if isinstance(motion_num_attention_heads, int):
|
1469
|
+
motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types)
|
1470
|
+
|
307
1471
|
# down
|
308
1472
|
output_channel = block_out_channels[0]
|
309
1473
|
for i, down_block_type in enumerate(down_block_types):
|
@@ -311,28 +1475,53 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
311
1475
|
output_channel = block_out_channels[i]
|
312
1476
|
is_final_block = i == len(block_out_channels) - 1
|
313
1477
|
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
1478
|
+
if down_block_type == "CrossAttnDownBlockMotion":
|
1479
|
+
down_block = CrossAttnDownBlockMotion(
|
1480
|
+
in_channels=input_channel,
|
1481
|
+
out_channels=output_channel,
|
1482
|
+
temb_channels=time_embed_dim,
|
1483
|
+
num_layers=layers_per_block[i],
|
1484
|
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
1485
|
+
resnet_eps=norm_eps,
|
1486
|
+
resnet_act_fn=act_fn,
|
1487
|
+
resnet_groups=norm_num_groups,
|
1488
|
+
num_attention_heads=num_attention_heads[i],
|
1489
|
+
cross_attention_dim=cross_attention_dim[i],
|
1490
|
+
downsample_padding=downsample_padding,
|
1491
|
+
add_downsample=not is_final_block,
|
1492
|
+
use_linear_projection=use_linear_projection,
|
1493
|
+
temporal_num_attention_heads=motion_num_attention_heads[i],
|
1494
|
+
temporal_max_seq_length=motion_max_seq_length,
|
1495
|
+
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
|
1496
|
+
)
|
1497
|
+
elif down_block_type == "DownBlockMotion":
|
1498
|
+
down_block = DownBlockMotion(
|
1499
|
+
in_channels=input_channel,
|
1500
|
+
out_channels=output_channel,
|
1501
|
+
temb_channels=time_embed_dim,
|
1502
|
+
num_layers=layers_per_block[i],
|
1503
|
+
resnet_eps=norm_eps,
|
1504
|
+
resnet_act_fn=act_fn,
|
1505
|
+
resnet_groups=norm_num_groups,
|
1506
|
+
add_downsample=not is_final_block,
|
1507
|
+
downsample_padding=downsample_padding,
|
1508
|
+
temporal_num_attention_heads=motion_num_attention_heads[i],
|
1509
|
+
temporal_max_seq_length=motion_max_seq_length,
|
1510
|
+
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
|
1511
|
+
)
|
1512
|
+
else:
|
1513
|
+
raise ValueError(
|
1514
|
+
"Invalid `down_block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`"
|
1515
|
+
)
|
1516
|
+
|
333
1517
|
self.down_blocks.append(down_block)
|
334
1518
|
|
335
1519
|
# mid
|
1520
|
+
if transformer_layers_per_mid_block is None:
|
1521
|
+
transformer_layers_per_mid_block = (
|
1522
|
+
transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1
|
1523
|
+
)
|
1524
|
+
|
336
1525
|
if use_motion_mid_block:
|
337
1526
|
self.mid_block = UNetMidBlockCrossAttnMotion(
|
338
1527
|
in_channels=block_out_channels[-1],
|
@@ -345,9 +1534,11 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
345
1534
|
resnet_groups=norm_num_groups,
|
346
1535
|
dual_cross_attention=False,
|
347
1536
|
use_linear_projection=use_linear_projection,
|
348
|
-
|
1537
|
+
num_layers=mid_block_layers,
|
1538
|
+
temporal_num_attention_heads=motion_num_attention_heads[-1],
|
349
1539
|
temporal_max_seq_length=motion_max_seq_length,
|
350
|
-
transformer_layers_per_block=
|
1540
|
+
transformer_layers_per_block=transformer_layers_per_mid_block,
|
1541
|
+
temporal_transformer_layers_per_block=temporal_transformer_layers_per_mid_block,
|
351
1542
|
)
|
352
1543
|
|
353
1544
|
else:
|
@@ -362,7 +1553,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
362
1553
|
resnet_groups=norm_num_groups,
|
363
1554
|
dual_cross_attention=False,
|
364
1555
|
use_linear_projection=use_linear_projection,
|
365
|
-
|
1556
|
+
num_layers=mid_block_layers,
|
1557
|
+
transformer_layers_per_block=transformer_layers_per_mid_block,
|
366
1558
|
)
|
367
1559
|
|
368
1560
|
# count how many layers upsample the images
|
@@ -373,7 +1565,13 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
373
1565
|
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
374
1566
|
reversed_layers_per_block = list(reversed(layers_per_block))
|
375
1567
|
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
376
|
-
|
1568
|
+
reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads))
|
1569
|
+
|
1570
|
+
if reverse_transformer_layers_per_block is None:
|
1571
|
+
reverse_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
1572
|
+
|
1573
|
+
if reverse_temporal_transformer_layers_per_block is None:
|
1574
|
+
reverse_temporal_transformer_layers_per_block = list(reversed(temporal_transformer_layers_per_block))
|
377
1575
|
|
378
1576
|
output_channel = reversed_block_out_channels[0]
|
379
1577
|
for i, up_block_type in enumerate(up_block_types):
|
@@ -390,26 +1588,47 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
390
1588
|
else:
|
391
1589
|
add_upsample = False
|
392
1590
|
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
1591
|
+
if up_block_type == "CrossAttnUpBlockMotion":
|
1592
|
+
up_block = CrossAttnUpBlockMotion(
|
1593
|
+
in_channels=input_channel,
|
1594
|
+
out_channels=output_channel,
|
1595
|
+
prev_output_channel=prev_output_channel,
|
1596
|
+
temb_channels=time_embed_dim,
|
1597
|
+
resolution_idx=i,
|
1598
|
+
num_layers=reversed_layers_per_block[i] + 1,
|
1599
|
+
transformer_layers_per_block=reverse_transformer_layers_per_block[i],
|
1600
|
+
resnet_eps=norm_eps,
|
1601
|
+
resnet_act_fn=act_fn,
|
1602
|
+
resnet_groups=norm_num_groups,
|
1603
|
+
num_attention_heads=reversed_num_attention_heads[i],
|
1604
|
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
1605
|
+
add_upsample=add_upsample,
|
1606
|
+
use_linear_projection=use_linear_projection,
|
1607
|
+
temporal_num_attention_heads=reversed_motion_num_attention_heads[i],
|
1608
|
+
temporal_max_seq_length=motion_max_seq_length,
|
1609
|
+
temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i],
|
1610
|
+
)
|
1611
|
+
elif up_block_type == "UpBlockMotion":
|
1612
|
+
up_block = UpBlockMotion(
|
1613
|
+
in_channels=input_channel,
|
1614
|
+
prev_output_channel=prev_output_channel,
|
1615
|
+
out_channels=output_channel,
|
1616
|
+
temb_channels=time_embed_dim,
|
1617
|
+
resolution_idx=i,
|
1618
|
+
num_layers=reversed_layers_per_block[i] + 1,
|
1619
|
+
resnet_eps=norm_eps,
|
1620
|
+
resnet_act_fn=act_fn,
|
1621
|
+
resnet_groups=norm_num_groups,
|
1622
|
+
add_upsample=add_upsample,
|
1623
|
+
temporal_num_attention_heads=reversed_motion_num_attention_heads[i],
|
1624
|
+
temporal_max_seq_length=motion_max_seq_length,
|
1625
|
+
temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i],
|
1626
|
+
)
|
1627
|
+
else:
|
1628
|
+
raise ValueError(
|
1629
|
+
"Invalid `up_block_type` encountered. Must be one of `CrossAttnUpBlockMotion` or `UpBlockMotion`"
|
1630
|
+
)
|
1631
|
+
|
413
1632
|
self.up_blocks.append(up_block)
|
414
1633
|
prev_output_channel = output_channel
|
415
1634
|
|
@@ -440,6 +1659,24 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
440
1659
|
if has_motion_adapter:
|
441
1660
|
motion_adapter.to(device=unet.device)
|
442
1661
|
|
1662
|
+
# check compatibility of number of blocks
|
1663
|
+
if len(unet.config["down_block_types"]) != len(motion_adapter.config["block_out_channels"]):
|
1664
|
+
raise ValueError("Incompatible Motion Adapter, got different number of blocks")
|
1665
|
+
|
1666
|
+
# check layers compatibility for each block
|
1667
|
+
if isinstance(unet.config["layers_per_block"], int):
|
1668
|
+
expanded_layers_per_block = [unet.config["layers_per_block"]] * len(unet.config["down_block_types"])
|
1669
|
+
else:
|
1670
|
+
expanded_layers_per_block = list(unet.config["layers_per_block"])
|
1671
|
+
if isinstance(motion_adapter.config["motion_layers_per_block"], int):
|
1672
|
+
expanded_adapter_layers_per_block = [motion_adapter.config["motion_layers_per_block"]] * len(
|
1673
|
+
motion_adapter.config["block_out_channels"]
|
1674
|
+
)
|
1675
|
+
else:
|
1676
|
+
expanded_adapter_layers_per_block = list(motion_adapter.config["motion_layers_per_block"])
|
1677
|
+
if expanded_layers_per_block != expanded_adapter_layers_per_block:
|
1678
|
+
raise ValueError("Incompatible Motion Adapter, got different number of layers per block")
|
1679
|
+
|
443
1680
|
# based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459
|
444
1681
|
config = dict(unet.config)
|
445
1682
|
config["_class_name"] = cls.__name__
|
@@ -458,13 +1695,20 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
458
1695
|
up_blocks.append("CrossAttnUpBlockMotion")
|
459
1696
|
else:
|
460
1697
|
up_blocks.append("UpBlockMotion")
|
461
|
-
|
462
1698
|
config["up_block_types"] = up_blocks
|
463
1699
|
|
464
1700
|
if has_motion_adapter:
|
465
1701
|
config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"]
|
466
1702
|
config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"]
|
467
1703
|
config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"]
|
1704
|
+
config["layers_per_block"] = motion_adapter.config["motion_layers_per_block"]
|
1705
|
+
config["temporal_transformer_layers_per_mid_block"] = motion_adapter.config[
|
1706
|
+
"motion_transformer_layers_per_mid_block"
|
1707
|
+
]
|
1708
|
+
config["temporal_transformer_layers_per_block"] = motion_adapter.config[
|
1709
|
+
"motion_transformer_layers_per_block"
|
1710
|
+
]
|
1711
|
+
config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"]
|
468
1712
|
|
469
1713
|
# For PIA UNets we need to set the number input channels to 9
|
470
1714
|
if motion_adapter.config["conv_in_channels"]:
|
@@ -474,7 +1718,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
474
1718
|
if not config.get("num_attention_heads"):
|
475
1719
|
config["num_attention_heads"] = config["attention_head_dim"]
|
476
1720
|
|
477
|
-
|
1721
|
+
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
|
1722
|
+
config = FrozenDict({k: config.get(k) for k in config if k in expected_kwargs or k in optional_kwargs})
|
1723
|
+
config["_class_name"] = cls.__name__
|
478
1724
|
model = cls.from_config(config)
|
479
1725
|
|
480
1726
|
if not load_weights:
|
@@ -637,7 +1883,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
637
1883
|
|
638
1884
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
639
1885
|
if hasattr(module, "get_processor"):
|
640
|
-
processors[f"{name}.processor"] = module.get_processor(
|
1886
|
+
processors[f"{name}.processor"] = module.get_processor()
|
641
1887
|
|
642
1888
|
for sub_name, child in module.named_children():
|
643
1889
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
@@ -684,7 +1930,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
684
1930
|
for name, module in self.named_children():
|
685
1931
|
fn_recursive_attn_processor(name, module, processor)
|
686
1932
|
|
687
|
-
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
688
1933
|
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
689
1934
|
"""
|
690
1935
|
Sets the attention processor to use [feed forward
|
@@ -714,7 +1959,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
714
1959
|
for module in self.children():
|
715
1960
|
fn_recursive_feed_forward(module, chunk_size, dim)
|
716
1961
|
|
717
|
-
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
718
1962
|
def disable_forward_chunking(self) -> None:
|
719
1963
|
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
720
1964
|
if hasattr(module, "set_chunk_feed_forward"):
|
@@ -804,6 +2048,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
804
2048
|
if isinstance(module, Attention):
|
805
2049
|
module.fuse_projections(fuse=True)
|
806
2050
|
|
2051
|
+
self.set_attn_processor(FusedAttnProcessor2_0())
|
2052
|
+
|
807
2053
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
808
2054
|
def unfuse_qkv_projections(self):
|
809
2055
|
"""Disables the fused QKV projection if enabled.
|
@@ -830,7 +2076,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
830
2076
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
831
2077
|
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
832
2078
|
return_dict: bool = True,
|
833
|
-
) -> Union[
|
2079
|
+
) -> Union[UNetMotionOutput, Tuple[torch.Tensor]]:
|
834
2080
|
r"""
|
835
2081
|
The [`UNetMotionModel`] forward method.
|
836
2082
|
|
@@ -856,12 +2102,12 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
856
2102
|
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
857
2103
|
A tensor that if specified is added to the residual of the middle unet block.
|
858
2104
|
return_dict (`bool`, *optional*, defaults to `True`):
|
859
|
-
Whether or not to return a [`~models.unets.
|
2105
|
+
Whether or not to return a [`~models.unets.unet_motion_model.UNetMotionOutput`] instead of a plain
|
860
2106
|
tuple.
|
861
2107
|
|
862
2108
|
Returns:
|
863
|
-
[`~models.unets.
|
864
|
-
If `return_dict` is True, an [`~models.unets.
|
2109
|
+
[`~models.unets.unet_motion_model.UNetMotionOutput`] or `tuple`:
|
2110
|
+
If `return_dict` is True, an [`~models.unets.unet_motion_model.UNetMotionOutput`] is returned,
|
865
2111
|
otherwise a `tuple` is returned where the first element is the sample tensor.
|
866
2112
|
"""
|
867
2113
|
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
@@ -1045,4 +2291,4 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
1045
2291
|
if not return_dict:
|
1046
2292
|
return (sample,)
|
1047
2293
|
|
1048
|
-
return
|
2294
|
+
return UNetMotionOutput(sample=sample)
|