diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -2
- diffusers/commands/fp16_safetensors.py +10 -11
- diffusers/configuration_utils.py +13 -8
- diffusers/dependency_versions_check.py +0 -1
- diffusers/dependency_versions_table.py +5 -5
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +463 -51
- diffusers/loaders/__init__.py +82 -0
- diffusers/loaders/ip_adapter.py +159 -0
- diffusers/loaders/lora.py +1553 -0
- diffusers/loaders/lora_conversion_utils.py +284 -0
- diffusers/loaders/single_file.py +637 -0
- diffusers/loaders/textual_inversion.py +455 -0
- diffusers/loaders/unet.py +828 -0
- diffusers/loaders/utils.py +59 -0
- diffusers/models/__init__.py +26 -9
- diffusers/models/activations.py +9 -6
- diffusers/models/attention.py +301 -29
- diffusers/models/attention_flax.py +9 -1
- diffusers/models/attention_processor.py +378 -6
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
- diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
- diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
- diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
- diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
- diffusers/models/controlnet.py +59 -39
- diffusers/models/controlnet_flax.py +19 -18
- diffusers/models/downsampling.py +338 -0
- diffusers/models/embeddings.py +112 -29
- diffusers/models/embeddings_flax.py +2 -0
- diffusers/models/lora.py +131 -1
- diffusers/models/modeling_flax_utils.py +14 -8
- diffusers/models/modeling_outputs.py +17 -0
- diffusers/models/modeling_utils.py +37 -29
- diffusers/models/normalization.py +110 -4
- diffusers/models/resnet.py +299 -652
- diffusers/models/transformer_2d.py +22 -5
- diffusers/models/transformer_temporal.py +183 -1
- diffusers/models/unet_2d_blocks_flax.py +5 -0
- diffusers/models/unet_2d_condition.py +46 -0
- diffusers/models/unet_2d_condition_flax.py +13 -13
- diffusers/models/unet_3d_blocks.py +957 -173
- diffusers/models/unet_3d_condition.py +16 -8
- diffusers/models/unet_kandinsky3.py +535 -0
- diffusers/models/unet_motion_model.py +48 -33
- diffusers/models/unet_spatio_temporal_condition.py +489 -0
- diffusers/models/upsampling.py +454 -0
- diffusers/models/uvit_2d.py +471 -0
- diffusers/models/vae_flax.py +7 -0
- diffusers/models/vq_model.py +12 -3
- diffusers/optimization.py +16 -9
- diffusers/pipelines/__init__.py +137 -76
- diffusers/pipelines/amused/__init__.py +62 -0
- diffusers/pipelines/amused/pipeline_amused.py +328 -0
- diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
- diffusers/pipelines/auto_pipeline.py +23 -13
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
- diffusers/pipelines/deprecated/__init__.py +153 -0
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
- diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
- diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
- diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
- diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
- diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +1 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
- diffusers/pipelines/kandinsky3/__init__.py +49 -0
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
- diffusers/pipelines/onnx_utils.py +8 -5
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
- diffusers/pipelines/pipeline_flax_utils.py +11 -8
- diffusers/pipelines/pipeline_utils.py +63 -42
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/__init__.py +37 -65
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
- diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
- diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
- diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
- diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
- diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
- diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
- diffusers/schedulers/__init__.py +4 -4
- diffusers/schedulers/deprecated/__init__.py +50 -0
- diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
- diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
- diffusers/schedulers/scheduling_amused.py +162 -0
- diffusers/schedulers/scheduling_consistency_models.py +2 -0
- diffusers/schedulers/scheduling_ddim.py +1 -3
- diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
- diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
- diffusers/schedulers/scheduling_ddpm.py +47 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
- diffusers/schedulers/scheduling_deis_multistep.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
- diffusers/schedulers/scheduling_euler_discrete.py +102 -16
- diffusers/schedulers/scheduling_heun_discrete.py +17 -5
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
- diffusers/schedulers/scheduling_lcm.py +123 -29
- diffusers/schedulers/scheduling_lms_discrete.py +3 -3
- diffusers/schedulers/scheduling_pndm.py +1 -3
- diffusers/schedulers/scheduling_repaint.py +1 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
- diffusers/schedulers/scheduling_utils.py +3 -1
- diffusers/schedulers/scheduling_utils_flax.py +3 -1
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +1 -2
- diffusers/utils/constants.py +10 -12
- diffusers/utils/dummy_pt_objects.py +75 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
- diffusers/utils/dynamic_modules_utils.py +18 -22
- diffusers/utils/export_utils.py +8 -3
- diffusers/utils/hub_utils.py +24 -36
- diffusers/utils/logging.py +11 -11
- diffusers/utils/outputs.py +5 -5
- diffusers/utils/peft_utils.py +88 -44
- diffusers/utils/state_dict_utils.py +8 -0
- diffusers/utils/testing_utils.py +199 -1
- diffusers/utils/torch_utils.py +4 -4
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
- diffusers-0.25.0.dist-info/RECORD +360 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
- diffusers/loaders.py +0 -3336
- diffusers-0.23.1.dist-info/RECORD +0 -323
- /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -12,40 +12,58 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Any, Dict, Optional, Tuple
|
15
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from torch import nn
|
19
19
|
|
20
20
|
from ..utils import is_torch_version
|
21
21
|
from ..utils.torch_utils import apply_freeu
|
22
|
+
from .attention import Attention
|
22
23
|
from .dual_transformer_2d import DualTransformer2DModel
|
23
|
-
from .resnet import
|
24
|
+
from .resnet import (
|
25
|
+
Downsample2D,
|
26
|
+
ResnetBlock2D,
|
27
|
+
SpatioTemporalResBlock,
|
28
|
+
TemporalConvLayer,
|
29
|
+
Upsample2D,
|
30
|
+
)
|
24
31
|
from .transformer_2d import Transformer2DModel
|
25
|
-
from .transformer_temporal import
|
32
|
+
from .transformer_temporal import (
|
33
|
+
TransformerSpatioTemporalModel,
|
34
|
+
TransformerTemporalModel,
|
35
|
+
)
|
26
36
|
|
27
37
|
|
28
38
|
def get_down_block(
|
29
|
-
down_block_type,
|
30
|
-
num_layers,
|
31
|
-
in_channels,
|
32
|
-
out_channels,
|
33
|
-
temb_channels,
|
34
|
-
add_downsample,
|
35
|
-
resnet_eps,
|
36
|
-
resnet_act_fn,
|
37
|
-
num_attention_heads,
|
38
|
-
resnet_groups=None,
|
39
|
-
cross_attention_dim=None,
|
40
|
-
downsample_padding=None,
|
41
|
-
dual_cross_attention=False,
|
42
|
-
use_linear_projection=True,
|
43
|
-
only_cross_attention=False,
|
44
|
-
upcast_attention=False,
|
45
|
-
resnet_time_scale_shift="default",
|
46
|
-
temporal_num_attention_heads=8,
|
47
|
-
temporal_max_seq_length=32,
|
48
|
-
|
39
|
+
down_block_type: str,
|
40
|
+
num_layers: int,
|
41
|
+
in_channels: int,
|
42
|
+
out_channels: int,
|
43
|
+
temb_channels: int,
|
44
|
+
add_downsample: bool,
|
45
|
+
resnet_eps: float,
|
46
|
+
resnet_act_fn: str,
|
47
|
+
num_attention_heads: int,
|
48
|
+
resnet_groups: Optional[int] = None,
|
49
|
+
cross_attention_dim: Optional[int] = None,
|
50
|
+
downsample_padding: Optional[int] = None,
|
51
|
+
dual_cross_attention: bool = False,
|
52
|
+
use_linear_projection: bool = True,
|
53
|
+
only_cross_attention: bool = False,
|
54
|
+
upcast_attention: bool = False,
|
55
|
+
resnet_time_scale_shift: str = "default",
|
56
|
+
temporal_num_attention_heads: int = 8,
|
57
|
+
temporal_max_seq_length: int = 32,
|
58
|
+
transformer_layers_per_block: int = 1,
|
59
|
+
) -> Union[
|
60
|
+
"DownBlock3D",
|
61
|
+
"CrossAttnDownBlock3D",
|
62
|
+
"DownBlockMotion",
|
63
|
+
"CrossAttnDownBlockMotion",
|
64
|
+
"DownBlockSpatioTemporal",
|
65
|
+
"CrossAttnDownBlockSpatioTemporal",
|
66
|
+
]:
|
49
67
|
if down_block_type == "DownBlock3D":
|
50
68
|
return DownBlock3D(
|
51
69
|
num_layers=num_layers,
|
@@ -118,33 +136,65 @@ def get_down_block(
|
|
118
136
|
temporal_num_attention_heads=temporal_num_attention_heads,
|
119
137
|
temporal_max_seq_length=temporal_max_seq_length,
|
120
138
|
)
|
139
|
+
elif down_block_type == "DownBlockSpatioTemporal":
|
140
|
+
# added for SDV
|
141
|
+
return DownBlockSpatioTemporal(
|
142
|
+
num_layers=num_layers,
|
143
|
+
in_channels=in_channels,
|
144
|
+
out_channels=out_channels,
|
145
|
+
temb_channels=temb_channels,
|
146
|
+
add_downsample=add_downsample,
|
147
|
+
)
|
148
|
+
elif down_block_type == "CrossAttnDownBlockSpatioTemporal":
|
149
|
+
# added for SDV
|
150
|
+
if cross_attention_dim is None:
|
151
|
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal")
|
152
|
+
return CrossAttnDownBlockSpatioTemporal(
|
153
|
+
in_channels=in_channels,
|
154
|
+
out_channels=out_channels,
|
155
|
+
temb_channels=temb_channels,
|
156
|
+
num_layers=num_layers,
|
157
|
+
transformer_layers_per_block=transformer_layers_per_block,
|
158
|
+
add_downsample=add_downsample,
|
159
|
+
cross_attention_dim=cross_attention_dim,
|
160
|
+
num_attention_heads=num_attention_heads,
|
161
|
+
)
|
121
162
|
|
122
163
|
raise ValueError(f"{down_block_type} does not exist.")
|
123
164
|
|
124
165
|
|
125
166
|
def get_up_block(
|
126
|
-
up_block_type,
|
127
|
-
num_layers,
|
128
|
-
in_channels,
|
129
|
-
out_channels,
|
130
|
-
prev_output_channel,
|
131
|
-
temb_channels,
|
132
|
-
add_upsample,
|
133
|
-
resnet_eps,
|
134
|
-
resnet_act_fn,
|
135
|
-
num_attention_heads,
|
136
|
-
resolution_idx=None,
|
137
|
-
resnet_groups=None,
|
138
|
-
cross_attention_dim=None,
|
139
|
-
dual_cross_attention=False,
|
140
|
-
use_linear_projection=True,
|
141
|
-
only_cross_attention=False,
|
142
|
-
upcast_attention=False,
|
143
|
-
resnet_time_scale_shift="default",
|
144
|
-
temporal_num_attention_heads=8,
|
145
|
-
temporal_cross_attention_dim=None,
|
146
|
-
temporal_max_seq_length=32,
|
147
|
-
|
167
|
+
up_block_type: str,
|
168
|
+
num_layers: int,
|
169
|
+
in_channels: int,
|
170
|
+
out_channels: int,
|
171
|
+
prev_output_channel: int,
|
172
|
+
temb_channels: int,
|
173
|
+
add_upsample: bool,
|
174
|
+
resnet_eps: float,
|
175
|
+
resnet_act_fn: str,
|
176
|
+
num_attention_heads: int,
|
177
|
+
resolution_idx: Optional[int] = None,
|
178
|
+
resnet_groups: Optional[int] = None,
|
179
|
+
cross_attention_dim: Optional[int] = None,
|
180
|
+
dual_cross_attention: bool = False,
|
181
|
+
use_linear_projection: bool = True,
|
182
|
+
only_cross_attention: bool = False,
|
183
|
+
upcast_attention: bool = False,
|
184
|
+
resnet_time_scale_shift: str = "default",
|
185
|
+
temporal_num_attention_heads: int = 8,
|
186
|
+
temporal_cross_attention_dim: Optional[int] = None,
|
187
|
+
temporal_max_seq_length: int = 32,
|
188
|
+
transformer_layers_per_block: int = 1,
|
189
|
+
dropout: float = 0.0,
|
190
|
+
) -> Union[
|
191
|
+
"UpBlock3D",
|
192
|
+
"CrossAttnUpBlock3D",
|
193
|
+
"UpBlockMotion",
|
194
|
+
"CrossAttnUpBlockMotion",
|
195
|
+
"UpBlockSpatioTemporal",
|
196
|
+
"CrossAttnUpBlockSpatioTemporal",
|
197
|
+
]:
|
148
198
|
if up_block_type == "UpBlock3D":
|
149
199
|
return UpBlock3D(
|
150
200
|
num_layers=num_layers,
|
@@ -221,6 +271,34 @@ def get_up_block(
|
|
221
271
|
temporal_num_attention_heads=temporal_num_attention_heads,
|
222
272
|
temporal_max_seq_length=temporal_max_seq_length,
|
223
273
|
)
|
274
|
+
elif up_block_type == "UpBlockSpatioTemporal":
|
275
|
+
# added for SDV
|
276
|
+
return UpBlockSpatioTemporal(
|
277
|
+
num_layers=num_layers,
|
278
|
+
in_channels=in_channels,
|
279
|
+
out_channels=out_channels,
|
280
|
+
prev_output_channel=prev_output_channel,
|
281
|
+
temb_channels=temb_channels,
|
282
|
+
resolution_idx=resolution_idx,
|
283
|
+
add_upsample=add_upsample,
|
284
|
+
)
|
285
|
+
elif up_block_type == "CrossAttnUpBlockSpatioTemporal":
|
286
|
+
# added for SDV
|
287
|
+
if cross_attention_dim is None:
|
288
|
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal")
|
289
|
+
return CrossAttnUpBlockSpatioTemporal(
|
290
|
+
in_channels=in_channels,
|
291
|
+
out_channels=out_channels,
|
292
|
+
prev_output_channel=prev_output_channel,
|
293
|
+
temb_channels=temb_channels,
|
294
|
+
num_layers=num_layers,
|
295
|
+
transformer_layers_per_block=transformer_layers_per_block,
|
296
|
+
add_upsample=add_upsample,
|
297
|
+
cross_attention_dim=cross_attention_dim,
|
298
|
+
num_attention_heads=num_attention_heads,
|
299
|
+
resolution_idx=resolution_idx,
|
300
|
+
)
|
301
|
+
|
224
302
|
raise ValueError(f"{up_block_type} does not exist.")
|
225
303
|
|
226
304
|
|
@@ -236,12 +314,12 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
|
236
314
|
resnet_act_fn: str = "swish",
|
237
315
|
resnet_groups: int = 32,
|
238
316
|
resnet_pre_norm: bool = True,
|
239
|
-
num_attention_heads=1,
|
240
|
-
output_scale_factor=1.0,
|
241
|
-
cross_attention_dim=1280,
|
242
|
-
dual_cross_attention=False,
|
243
|
-
use_linear_projection=True,
|
244
|
-
upcast_attention=False,
|
317
|
+
num_attention_heads: int = 1,
|
318
|
+
output_scale_factor: float = 1.0,
|
319
|
+
cross_attention_dim: int = 1280,
|
320
|
+
dual_cross_attention: bool = False,
|
321
|
+
use_linear_projection: bool = True,
|
322
|
+
upcast_attention: bool = False,
|
245
323
|
):
|
246
324
|
super().__init__()
|
247
325
|
|
@@ -269,6 +347,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
|
269
347
|
in_channels,
|
270
348
|
in_channels,
|
271
349
|
dropout=0.1,
|
350
|
+
norm_num_groups=resnet_groups,
|
272
351
|
)
|
273
352
|
]
|
274
353
|
attentions = []
|
@@ -316,6 +395,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
|
316
395
|
in_channels,
|
317
396
|
in_channels,
|
318
397
|
dropout=0.1,
|
398
|
+
norm_num_groups=resnet_groups,
|
319
399
|
)
|
320
400
|
)
|
321
401
|
|
@@ -326,13 +406,13 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
|
326
406
|
|
327
407
|
def forward(
|
328
408
|
self,
|
329
|
-
hidden_states,
|
330
|
-
temb=None,
|
331
|
-
encoder_hidden_states=None,
|
332
|
-
attention_mask=None,
|
333
|
-
num_frames=1,
|
334
|
-
cross_attention_kwargs=None,
|
335
|
-
):
|
409
|
+
hidden_states: torch.FloatTensor,
|
410
|
+
temb: Optional[torch.FloatTensor] = None,
|
411
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
412
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
413
|
+
num_frames: int = 1,
|
414
|
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
415
|
+
) -> torch.FloatTensor:
|
336
416
|
hidden_states = self.resnets[0](hidden_states, temb)
|
337
417
|
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
|
338
418
|
for attn, temp_attn, resnet, temp_conv in zip(
|
@@ -345,7 +425,10 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
|
345
425
|
return_dict=False,
|
346
426
|
)[0]
|
347
427
|
hidden_states = temp_attn(
|
348
|
-
hidden_states,
|
428
|
+
hidden_states,
|
429
|
+
num_frames=num_frames,
|
430
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
431
|
+
return_dict=False,
|
349
432
|
)[0]
|
350
433
|
hidden_states = resnet(hidden_states, temb)
|
351
434
|
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
@@ -366,15 +449,15 @@ class CrossAttnDownBlock3D(nn.Module):
|
|
366
449
|
resnet_act_fn: str = "swish",
|
367
450
|
resnet_groups: int = 32,
|
368
451
|
resnet_pre_norm: bool = True,
|
369
|
-
num_attention_heads=1,
|
370
|
-
cross_attention_dim=1280,
|
371
|
-
output_scale_factor=1.0,
|
372
|
-
downsample_padding=1,
|
373
|
-
add_downsample=True,
|
374
|
-
dual_cross_attention=False,
|
375
|
-
use_linear_projection=False,
|
376
|
-
only_cross_attention=False,
|
377
|
-
upcast_attention=False,
|
452
|
+
num_attention_heads: int = 1,
|
453
|
+
cross_attention_dim: int = 1280,
|
454
|
+
output_scale_factor: float = 1.0,
|
455
|
+
downsample_padding: int = 1,
|
456
|
+
add_downsample: bool = True,
|
457
|
+
dual_cross_attention: bool = False,
|
458
|
+
use_linear_projection: bool = False,
|
459
|
+
only_cross_attention: bool = False,
|
460
|
+
upcast_attention: bool = False,
|
378
461
|
):
|
379
462
|
super().__init__()
|
380
463
|
resnets = []
|
@@ -406,6 +489,7 @@ class CrossAttnDownBlock3D(nn.Module):
|
|
406
489
|
out_channels,
|
407
490
|
out_channels,
|
408
491
|
dropout=0.1,
|
492
|
+
norm_num_groups=resnet_groups,
|
409
493
|
)
|
410
494
|
)
|
411
495
|
attentions.append(
|
@@ -440,7 +524,11 @@ class CrossAttnDownBlock3D(nn.Module):
|
|
440
524
|
self.downsamplers = nn.ModuleList(
|
441
525
|
[
|
442
526
|
Downsample2D(
|
443
|
-
out_channels,
|
527
|
+
out_channels,
|
528
|
+
use_conv=True,
|
529
|
+
out_channels=out_channels,
|
530
|
+
padding=downsample_padding,
|
531
|
+
name="op",
|
444
532
|
)
|
445
533
|
]
|
446
534
|
)
|
@@ -451,13 +539,13 @@ class CrossAttnDownBlock3D(nn.Module):
|
|
451
539
|
|
452
540
|
def forward(
|
453
541
|
self,
|
454
|
-
hidden_states,
|
455
|
-
temb=None,
|
456
|
-
encoder_hidden_states=None,
|
457
|
-
attention_mask=None,
|
458
|
-
num_frames=1,
|
459
|
-
cross_attention_kwargs=None,
|
460
|
-
):
|
542
|
+
hidden_states: torch.FloatTensor,
|
543
|
+
temb: Optional[torch.FloatTensor] = None,
|
544
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
545
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
546
|
+
num_frames: int = 1,
|
547
|
+
cross_attention_kwargs: Dict[str, Any] = None,
|
548
|
+
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
461
549
|
# TODO(Patrick, William) - attention mask is not used
|
462
550
|
output_states = ()
|
463
551
|
|
@@ -473,7 +561,10 @@ class CrossAttnDownBlock3D(nn.Module):
|
|
473
561
|
return_dict=False,
|
474
562
|
)[0]
|
475
563
|
hidden_states = temp_attn(
|
476
|
-
hidden_states,
|
564
|
+
hidden_states,
|
565
|
+
num_frames=num_frames,
|
566
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
567
|
+
return_dict=False,
|
477
568
|
)[0]
|
478
569
|
|
479
570
|
output_states += (hidden_states,)
|
@@ -500,9 +591,9 @@ class DownBlock3D(nn.Module):
|
|
500
591
|
resnet_act_fn: str = "swish",
|
501
592
|
resnet_groups: int = 32,
|
502
593
|
resnet_pre_norm: bool = True,
|
503
|
-
output_scale_factor=1.0,
|
504
|
-
add_downsample=True,
|
505
|
-
downsample_padding=1,
|
594
|
+
output_scale_factor: float = 1.0,
|
595
|
+
add_downsample: bool = True,
|
596
|
+
downsample_padding: int = 1,
|
506
597
|
):
|
507
598
|
super().__init__()
|
508
599
|
resnets = []
|
@@ -529,6 +620,7 @@ class DownBlock3D(nn.Module):
|
|
529
620
|
out_channels,
|
530
621
|
out_channels,
|
531
622
|
dropout=0.1,
|
623
|
+
norm_num_groups=resnet_groups,
|
532
624
|
)
|
533
625
|
)
|
534
626
|
|
@@ -539,7 +631,11 @@ class DownBlock3D(nn.Module):
|
|
539
631
|
self.downsamplers = nn.ModuleList(
|
540
632
|
[
|
541
633
|
Downsample2D(
|
542
|
-
out_channels,
|
634
|
+
out_channels,
|
635
|
+
use_conv=True,
|
636
|
+
out_channels=out_channels,
|
637
|
+
padding=downsample_padding,
|
638
|
+
name="op",
|
543
639
|
)
|
544
640
|
]
|
545
641
|
)
|
@@ -548,7 +644,12 @@ class DownBlock3D(nn.Module):
|
|
548
644
|
|
549
645
|
self.gradient_checkpointing = False
|
550
646
|
|
551
|
-
def forward(
|
647
|
+
def forward(
|
648
|
+
self,
|
649
|
+
hidden_states: torch.FloatTensor,
|
650
|
+
temb: Optional[torch.FloatTensor] = None,
|
651
|
+
num_frames: int = 1,
|
652
|
+
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
552
653
|
output_states = ()
|
553
654
|
|
554
655
|
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
@@ -580,15 +681,15 @@ class CrossAttnUpBlock3D(nn.Module):
|
|
580
681
|
resnet_act_fn: str = "swish",
|
581
682
|
resnet_groups: int = 32,
|
582
683
|
resnet_pre_norm: bool = True,
|
583
|
-
num_attention_heads=1,
|
584
|
-
cross_attention_dim=1280,
|
585
|
-
output_scale_factor=1.0,
|
586
|
-
add_upsample=True,
|
587
|
-
dual_cross_attention=False,
|
588
|
-
use_linear_projection=False,
|
589
|
-
only_cross_attention=False,
|
590
|
-
upcast_attention=False,
|
591
|
-
resolution_idx=None,
|
684
|
+
num_attention_heads: int = 1,
|
685
|
+
cross_attention_dim: int = 1280,
|
686
|
+
output_scale_factor: float = 1.0,
|
687
|
+
add_upsample: bool = True,
|
688
|
+
dual_cross_attention: bool = False,
|
689
|
+
use_linear_projection: bool = False,
|
690
|
+
only_cross_attention: bool = False,
|
691
|
+
upcast_attention: bool = False,
|
692
|
+
resolution_idx: Optional[int] = None,
|
592
693
|
):
|
593
694
|
super().__init__()
|
594
695
|
resnets = []
|
@@ -622,6 +723,7 @@ class CrossAttnUpBlock3D(nn.Module):
|
|
622
723
|
out_channels,
|
623
724
|
out_channels,
|
624
725
|
dropout=0.1,
|
726
|
+
norm_num_groups=resnet_groups,
|
625
727
|
)
|
626
728
|
)
|
627
729
|
attentions.append(
|
@@ -662,15 +764,15 @@ class CrossAttnUpBlock3D(nn.Module):
|
|
662
764
|
|
663
765
|
def forward(
|
664
766
|
self,
|
665
|
-
hidden_states,
|
666
|
-
res_hidden_states_tuple,
|
667
|
-
temb=None,
|
668
|
-
encoder_hidden_states=None,
|
669
|
-
upsample_size=None,
|
670
|
-
attention_mask=None,
|
671
|
-
num_frames=1,
|
672
|
-
cross_attention_kwargs=None,
|
673
|
-
):
|
767
|
+
hidden_states: torch.FloatTensor,
|
768
|
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
769
|
+
temb: Optional[torch.FloatTensor] = None,
|
770
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
771
|
+
upsample_size: Optional[int] = None,
|
772
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
773
|
+
num_frames: int = 1,
|
774
|
+
cross_attention_kwargs: Dict[str, Any] = None,
|
775
|
+
) -> torch.FloatTensor:
|
674
776
|
is_freeu_enabled = (
|
675
777
|
getattr(self, "s1", None)
|
676
778
|
and getattr(self, "s2", None)
|
@@ -709,7 +811,10 @@ class CrossAttnUpBlock3D(nn.Module):
|
|
709
811
|
return_dict=False,
|
710
812
|
)[0]
|
711
813
|
hidden_states = temp_attn(
|
712
|
-
hidden_states,
|
814
|
+
hidden_states,
|
815
|
+
num_frames=num_frames,
|
816
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
817
|
+
return_dict=False,
|
713
818
|
)[0]
|
714
819
|
|
715
820
|
if self.upsamplers is not None:
|
@@ -733,9 +838,9 @@ class UpBlock3D(nn.Module):
|
|
733
838
|
resnet_act_fn: str = "swish",
|
734
839
|
resnet_groups: int = 32,
|
735
840
|
resnet_pre_norm: bool = True,
|
736
|
-
output_scale_factor=1.0,
|
737
|
-
add_upsample=True,
|
738
|
-
resolution_idx=None,
|
841
|
+
output_scale_factor: float = 1.0,
|
842
|
+
add_upsample: bool = True,
|
843
|
+
resolution_idx: Optional[int] = None,
|
739
844
|
):
|
740
845
|
super().__init__()
|
741
846
|
resnets = []
|
@@ -764,6 +869,7 @@ class UpBlock3D(nn.Module):
|
|
764
869
|
out_channels,
|
765
870
|
out_channels,
|
766
871
|
dropout=0.1,
|
872
|
+
norm_num_groups=resnet_groups,
|
767
873
|
)
|
768
874
|
)
|
769
875
|
|
@@ -778,7 +884,14 @@ class UpBlock3D(nn.Module):
|
|
778
884
|
self.gradient_checkpointing = False
|
779
885
|
self.resolution_idx = resolution_idx
|
780
886
|
|
781
|
-
def forward(
|
887
|
+
def forward(
|
888
|
+
self,
|
889
|
+
hidden_states: torch.FloatTensor,
|
890
|
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
891
|
+
temb: Optional[torch.FloatTensor] = None,
|
892
|
+
upsample_size: Optional[int] = None,
|
893
|
+
num_frames: int = 1,
|
894
|
+
) -> torch.FloatTensor:
|
782
895
|
is_freeu_enabled = (
|
783
896
|
getattr(self, "s1", None)
|
784
897
|
and getattr(self, "s2", None)
|
@@ -827,12 +940,12 @@ class DownBlockMotion(nn.Module):
|
|
827
940
|
resnet_act_fn: str = "swish",
|
828
941
|
resnet_groups: int = 32,
|
829
942
|
resnet_pre_norm: bool = True,
|
830
|
-
output_scale_factor=1.0,
|
831
|
-
add_downsample=True,
|
832
|
-
downsample_padding=1,
|
833
|
-
temporal_num_attention_heads=1,
|
834
|
-
temporal_cross_attention_dim=None,
|
835
|
-
temporal_max_seq_length=32,
|
943
|
+
output_scale_factor: float = 1.0,
|
944
|
+
add_downsample: bool = True,
|
945
|
+
downsample_padding: int = 1,
|
946
|
+
temporal_num_attention_heads: int = 1,
|
947
|
+
temporal_cross_attention_dim: Optional[int] = None,
|
948
|
+
temporal_max_seq_length: int = 32,
|
836
949
|
):
|
837
950
|
super().__init__()
|
838
951
|
resnets = []
|
@@ -875,7 +988,11 @@ class DownBlockMotion(nn.Module):
|
|
875
988
|
self.downsamplers = nn.ModuleList(
|
876
989
|
[
|
877
990
|
Downsample2D(
|
878
|
-
out_channels,
|
991
|
+
out_channels,
|
992
|
+
use_conv=True,
|
993
|
+
out_channels=out_channels,
|
994
|
+
padding=downsample_padding,
|
995
|
+
name="op",
|
879
996
|
)
|
880
997
|
]
|
881
998
|
)
|
@@ -884,7 +1001,13 @@ class DownBlockMotion(nn.Module):
|
|
884
1001
|
|
885
1002
|
self.gradient_checkpointing = False
|
886
1003
|
|
887
|
-
def forward(
|
1004
|
+
def forward(
|
1005
|
+
self,
|
1006
|
+
hidden_states: torch.FloatTensor,
|
1007
|
+
temb: Optional[torch.FloatTensor] = None,
|
1008
|
+
scale: float = 1.0,
|
1009
|
+
num_frames: int = 1,
|
1010
|
+
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
888
1011
|
output_states = ()
|
889
1012
|
|
890
1013
|
blocks = zip(self.resnets, self.motion_modules)
|
@@ -899,14 +1022,20 @@ class DownBlockMotion(nn.Module):
|
|
899
1022
|
|
900
1023
|
if is_torch_version(">=", "1.11.0"):
|
901
1024
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
902
|
-
create_custom_forward(resnet),
|
1025
|
+
create_custom_forward(resnet),
|
1026
|
+
hidden_states,
|
1027
|
+
temb,
|
1028
|
+
use_reentrant=False,
|
903
1029
|
)
|
904
1030
|
else:
|
905
1031
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
906
1032
|
create_custom_forward(resnet), hidden_states, temb, scale
|
907
1033
|
)
|
908
1034
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
909
|
-
create_custom_forward(motion_module),
|
1035
|
+
create_custom_forward(motion_module),
|
1036
|
+
hidden_states.requires_grad_(),
|
1037
|
+
temb,
|
1038
|
+
num_frames,
|
910
1039
|
)
|
911
1040
|
|
912
1041
|
else:
|
@@ -938,19 +1067,19 @@ class CrossAttnDownBlockMotion(nn.Module):
|
|
938
1067
|
resnet_act_fn: str = "swish",
|
939
1068
|
resnet_groups: int = 32,
|
940
1069
|
resnet_pre_norm: bool = True,
|
941
|
-
num_attention_heads=1,
|
942
|
-
cross_attention_dim=1280,
|
943
|
-
output_scale_factor=1.0,
|
944
|
-
downsample_padding=1,
|
945
|
-
add_downsample=True,
|
946
|
-
dual_cross_attention=False,
|
947
|
-
use_linear_projection=False,
|
948
|
-
only_cross_attention=False,
|
949
|
-
upcast_attention=False,
|
950
|
-
attention_type="default",
|
951
|
-
temporal_cross_attention_dim=None,
|
952
|
-
temporal_num_attention_heads=8,
|
953
|
-
temporal_max_seq_length=32,
|
1070
|
+
num_attention_heads: int = 1,
|
1071
|
+
cross_attention_dim: int = 1280,
|
1072
|
+
output_scale_factor: float = 1.0,
|
1073
|
+
downsample_padding: int = 1,
|
1074
|
+
add_downsample: bool = True,
|
1075
|
+
dual_cross_attention: bool = False,
|
1076
|
+
use_linear_projection: bool = False,
|
1077
|
+
only_cross_attention: bool = False,
|
1078
|
+
upcast_attention: bool = False,
|
1079
|
+
attention_type: str = "default",
|
1080
|
+
temporal_cross_attention_dim: Optional[int] = None,
|
1081
|
+
temporal_num_attention_heads: int = 8,
|
1082
|
+
temporal_max_seq_length: int = 32,
|
954
1083
|
):
|
955
1084
|
super().__init__()
|
956
1085
|
resnets = []
|
@@ -1026,7 +1155,11 @@ class CrossAttnDownBlockMotion(nn.Module):
|
|
1026
1155
|
self.downsamplers = nn.ModuleList(
|
1027
1156
|
[
|
1028
1157
|
Downsample2D(
|
1029
|
-
out_channels,
|
1158
|
+
out_channels,
|
1159
|
+
use_conv=True,
|
1160
|
+
out_channels=out_channels,
|
1161
|
+
padding=downsample_padding,
|
1162
|
+
name="op",
|
1030
1163
|
)
|
1031
1164
|
]
|
1032
1165
|
)
|
@@ -1037,14 +1170,14 @@ class CrossAttnDownBlockMotion(nn.Module):
|
|
1037
1170
|
|
1038
1171
|
def forward(
|
1039
1172
|
self,
|
1040
|
-
hidden_states,
|
1041
|
-
temb=None,
|
1042
|
-
encoder_hidden_states=None,
|
1043
|
-
attention_mask=None,
|
1044
|
-
num_frames=1,
|
1045
|
-
encoder_attention_mask=None,
|
1046
|
-
cross_attention_kwargs=None,
|
1047
|
-
additional_residuals=None,
|
1173
|
+
hidden_states: torch.FloatTensor,
|
1174
|
+
temb: Optional[torch.FloatTensor] = None,
|
1175
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1176
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1177
|
+
num_frames: int = 1,
|
1178
|
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1179
|
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1180
|
+
additional_residuals: Optional[torch.FloatTensor] = None,
|
1048
1181
|
):
|
1049
1182
|
output_states = ()
|
1050
1183
|
|
@@ -1115,7 +1248,7 @@ class CrossAttnUpBlockMotion(nn.Module):
|
|
1115
1248
|
out_channels: int,
|
1116
1249
|
prev_output_channel: int,
|
1117
1250
|
temb_channels: int,
|
1118
|
-
resolution_idx: int = None,
|
1251
|
+
resolution_idx: Optional[int] = None,
|
1119
1252
|
dropout: float = 0.0,
|
1120
1253
|
num_layers: int = 1,
|
1121
1254
|
transformer_layers_per_block: int = 1,
|
@@ -1124,18 +1257,18 @@ class CrossAttnUpBlockMotion(nn.Module):
|
|
1124
1257
|
resnet_act_fn: str = "swish",
|
1125
1258
|
resnet_groups: int = 32,
|
1126
1259
|
resnet_pre_norm: bool = True,
|
1127
|
-
num_attention_heads=1,
|
1128
|
-
cross_attention_dim=1280,
|
1129
|
-
output_scale_factor=1.0,
|
1130
|
-
add_upsample=True,
|
1131
|
-
dual_cross_attention=False,
|
1132
|
-
use_linear_projection=False,
|
1133
|
-
only_cross_attention=False,
|
1134
|
-
upcast_attention=False,
|
1135
|
-
attention_type="default",
|
1136
|
-
temporal_cross_attention_dim=None,
|
1137
|
-
temporal_num_attention_heads=8,
|
1138
|
-
temporal_max_seq_length=32,
|
1260
|
+
num_attention_heads: int = 1,
|
1261
|
+
cross_attention_dim: int = 1280,
|
1262
|
+
output_scale_factor: float = 1.0,
|
1263
|
+
add_upsample: bool = True,
|
1264
|
+
dual_cross_attention: bool = False,
|
1265
|
+
use_linear_projection: bool = False,
|
1266
|
+
only_cross_attention: bool = False,
|
1267
|
+
upcast_attention: bool = False,
|
1268
|
+
attention_type: str = "default",
|
1269
|
+
temporal_cross_attention_dim: Optional[int] = None,
|
1270
|
+
temporal_num_attention_heads: int = 8,
|
1271
|
+
temporal_max_seq_length: int = 32,
|
1139
1272
|
):
|
1140
1273
|
super().__init__()
|
1141
1274
|
resnets = []
|
@@ -1226,8 +1359,8 @@ class CrossAttnUpBlockMotion(nn.Module):
|
|
1226
1359
|
upsample_size: Optional[int] = None,
|
1227
1360
|
attention_mask: Optional[torch.FloatTensor] = None,
|
1228
1361
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1229
|
-
num_frames=1,
|
1230
|
-
):
|
1362
|
+
num_frames: int = 1,
|
1363
|
+
) -> torch.FloatTensor:
|
1231
1364
|
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
1232
1365
|
is_freeu_enabled = (
|
1233
1366
|
getattr(self, "s1", None)
|
@@ -1311,7 +1444,7 @@ class UpBlockMotion(nn.Module):
|
|
1311
1444
|
prev_output_channel: int,
|
1312
1445
|
out_channels: int,
|
1313
1446
|
temb_channels: int,
|
1314
|
-
resolution_idx: int = None,
|
1447
|
+
resolution_idx: Optional[int] = None,
|
1315
1448
|
dropout: float = 0.0,
|
1316
1449
|
num_layers: int = 1,
|
1317
1450
|
resnet_eps: float = 1e-6,
|
@@ -1319,12 +1452,12 @@ class UpBlockMotion(nn.Module):
|
|
1319
1452
|
resnet_act_fn: str = "swish",
|
1320
1453
|
resnet_groups: int = 32,
|
1321
1454
|
resnet_pre_norm: bool = True,
|
1322
|
-
output_scale_factor=1.0,
|
1323
|
-
add_upsample=True,
|
1324
|
-
temporal_norm_num_groups=32,
|
1325
|
-
temporal_cross_attention_dim=None,
|
1326
|
-
temporal_num_attention_heads=8,
|
1327
|
-
temporal_max_seq_length=32,
|
1455
|
+
output_scale_factor: float = 1.0,
|
1456
|
+
add_upsample: bool = True,
|
1457
|
+
temporal_norm_num_groups: int = 32,
|
1458
|
+
temporal_cross_attention_dim: Optional[int] = None,
|
1459
|
+
temporal_num_attention_heads: int = 8,
|
1460
|
+
temporal_max_seq_length: int = 32,
|
1328
1461
|
):
|
1329
1462
|
super().__init__()
|
1330
1463
|
resnets = []
|
@@ -1375,8 +1508,14 @@ class UpBlockMotion(nn.Module):
|
|
1375
1508
|
self.resolution_idx = resolution_idx
|
1376
1509
|
|
1377
1510
|
def forward(
|
1378
|
-
self,
|
1379
|
-
|
1511
|
+
self,
|
1512
|
+
hidden_states: torch.FloatTensor,
|
1513
|
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
1514
|
+
temb: Optional[torch.FloatTensor] = None,
|
1515
|
+
upsample_size=None,
|
1516
|
+
scale: float = 1.0,
|
1517
|
+
num_frames: int = 1,
|
1518
|
+
) -> torch.FloatTensor:
|
1380
1519
|
is_freeu_enabled = (
|
1381
1520
|
getattr(self, "s1", None)
|
1382
1521
|
and getattr(self, "s2", None)
|
@@ -1415,7 +1554,10 @@ class UpBlockMotion(nn.Module):
|
|
1415
1554
|
|
1416
1555
|
if is_torch_version(">=", "1.11.0"):
|
1417
1556
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
1418
|
-
create_custom_forward(resnet),
|
1557
|
+
create_custom_forward(resnet),
|
1558
|
+
hidden_states,
|
1559
|
+
temb,
|
1560
|
+
use_reentrant=False,
|
1419
1561
|
)
|
1420
1562
|
else:
|
1421
1563
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
@@ -1451,16 +1593,16 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
|
1451
1593
|
resnet_act_fn: str = "swish",
|
1452
1594
|
resnet_groups: int = 32,
|
1453
1595
|
resnet_pre_norm: bool = True,
|
1454
|
-
num_attention_heads=1,
|
1455
|
-
output_scale_factor=1.0,
|
1456
|
-
cross_attention_dim=1280,
|
1457
|
-
dual_cross_attention=False,
|
1458
|
-
use_linear_projection=False,
|
1459
|
-
upcast_attention=False,
|
1460
|
-
attention_type="default",
|
1461
|
-
temporal_num_attention_heads=1,
|
1462
|
-
temporal_cross_attention_dim=None,
|
1463
|
-
temporal_max_seq_length=32,
|
1596
|
+
num_attention_heads: int = 1,
|
1597
|
+
output_scale_factor: float = 1.0,
|
1598
|
+
cross_attention_dim: int = 1280,
|
1599
|
+
dual_cross_attention: float = False,
|
1600
|
+
use_linear_projection: float = False,
|
1601
|
+
upcast_attention: float = False,
|
1602
|
+
attention_type: str = "default",
|
1603
|
+
temporal_num_attention_heads: int = 1,
|
1604
|
+
temporal_cross_attention_dim: Optional[int] = None,
|
1605
|
+
temporal_max_seq_length: int = 32,
|
1464
1606
|
):
|
1465
1607
|
super().__init__()
|
1466
1608
|
|
@@ -1554,7 +1696,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
|
1554
1696
|
attention_mask: Optional[torch.FloatTensor] = None,
|
1555
1697
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1556
1698
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1557
|
-
num_frames=1,
|
1699
|
+
num_frames: int = 1,
|
1558
1700
|
) -> torch.FloatTensor:
|
1559
1701
|
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
1560
1702
|
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
@@ -1609,3 +1751,645 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
|
1609
1751
|
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
1610
1752
|
|
1611
1753
|
return hidden_states
|
1754
|
+
|
1755
|
+
|
1756
|
+
class MidBlockTemporalDecoder(nn.Module):
|
1757
|
+
def __init__(
|
1758
|
+
self,
|
1759
|
+
in_channels: int,
|
1760
|
+
out_channels: int,
|
1761
|
+
attention_head_dim: int = 512,
|
1762
|
+
num_layers: int = 1,
|
1763
|
+
upcast_attention: bool = False,
|
1764
|
+
):
|
1765
|
+
super().__init__()
|
1766
|
+
|
1767
|
+
resnets = []
|
1768
|
+
attentions = []
|
1769
|
+
for i in range(num_layers):
|
1770
|
+
input_channels = in_channels if i == 0 else out_channels
|
1771
|
+
resnets.append(
|
1772
|
+
SpatioTemporalResBlock(
|
1773
|
+
in_channels=input_channels,
|
1774
|
+
out_channels=out_channels,
|
1775
|
+
temb_channels=None,
|
1776
|
+
eps=1e-6,
|
1777
|
+
temporal_eps=1e-5,
|
1778
|
+
merge_factor=0.0,
|
1779
|
+
merge_strategy="learned",
|
1780
|
+
switch_spatial_to_temporal_mix=True,
|
1781
|
+
)
|
1782
|
+
)
|
1783
|
+
|
1784
|
+
attentions.append(
|
1785
|
+
Attention(
|
1786
|
+
query_dim=in_channels,
|
1787
|
+
heads=in_channels // attention_head_dim,
|
1788
|
+
dim_head=attention_head_dim,
|
1789
|
+
eps=1e-6,
|
1790
|
+
upcast_attention=upcast_attention,
|
1791
|
+
norm_num_groups=32,
|
1792
|
+
bias=True,
|
1793
|
+
residual_connection=True,
|
1794
|
+
)
|
1795
|
+
)
|
1796
|
+
|
1797
|
+
self.attentions = nn.ModuleList(attentions)
|
1798
|
+
self.resnets = nn.ModuleList(resnets)
|
1799
|
+
|
1800
|
+
def forward(
|
1801
|
+
self,
|
1802
|
+
hidden_states: torch.FloatTensor,
|
1803
|
+
image_only_indicator: torch.FloatTensor,
|
1804
|
+
):
|
1805
|
+
hidden_states = self.resnets[0](
|
1806
|
+
hidden_states,
|
1807
|
+
image_only_indicator=image_only_indicator,
|
1808
|
+
)
|
1809
|
+
for resnet, attn in zip(self.resnets[1:], self.attentions):
|
1810
|
+
hidden_states = attn(hidden_states)
|
1811
|
+
hidden_states = resnet(
|
1812
|
+
hidden_states,
|
1813
|
+
image_only_indicator=image_only_indicator,
|
1814
|
+
)
|
1815
|
+
|
1816
|
+
return hidden_states
|
1817
|
+
|
1818
|
+
|
1819
|
+
class UpBlockTemporalDecoder(nn.Module):
|
1820
|
+
def __init__(
|
1821
|
+
self,
|
1822
|
+
in_channels: int,
|
1823
|
+
out_channels: int,
|
1824
|
+
num_layers: int = 1,
|
1825
|
+
add_upsample: bool = True,
|
1826
|
+
):
|
1827
|
+
super().__init__()
|
1828
|
+
resnets = []
|
1829
|
+
for i in range(num_layers):
|
1830
|
+
input_channels = in_channels if i == 0 else out_channels
|
1831
|
+
|
1832
|
+
resnets.append(
|
1833
|
+
SpatioTemporalResBlock(
|
1834
|
+
in_channels=input_channels,
|
1835
|
+
out_channels=out_channels,
|
1836
|
+
temb_channels=None,
|
1837
|
+
eps=1e-6,
|
1838
|
+
temporal_eps=1e-5,
|
1839
|
+
merge_factor=0.0,
|
1840
|
+
merge_strategy="learned",
|
1841
|
+
switch_spatial_to_temporal_mix=True,
|
1842
|
+
)
|
1843
|
+
)
|
1844
|
+
self.resnets = nn.ModuleList(resnets)
|
1845
|
+
|
1846
|
+
if add_upsample:
|
1847
|
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1848
|
+
else:
|
1849
|
+
self.upsamplers = None
|
1850
|
+
|
1851
|
+
def forward(
|
1852
|
+
self,
|
1853
|
+
hidden_states: torch.FloatTensor,
|
1854
|
+
image_only_indicator: torch.FloatTensor,
|
1855
|
+
) -> torch.FloatTensor:
|
1856
|
+
for resnet in self.resnets:
|
1857
|
+
hidden_states = resnet(
|
1858
|
+
hidden_states,
|
1859
|
+
image_only_indicator=image_only_indicator,
|
1860
|
+
)
|
1861
|
+
|
1862
|
+
if self.upsamplers is not None:
|
1863
|
+
for upsampler in self.upsamplers:
|
1864
|
+
hidden_states = upsampler(hidden_states)
|
1865
|
+
|
1866
|
+
return hidden_states
|
1867
|
+
|
1868
|
+
|
1869
|
+
class UNetMidBlockSpatioTemporal(nn.Module):
|
1870
|
+
def __init__(
|
1871
|
+
self,
|
1872
|
+
in_channels: int,
|
1873
|
+
temb_channels: int,
|
1874
|
+
num_layers: int = 1,
|
1875
|
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
1876
|
+
num_attention_heads: int = 1,
|
1877
|
+
cross_attention_dim: int = 1280,
|
1878
|
+
):
|
1879
|
+
super().__init__()
|
1880
|
+
|
1881
|
+
self.has_cross_attention = True
|
1882
|
+
self.num_attention_heads = num_attention_heads
|
1883
|
+
|
1884
|
+
# support for variable transformer layers per block
|
1885
|
+
if isinstance(transformer_layers_per_block, int):
|
1886
|
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
1887
|
+
|
1888
|
+
# there is always at least one resnet
|
1889
|
+
resnets = [
|
1890
|
+
SpatioTemporalResBlock(
|
1891
|
+
in_channels=in_channels,
|
1892
|
+
out_channels=in_channels,
|
1893
|
+
temb_channels=temb_channels,
|
1894
|
+
eps=1e-5,
|
1895
|
+
)
|
1896
|
+
]
|
1897
|
+
attentions = []
|
1898
|
+
|
1899
|
+
for i in range(num_layers):
|
1900
|
+
attentions.append(
|
1901
|
+
TransformerSpatioTemporalModel(
|
1902
|
+
num_attention_heads,
|
1903
|
+
in_channels // num_attention_heads,
|
1904
|
+
in_channels=in_channels,
|
1905
|
+
num_layers=transformer_layers_per_block[i],
|
1906
|
+
cross_attention_dim=cross_attention_dim,
|
1907
|
+
)
|
1908
|
+
)
|
1909
|
+
|
1910
|
+
resnets.append(
|
1911
|
+
SpatioTemporalResBlock(
|
1912
|
+
in_channels=in_channels,
|
1913
|
+
out_channels=in_channels,
|
1914
|
+
temb_channels=temb_channels,
|
1915
|
+
eps=1e-5,
|
1916
|
+
)
|
1917
|
+
)
|
1918
|
+
|
1919
|
+
self.attentions = nn.ModuleList(attentions)
|
1920
|
+
self.resnets = nn.ModuleList(resnets)
|
1921
|
+
|
1922
|
+
self.gradient_checkpointing = False
|
1923
|
+
|
1924
|
+
def forward(
|
1925
|
+
self,
|
1926
|
+
hidden_states: torch.FloatTensor,
|
1927
|
+
temb: Optional[torch.FloatTensor] = None,
|
1928
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1929
|
+
image_only_indicator: Optional[torch.Tensor] = None,
|
1930
|
+
) -> torch.FloatTensor:
|
1931
|
+
hidden_states = self.resnets[0](
|
1932
|
+
hidden_states,
|
1933
|
+
temb,
|
1934
|
+
image_only_indicator=image_only_indicator,
|
1935
|
+
)
|
1936
|
+
|
1937
|
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
1938
|
+
if self.training and self.gradient_checkpointing: # TODO
|
1939
|
+
|
1940
|
+
def create_custom_forward(module, return_dict=None):
|
1941
|
+
def custom_forward(*inputs):
|
1942
|
+
if return_dict is not None:
|
1943
|
+
return module(*inputs, return_dict=return_dict)
|
1944
|
+
else:
|
1945
|
+
return module(*inputs)
|
1946
|
+
|
1947
|
+
return custom_forward
|
1948
|
+
|
1949
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1950
|
+
hidden_states = attn(
|
1951
|
+
hidden_states,
|
1952
|
+
encoder_hidden_states=encoder_hidden_states,
|
1953
|
+
image_only_indicator=image_only_indicator,
|
1954
|
+
return_dict=False,
|
1955
|
+
)[0]
|
1956
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1957
|
+
create_custom_forward(resnet),
|
1958
|
+
hidden_states,
|
1959
|
+
temb,
|
1960
|
+
image_only_indicator,
|
1961
|
+
**ckpt_kwargs,
|
1962
|
+
)
|
1963
|
+
else:
|
1964
|
+
hidden_states = attn(
|
1965
|
+
hidden_states,
|
1966
|
+
encoder_hidden_states=encoder_hidden_states,
|
1967
|
+
image_only_indicator=image_only_indicator,
|
1968
|
+
return_dict=False,
|
1969
|
+
)[0]
|
1970
|
+
hidden_states = resnet(
|
1971
|
+
hidden_states,
|
1972
|
+
temb,
|
1973
|
+
image_only_indicator=image_only_indicator,
|
1974
|
+
)
|
1975
|
+
|
1976
|
+
return hidden_states
|
1977
|
+
|
1978
|
+
|
1979
|
+
class DownBlockSpatioTemporal(nn.Module):
|
1980
|
+
def __init__(
|
1981
|
+
self,
|
1982
|
+
in_channels: int,
|
1983
|
+
out_channels: int,
|
1984
|
+
temb_channels: int,
|
1985
|
+
num_layers: int = 1,
|
1986
|
+
add_downsample: bool = True,
|
1987
|
+
):
|
1988
|
+
super().__init__()
|
1989
|
+
resnets = []
|
1990
|
+
|
1991
|
+
for i in range(num_layers):
|
1992
|
+
in_channels = in_channels if i == 0 else out_channels
|
1993
|
+
resnets.append(
|
1994
|
+
SpatioTemporalResBlock(
|
1995
|
+
in_channels=in_channels,
|
1996
|
+
out_channels=out_channels,
|
1997
|
+
temb_channels=temb_channels,
|
1998
|
+
eps=1e-5,
|
1999
|
+
)
|
2000
|
+
)
|
2001
|
+
|
2002
|
+
self.resnets = nn.ModuleList(resnets)
|
2003
|
+
|
2004
|
+
if add_downsample:
|
2005
|
+
self.downsamplers = nn.ModuleList(
|
2006
|
+
[
|
2007
|
+
Downsample2D(
|
2008
|
+
out_channels,
|
2009
|
+
use_conv=True,
|
2010
|
+
out_channels=out_channels,
|
2011
|
+
name="op",
|
2012
|
+
)
|
2013
|
+
]
|
2014
|
+
)
|
2015
|
+
else:
|
2016
|
+
self.downsamplers = None
|
2017
|
+
|
2018
|
+
self.gradient_checkpointing = False
|
2019
|
+
|
2020
|
+
def forward(
|
2021
|
+
self,
|
2022
|
+
hidden_states: torch.FloatTensor,
|
2023
|
+
temb: Optional[torch.FloatTensor] = None,
|
2024
|
+
image_only_indicator: Optional[torch.Tensor] = None,
|
2025
|
+
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
2026
|
+
output_states = ()
|
2027
|
+
for resnet in self.resnets:
|
2028
|
+
if self.training and self.gradient_checkpointing:
|
2029
|
+
|
2030
|
+
def create_custom_forward(module):
|
2031
|
+
def custom_forward(*inputs):
|
2032
|
+
return module(*inputs)
|
2033
|
+
|
2034
|
+
return custom_forward
|
2035
|
+
|
2036
|
+
if is_torch_version(">=", "1.11.0"):
|
2037
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2038
|
+
create_custom_forward(resnet),
|
2039
|
+
hidden_states,
|
2040
|
+
temb,
|
2041
|
+
image_only_indicator,
|
2042
|
+
use_reentrant=False,
|
2043
|
+
)
|
2044
|
+
else:
|
2045
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2046
|
+
create_custom_forward(resnet),
|
2047
|
+
hidden_states,
|
2048
|
+
temb,
|
2049
|
+
image_only_indicator,
|
2050
|
+
)
|
2051
|
+
else:
|
2052
|
+
hidden_states = resnet(
|
2053
|
+
hidden_states,
|
2054
|
+
temb,
|
2055
|
+
image_only_indicator=image_only_indicator,
|
2056
|
+
)
|
2057
|
+
|
2058
|
+
output_states = output_states + (hidden_states,)
|
2059
|
+
|
2060
|
+
if self.downsamplers is not None:
|
2061
|
+
for downsampler in self.downsamplers:
|
2062
|
+
hidden_states = downsampler(hidden_states)
|
2063
|
+
|
2064
|
+
output_states = output_states + (hidden_states,)
|
2065
|
+
|
2066
|
+
return hidden_states, output_states
|
2067
|
+
|
2068
|
+
|
2069
|
+
class CrossAttnDownBlockSpatioTemporal(nn.Module):
|
2070
|
+
def __init__(
|
2071
|
+
self,
|
2072
|
+
in_channels: int,
|
2073
|
+
out_channels: int,
|
2074
|
+
temb_channels: int,
|
2075
|
+
num_layers: int = 1,
|
2076
|
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
2077
|
+
num_attention_heads: int = 1,
|
2078
|
+
cross_attention_dim: int = 1280,
|
2079
|
+
add_downsample: bool = True,
|
2080
|
+
):
|
2081
|
+
super().__init__()
|
2082
|
+
resnets = []
|
2083
|
+
attentions = []
|
2084
|
+
|
2085
|
+
self.has_cross_attention = True
|
2086
|
+
self.num_attention_heads = num_attention_heads
|
2087
|
+
if isinstance(transformer_layers_per_block, int):
|
2088
|
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
2089
|
+
|
2090
|
+
for i in range(num_layers):
|
2091
|
+
in_channels = in_channels if i == 0 else out_channels
|
2092
|
+
resnets.append(
|
2093
|
+
SpatioTemporalResBlock(
|
2094
|
+
in_channels=in_channels,
|
2095
|
+
out_channels=out_channels,
|
2096
|
+
temb_channels=temb_channels,
|
2097
|
+
eps=1e-6,
|
2098
|
+
)
|
2099
|
+
)
|
2100
|
+
attentions.append(
|
2101
|
+
TransformerSpatioTemporalModel(
|
2102
|
+
num_attention_heads,
|
2103
|
+
out_channels // num_attention_heads,
|
2104
|
+
in_channels=out_channels,
|
2105
|
+
num_layers=transformer_layers_per_block[i],
|
2106
|
+
cross_attention_dim=cross_attention_dim,
|
2107
|
+
)
|
2108
|
+
)
|
2109
|
+
|
2110
|
+
self.attentions = nn.ModuleList(attentions)
|
2111
|
+
self.resnets = nn.ModuleList(resnets)
|
2112
|
+
|
2113
|
+
if add_downsample:
|
2114
|
+
self.downsamplers = nn.ModuleList(
|
2115
|
+
[
|
2116
|
+
Downsample2D(
|
2117
|
+
out_channels,
|
2118
|
+
use_conv=True,
|
2119
|
+
out_channels=out_channels,
|
2120
|
+
padding=1,
|
2121
|
+
name="op",
|
2122
|
+
)
|
2123
|
+
]
|
2124
|
+
)
|
2125
|
+
else:
|
2126
|
+
self.downsamplers = None
|
2127
|
+
|
2128
|
+
self.gradient_checkpointing = False
|
2129
|
+
|
2130
|
+
def forward(
|
2131
|
+
self,
|
2132
|
+
hidden_states: torch.FloatTensor,
|
2133
|
+
temb: Optional[torch.FloatTensor] = None,
|
2134
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
2135
|
+
image_only_indicator: Optional[torch.Tensor] = None,
|
2136
|
+
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
2137
|
+
output_states = ()
|
2138
|
+
|
2139
|
+
blocks = list(zip(self.resnets, self.attentions))
|
2140
|
+
for resnet, attn in blocks:
|
2141
|
+
if self.training and self.gradient_checkpointing: # TODO
|
2142
|
+
|
2143
|
+
def create_custom_forward(module, return_dict=None):
|
2144
|
+
def custom_forward(*inputs):
|
2145
|
+
if return_dict is not None:
|
2146
|
+
return module(*inputs, return_dict=return_dict)
|
2147
|
+
else:
|
2148
|
+
return module(*inputs)
|
2149
|
+
|
2150
|
+
return custom_forward
|
2151
|
+
|
2152
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
2153
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2154
|
+
create_custom_forward(resnet),
|
2155
|
+
hidden_states,
|
2156
|
+
temb,
|
2157
|
+
image_only_indicator,
|
2158
|
+
**ckpt_kwargs,
|
2159
|
+
)
|
2160
|
+
|
2161
|
+
hidden_states = attn(
|
2162
|
+
hidden_states,
|
2163
|
+
encoder_hidden_states=encoder_hidden_states,
|
2164
|
+
image_only_indicator=image_only_indicator,
|
2165
|
+
return_dict=False,
|
2166
|
+
)[0]
|
2167
|
+
else:
|
2168
|
+
hidden_states = resnet(
|
2169
|
+
hidden_states,
|
2170
|
+
temb,
|
2171
|
+
image_only_indicator=image_only_indicator,
|
2172
|
+
)
|
2173
|
+
hidden_states = attn(
|
2174
|
+
hidden_states,
|
2175
|
+
encoder_hidden_states=encoder_hidden_states,
|
2176
|
+
image_only_indicator=image_only_indicator,
|
2177
|
+
return_dict=False,
|
2178
|
+
)[0]
|
2179
|
+
|
2180
|
+
output_states = output_states + (hidden_states,)
|
2181
|
+
|
2182
|
+
if self.downsamplers is not None:
|
2183
|
+
for downsampler in self.downsamplers:
|
2184
|
+
hidden_states = downsampler(hidden_states)
|
2185
|
+
|
2186
|
+
output_states = output_states + (hidden_states,)
|
2187
|
+
|
2188
|
+
return hidden_states, output_states
|
2189
|
+
|
2190
|
+
|
2191
|
+
class UpBlockSpatioTemporal(nn.Module):
|
2192
|
+
def __init__(
|
2193
|
+
self,
|
2194
|
+
in_channels: int,
|
2195
|
+
prev_output_channel: int,
|
2196
|
+
out_channels: int,
|
2197
|
+
temb_channels: int,
|
2198
|
+
resolution_idx: Optional[int] = None,
|
2199
|
+
num_layers: int = 1,
|
2200
|
+
resnet_eps: float = 1e-6,
|
2201
|
+
add_upsample: bool = True,
|
2202
|
+
):
|
2203
|
+
super().__init__()
|
2204
|
+
resnets = []
|
2205
|
+
|
2206
|
+
for i in range(num_layers):
|
2207
|
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
2208
|
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
2209
|
+
|
2210
|
+
resnets.append(
|
2211
|
+
SpatioTemporalResBlock(
|
2212
|
+
in_channels=resnet_in_channels + res_skip_channels,
|
2213
|
+
out_channels=out_channels,
|
2214
|
+
temb_channels=temb_channels,
|
2215
|
+
eps=resnet_eps,
|
2216
|
+
)
|
2217
|
+
)
|
2218
|
+
|
2219
|
+
self.resnets = nn.ModuleList(resnets)
|
2220
|
+
|
2221
|
+
if add_upsample:
|
2222
|
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
2223
|
+
else:
|
2224
|
+
self.upsamplers = None
|
2225
|
+
|
2226
|
+
self.gradient_checkpointing = False
|
2227
|
+
self.resolution_idx = resolution_idx
|
2228
|
+
|
2229
|
+
def forward(
|
2230
|
+
self,
|
2231
|
+
hidden_states: torch.FloatTensor,
|
2232
|
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
2233
|
+
temb: Optional[torch.FloatTensor] = None,
|
2234
|
+
image_only_indicator: Optional[torch.Tensor] = None,
|
2235
|
+
) -> torch.FloatTensor:
|
2236
|
+
for resnet in self.resnets:
|
2237
|
+
# pop res hidden states
|
2238
|
+
res_hidden_states = res_hidden_states_tuple[-1]
|
2239
|
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
2240
|
+
|
2241
|
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2242
|
+
|
2243
|
+
if self.training and self.gradient_checkpointing:
|
2244
|
+
|
2245
|
+
def create_custom_forward(module):
|
2246
|
+
def custom_forward(*inputs):
|
2247
|
+
return module(*inputs)
|
2248
|
+
|
2249
|
+
return custom_forward
|
2250
|
+
|
2251
|
+
if is_torch_version(">=", "1.11.0"):
|
2252
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2253
|
+
create_custom_forward(resnet),
|
2254
|
+
hidden_states,
|
2255
|
+
temb,
|
2256
|
+
image_only_indicator,
|
2257
|
+
use_reentrant=False,
|
2258
|
+
)
|
2259
|
+
else:
|
2260
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2261
|
+
create_custom_forward(resnet),
|
2262
|
+
hidden_states,
|
2263
|
+
temb,
|
2264
|
+
image_only_indicator,
|
2265
|
+
)
|
2266
|
+
else:
|
2267
|
+
hidden_states = resnet(
|
2268
|
+
hidden_states,
|
2269
|
+
temb,
|
2270
|
+
image_only_indicator=image_only_indicator,
|
2271
|
+
)
|
2272
|
+
|
2273
|
+
if self.upsamplers is not None:
|
2274
|
+
for upsampler in self.upsamplers:
|
2275
|
+
hidden_states = upsampler(hidden_states)
|
2276
|
+
|
2277
|
+
return hidden_states
|
2278
|
+
|
2279
|
+
|
2280
|
+
class CrossAttnUpBlockSpatioTemporal(nn.Module):
|
2281
|
+
def __init__(
|
2282
|
+
self,
|
2283
|
+
in_channels: int,
|
2284
|
+
out_channels: int,
|
2285
|
+
prev_output_channel: int,
|
2286
|
+
temb_channels: int,
|
2287
|
+
resolution_idx: Optional[int] = None,
|
2288
|
+
num_layers: int = 1,
|
2289
|
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
2290
|
+
resnet_eps: float = 1e-6,
|
2291
|
+
num_attention_heads: int = 1,
|
2292
|
+
cross_attention_dim: int = 1280,
|
2293
|
+
add_upsample: bool = True,
|
2294
|
+
):
|
2295
|
+
super().__init__()
|
2296
|
+
resnets = []
|
2297
|
+
attentions = []
|
2298
|
+
|
2299
|
+
self.has_cross_attention = True
|
2300
|
+
self.num_attention_heads = num_attention_heads
|
2301
|
+
|
2302
|
+
if isinstance(transformer_layers_per_block, int):
|
2303
|
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
2304
|
+
|
2305
|
+
for i in range(num_layers):
|
2306
|
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
2307
|
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
2308
|
+
|
2309
|
+
resnets.append(
|
2310
|
+
SpatioTemporalResBlock(
|
2311
|
+
in_channels=resnet_in_channels + res_skip_channels,
|
2312
|
+
out_channels=out_channels,
|
2313
|
+
temb_channels=temb_channels,
|
2314
|
+
eps=resnet_eps,
|
2315
|
+
)
|
2316
|
+
)
|
2317
|
+
attentions.append(
|
2318
|
+
TransformerSpatioTemporalModel(
|
2319
|
+
num_attention_heads,
|
2320
|
+
out_channels // num_attention_heads,
|
2321
|
+
in_channels=out_channels,
|
2322
|
+
num_layers=transformer_layers_per_block[i],
|
2323
|
+
cross_attention_dim=cross_attention_dim,
|
2324
|
+
)
|
2325
|
+
)
|
2326
|
+
|
2327
|
+
self.attentions = nn.ModuleList(attentions)
|
2328
|
+
self.resnets = nn.ModuleList(resnets)
|
2329
|
+
|
2330
|
+
if add_upsample:
|
2331
|
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
2332
|
+
else:
|
2333
|
+
self.upsamplers = None
|
2334
|
+
|
2335
|
+
self.gradient_checkpointing = False
|
2336
|
+
self.resolution_idx = resolution_idx
|
2337
|
+
|
2338
|
+
def forward(
|
2339
|
+
self,
|
2340
|
+
hidden_states: torch.FloatTensor,
|
2341
|
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
2342
|
+
temb: Optional[torch.FloatTensor] = None,
|
2343
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
2344
|
+
image_only_indicator: Optional[torch.Tensor] = None,
|
2345
|
+
) -> torch.FloatTensor:
|
2346
|
+
for resnet, attn in zip(self.resnets, self.attentions):
|
2347
|
+
# pop res hidden states
|
2348
|
+
res_hidden_states = res_hidden_states_tuple[-1]
|
2349
|
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
2350
|
+
|
2351
|
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2352
|
+
|
2353
|
+
if self.training and self.gradient_checkpointing: # TODO
|
2354
|
+
|
2355
|
+
def create_custom_forward(module, return_dict=None):
|
2356
|
+
def custom_forward(*inputs):
|
2357
|
+
if return_dict is not None:
|
2358
|
+
return module(*inputs, return_dict=return_dict)
|
2359
|
+
else:
|
2360
|
+
return module(*inputs)
|
2361
|
+
|
2362
|
+
return custom_forward
|
2363
|
+
|
2364
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
2365
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2366
|
+
create_custom_forward(resnet),
|
2367
|
+
hidden_states,
|
2368
|
+
temb,
|
2369
|
+
image_only_indicator,
|
2370
|
+
**ckpt_kwargs,
|
2371
|
+
)
|
2372
|
+
hidden_states = attn(
|
2373
|
+
hidden_states,
|
2374
|
+
encoder_hidden_states=encoder_hidden_states,
|
2375
|
+
image_only_indicator=image_only_indicator,
|
2376
|
+
return_dict=False,
|
2377
|
+
)[0]
|
2378
|
+
else:
|
2379
|
+
hidden_states = resnet(
|
2380
|
+
hidden_states,
|
2381
|
+
temb,
|
2382
|
+
image_only_indicator=image_only_indicator,
|
2383
|
+
)
|
2384
|
+
hidden_states = attn(
|
2385
|
+
hidden_states,
|
2386
|
+
encoder_hidden_states=encoder_hidden_states,
|
2387
|
+
image_only_indicator=image_only_indicator,
|
2388
|
+
return_dict=False,
|
2389
|
+
)[0]
|
2390
|
+
|
2391
|
+
if self.upsamplers is not None:
|
2392
|
+
for upsampler in self.upsamplers:
|
2393
|
+
hidden_states = upsampler(hidden_states)
|
2394
|
+
|
2395
|
+
return hidden_states
|