diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
- diffusers-0.28.0.dist-info/RECORD +414 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ import torch.utils.checkpoint
|
|
20
20
|
|
21
21
|
from ...configuration_utils import ConfigMixin, register_to_config
|
22
22
|
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
23
|
+
from ...loaders.single_file_model import FromOriginalModelMixin
|
23
24
|
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
24
25
|
from ..activations import get_activation
|
25
26
|
from ..attention_processor import (
|
@@ -59,14 +60,16 @@ class UNet2DConditionOutput(BaseOutput):
|
|
59
60
|
The output of [`UNet2DConditionModel`].
|
60
61
|
|
61
62
|
Args:
|
62
|
-
sample (`torch.
|
63
|
+
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
63
64
|
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
64
65
|
"""
|
65
66
|
|
66
|
-
sample: torch.
|
67
|
+
sample: torch.Tensor = None
|
67
68
|
|
68
69
|
|
69
|
-
class UNet2DConditionModel(
|
70
|
+
class UNet2DConditionModel(
|
71
|
+
ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
|
72
|
+
):
|
70
73
|
r"""
|
71
74
|
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
72
75
|
shaped output.
|
@@ -161,6 +164,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
|
161
164
|
"""
|
162
165
|
|
163
166
|
_supports_gradient_checkpointing = True
|
167
|
+
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
|
164
168
|
|
165
169
|
@register_to_config
|
166
170
|
def __init__(
|
@@ -580,7 +584,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
|
580
584
|
elif encoder_hid_dim_type == "text_image_proj":
|
581
585
|
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
582
586
|
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
583
|
-
# case when `addition_embed_type == "text_image_proj"` (
|
587
|
+
# case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
|
584
588
|
self.encoder_hid_proj = TextImageProjection(
|
585
589
|
text_embed_dim=encoder_hid_dim,
|
586
590
|
image_embed_dim=cross_attention_dim,
|
@@ -660,7 +664,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
|
660
664
|
elif addition_embed_type == "text_image":
|
661
665
|
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
662
666
|
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
663
|
-
# case when `addition_embed_type == "text_image"` (
|
667
|
+
# case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
|
664
668
|
self.add_embedding = TextImageTimeEmbedding(
|
665
669
|
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
666
670
|
)
|
@@ -681,7 +685,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
|
681
685
|
positive_len = 768
|
682
686
|
if isinstance(cross_attention_dim, int):
|
683
687
|
positive_len = cross_attention_dim
|
684
|
-
elif isinstance(cross_attention_dim,
|
688
|
+
elif isinstance(cross_attention_dim, (list, tuple)):
|
685
689
|
positive_len = cross_attention_dim[0]
|
686
690
|
|
687
691
|
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
@@ -865,8 +869,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
|
865
869
|
|
866
870
|
def fuse_qkv_projections(self):
|
867
871
|
"""
|
868
|
-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
869
|
-
|
872
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
873
|
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
870
874
|
|
871
875
|
<Tip warning={true}>
|
872
876
|
|
@@ -1010,7 +1014,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
|
1010
1014
|
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
1011
1015
|
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
1012
1016
|
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
1013
|
-
#
|
1017
|
+
# Kandinsky 2.1 - style
|
1014
1018
|
if "image_embeds" not in added_cond_kwargs:
|
1015
1019
|
raise ValueError(
|
1016
1020
|
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
@@ -1038,7 +1042,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
|
1038
1042
|
|
1039
1043
|
def forward(
|
1040
1044
|
self,
|
1041
|
-
sample: torch.
|
1045
|
+
sample: torch.Tensor,
|
1042
1046
|
timestep: Union[torch.Tensor, float, int],
|
1043
1047
|
encoder_hidden_states: torch.Tensor,
|
1044
1048
|
class_labels: Optional[torch.Tensor] = None,
|
@@ -1056,10 +1060,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
|
1056
1060
|
The [`UNet2DConditionModel`] forward method.
|
1057
1061
|
|
1058
1062
|
Args:
|
1059
|
-
sample (`torch.
|
1063
|
+
sample (`torch.Tensor`):
|
1060
1064
|
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
1061
|
-
timestep (`torch.
|
1062
|
-
encoder_hidden_states (`torch.
|
1065
|
+
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
|
1066
|
+
encoder_hidden_states (`torch.Tensor`):
|
1063
1067
|
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
1064
1068
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
1065
1069
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
@@ -1093,8 +1097,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
|
1093
1097
|
|
1094
1098
|
Returns:
|
1095
1099
|
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
1096
|
-
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
|
1097
|
-
a `tuple` is returned where the first element is the sample tensor.
|
1100
|
+
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
|
1101
|
+
otherwise a `tuple` is returned where the first element is the sample tensor.
|
1098
1102
|
"""
|
1099
1103
|
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
1100
1104
|
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
@@ -76,7 +76,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
76
76
|
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
|
77
77
|
The tuple of upsample blocks to use.
|
78
78
|
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
79
|
-
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer
|
79
|
+
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer
|
80
|
+
is skipped.
|
80
81
|
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
81
82
|
The tuple of output channels for each block.
|
82
83
|
layers_per_block (`int`, *optional*, defaults to 2):
|
@@ -350,15 +351,15 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
350
351
|
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
351
352
|
A tensor that if specified is added to the residual of the middle unet block.
|
352
353
|
return_dict (`bool`, *optional*, defaults to `True`):
|
353
|
-
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of
|
354
|
-
plain tuple.
|
354
|
+
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of
|
355
|
+
a plain tuple.
|
355
356
|
train (`bool`, *optional*, defaults to `False`):
|
356
357
|
Use deterministic functions and disable dropout when not training.
|
357
358
|
|
358
359
|
Returns:
|
359
360
|
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
360
|
-
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
|
361
|
-
When returning a tuple, the first element is the sample tensor.
|
361
|
+
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
|
362
|
+
`tuple`. When returning a tuple, the first element is the sample tensor.
|
362
363
|
"""
|
363
364
|
# 1. time
|
364
365
|
if not isinstance(timesteps, jnp.ndarray):
|
@@ -121,6 +121,7 @@ def get_down_block(
|
|
121
121
|
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
|
122
122
|
return CrossAttnDownBlockMotion(
|
123
123
|
num_layers=num_layers,
|
124
|
+
transformer_layers_per_block=transformer_layers_per_block,
|
124
125
|
in_channels=in_channels,
|
125
126
|
out_channels=out_channels,
|
126
127
|
temb_channels=temb_channels,
|
@@ -255,6 +256,7 @@ def get_up_block(
|
|
255
256
|
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
|
256
257
|
return CrossAttnUpBlockMotion(
|
257
258
|
num_layers=num_layers,
|
259
|
+
transformer_layers_per_block=transformer_layers_per_block,
|
258
260
|
in_channels=in_channels,
|
259
261
|
out_channels=out_channels,
|
260
262
|
prev_output_channel=prev_output_channel,
|
@@ -409,13 +411,13 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
|
409
411
|
|
410
412
|
def forward(
|
411
413
|
self,
|
412
|
-
hidden_states: torch.
|
413
|
-
temb: Optional[torch.
|
414
|
-
encoder_hidden_states: Optional[torch.
|
415
|
-
attention_mask: Optional[torch.
|
414
|
+
hidden_states: torch.Tensor,
|
415
|
+
temb: Optional[torch.Tensor] = None,
|
416
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
417
|
+
attention_mask: Optional[torch.Tensor] = None,
|
416
418
|
num_frames: int = 1,
|
417
419
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
418
|
-
) -> torch.
|
420
|
+
) -> torch.Tensor:
|
419
421
|
hidden_states = self.resnets[0](hidden_states, temb)
|
420
422
|
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
|
421
423
|
for attn, temp_attn, resnet, temp_conv in zip(
|
@@ -542,13 +544,13 @@ class CrossAttnDownBlock3D(nn.Module):
|
|
542
544
|
|
543
545
|
def forward(
|
544
546
|
self,
|
545
|
-
hidden_states: torch.
|
546
|
-
temb: Optional[torch.
|
547
|
-
encoder_hidden_states: Optional[torch.
|
548
|
-
attention_mask: Optional[torch.
|
547
|
+
hidden_states: torch.Tensor,
|
548
|
+
temb: Optional[torch.Tensor] = None,
|
549
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
550
|
+
attention_mask: Optional[torch.Tensor] = None,
|
549
551
|
num_frames: int = 1,
|
550
552
|
cross_attention_kwargs: Dict[str, Any] = None,
|
551
|
-
) -> Union[torch.
|
553
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
552
554
|
# TODO(Patrick, William) - attention mask is not used
|
553
555
|
output_states = ()
|
554
556
|
|
@@ -649,10 +651,10 @@ class DownBlock3D(nn.Module):
|
|
649
651
|
|
650
652
|
def forward(
|
651
653
|
self,
|
652
|
-
hidden_states: torch.
|
653
|
-
temb: Optional[torch.
|
654
|
+
hidden_states: torch.Tensor,
|
655
|
+
temb: Optional[torch.Tensor] = None,
|
654
656
|
num_frames: int = 1,
|
655
|
-
) -> Union[torch.
|
657
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
656
658
|
output_states = ()
|
657
659
|
|
658
660
|
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
@@ -767,15 +769,15 @@ class CrossAttnUpBlock3D(nn.Module):
|
|
767
769
|
|
768
770
|
def forward(
|
769
771
|
self,
|
770
|
-
hidden_states: torch.
|
771
|
-
res_hidden_states_tuple: Tuple[torch.
|
772
|
-
temb: Optional[torch.
|
773
|
-
encoder_hidden_states: Optional[torch.
|
772
|
+
hidden_states: torch.Tensor,
|
773
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
774
|
+
temb: Optional[torch.Tensor] = None,
|
775
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
774
776
|
upsample_size: Optional[int] = None,
|
775
|
-
attention_mask: Optional[torch.
|
777
|
+
attention_mask: Optional[torch.Tensor] = None,
|
776
778
|
num_frames: int = 1,
|
777
779
|
cross_attention_kwargs: Dict[str, Any] = None,
|
778
|
-
) -> torch.
|
780
|
+
) -> torch.Tensor:
|
779
781
|
is_freeu_enabled = (
|
780
782
|
getattr(self, "s1", None)
|
781
783
|
and getattr(self, "s2", None)
|
@@ -889,12 +891,12 @@ class UpBlock3D(nn.Module):
|
|
889
891
|
|
890
892
|
def forward(
|
891
893
|
self,
|
892
|
-
hidden_states: torch.
|
893
|
-
res_hidden_states_tuple: Tuple[torch.
|
894
|
-
temb: Optional[torch.
|
894
|
+
hidden_states: torch.Tensor,
|
895
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
896
|
+
temb: Optional[torch.Tensor] = None,
|
895
897
|
upsample_size: Optional[int] = None,
|
896
898
|
num_frames: int = 1,
|
897
|
-
) -> torch.
|
899
|
+
) -> torch.Tensor:
|
898
900
|
is_freeu_enabled = (
|
899
901
|
getattr(self, "s1", None)
|
900
902
|
and getattr(self, "s2", None)
|
@@ -1006,12 +1008,12 @@ class DownBlockMotion(nn.Module):
|
|
1006
1008
|
|
1007
1009
|
def forward(
|
1008
1010
|
self,
|
1009
|
-
hidden_states: torch.
|
1010
|
-
temb: Optional[torch.
|
1011
|
+
hidden_states: torch.Tensor,
|
1012
|
+
temb: Optional[torch.Tensor] = None,
|
1011
1013
|
num_frames: int = 1,
|
1012
1014
|
*args,
|
1013
1015
|
**kwargs,
|
1014
|
-
) -> Union[torch.
|
1016
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
1015
1017
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1016
1018
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1017
1019
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1172,18 +1174,18 @@ class CrossAttnDownBlockMotion(nn.Module):
|
|
1172
1174
|
|
1173
1175
|
def forward(
|
1174
1176
|
self,
|
1175
|
-
hidden_states: torch.
|
1176
|
-
temb: Optional[torch.
|
1177
|
-
encoder_hidden_states: Optional[torch.
|
1178
|
-
attention_mask: Optional[torch.
|
1177
|
+
hidden_states: torch.Tensor,
|
1178
|
+
temb: Optional[torch.Tensor] = None,
|
1179
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1180
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1179
1181
|
num_frames: int = 1,
|
1180
|
-
encoder_attention_mask: Optional[torch.
|
1182
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1181
1183
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1182
|
-
additional_residuals: Optional[torch.
|
1184
|
+
additional_residuals: Optional[torch.Tensor] = None,
|
1183
1185
|
):
|
1184
1186
|
if cross_attention_kwargs is not None:
|
1185
1187
|
if cross_attention_kwargs.get("scale", None) is not None:
|
1186
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
1188
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
1187
1189
|
|
1188
1190
|
output_states = ()
|
1189
1191
|
|
@@ -1355,19 +1357,19 @@ class CrossAttnUpBlockMotion(nn.Module):
|
|
1355
1357
|
|
1356
1358
|
def forward(
|
1357
1359
|
self,
|
1358
|
-
hidden_states: torch.
|
1359
|
-
res_hidden_states_tuple: Tuple[torch.
|
1360
|
-
temb: Optional[torch.
|
1361
|
-
encoder_hidden_states: Optional[torch.
|
1360
|
+
hidden_states: torch.Tensor,
|
1361
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
1362
|
+
temb: Optional[torch.Tensor] = None,
|
1363
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1362
1364
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1363
1365
|
upsample_size: Optional[int] = None,
|
1364
|
-
attention_mask: Optional[torch.
|
1365
|
-
encoder_attention_mask: Optional[torch.
|
1366
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1367
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1366
1368
|
num_frames: int = 1,
|
1367
|
-
) -> torch.
|
1369
|
+
) -> torch.Tensor:
|
1368
1370
|
if cross_attention_kwargs is not None:
|
1369
1371
|
if cross_attention_kwargs.get("scale", None) is not None:
|
1370
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
1372
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
1371
1373
|
|
1372
1374
|
is_freeu_enabled = (
|
1373
1375
|
getattr(self, "s1", None)
|
@@ -1516,14 +1518,14 @@ class UpBlockMotion(nn.Module):
|
|
1516
1518
|
|
1517
1519
|
def forward(
|
1518
1520
|
self,
|
1519
|
-
hidden_states: torch.
|
1520
|
-
res_hidden_states_tuple: Tuple[torch.
|
1521
|
-
temb: Optional[torch.
|
1521
|
+
hidden_states: torch.Tensor,
|
1522
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
1523
|
+
temb: Optional[torch.Tensor] = None,
|
1522
1524
|
upsample_size=None,
|
1523
1525
|
num_frames: int = 1,
|
1524
1526
|
*args,
|
1525
1527
|
**kwargs,
|
1526
|
-
) -> torch.
|
1528
|
+
) -> torch.Tensor:
|
1527
1529
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1528
1530
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1529
1531
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1697,17 +1699,17 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
|
1697
1699
|
|
1698
1700
|
def forward(
|
1699
1701
|
self,
|
1700
|
-
hidden_states: torch.
|
1701
|
-
temb: Optional[torch.
|
1702
|
-
encoder_hidden_states: Optional[torch.
|
1703
|
-
attention_mask: Optional[torch.
|
1702
|
+
hidden_states: torch.Tensor,
|
1703
|
+
temb: Optional[torch.Tensor] = None,
|
1704
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1705
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1704
1706
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1705
|
-
encoder_attention_mask: Optional[torch.
|
1707
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1706
1708
|
num_frames: int = 1,
|
1707
|
-
) -> torch.
|
1709
|
+
) -> torch.Tensor:
|
1708
1710
|
if cross_attention_kwargs is not None:
|
1709
1711
|
if cross_attention_kwargs.get("scale", None) is not None:
|
1710
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
1712
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
1711
1713
|
|
1712
1714
|
hidden_states = self.resnets[0](hidden_states, temb)
|
1713
1715
|
|
@@ -1809,8 +1811,8 @@ class MidBlockTemporalDecoder(nn.Module):
|
|
1809
1811
|
|
1810
1812
|
def forward(
|
1811
1813
|
self,
|
1812
|
-
hidden_states: torch.
|
1813
|
-
image_only_indicator: torch.
|
1814
|
+
hidden_states: torch.Tensor,
|
1815
|
+
image_only_indicator: torch.Tensor,
|
1814
1816
|
):
|
1815
1817
|
hidden_states = self.resnets[0](
|
1816
1818
|
hidden_states,
|
@@ -1860,9 +1862,9 @@ class UpBlockTemporalDecoder(nn.Module):
|
|
1860
1862
|
|
1861
1863
|
def forward(
|
1862
1864
|
self,
|
1863
|
-
hidden_states: torch.
|
1864
|
-
image_only_indicator: torch.
|
1865
|
-
) -> torch.
|
1865
|
+
hidden_states: torch.Tensor,
|
1866
|
+
image_only_indicator: torch.Tensor,
|
1867
|
+
) -> torch.Tensor:
|
1866
1868
|
for resnet in self.resnets:
|
1867
1869
|
hidden_states = resnet(
|
1868
1870
|
hidden_states,
|
@@ -1933,11 +1935,11 @@ class UNetMidBlockSpatioTemporal(nn.Module):
|
|
1933
1935
|
|
1934
1936
|
def forward(
|
1935
1937
|
self,
|
1936
|
-
hidden_states: torch.
|
1937
|
-
temb: Optional[torch.
|
1938
|
-
encoder_hidden_states: Optional[torch.
|
1938
|
+
hidden_states: torch.Tensor,
|
1939
|
+
temb: Optional[torch.Tensor] = None,
|
1940
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1939
1941
|
image_only_indicator: Optional[torch.Tensor] = None,
|
1940
|
-
) -> torch.
|
1942
|
+
) -> torch.Tensor:
|
1941
1943
|
hidden_states = self.resnets[0](
|
1942
1944
|
hidden_states,
|
1943
1945
|
temb,
|
@@ -2029,10 +2031,10 @@ class DownBlockSpatioTemporal(nn.Module):
|
|
2029
2031
|
|
2030
2032
|
def forward(
|
2031
2033
|
self,
|
2032
|
-
hidden_states: torch.
|
2033
|
-
temb: Optional[torch.
|
2034
|
+
hidden_states: torch.Tensor,
|
2035
|
+
temb: Optional[torch.Tensor] = None,
|
2034
2036
|
image_only_indicator: Optional[torch.Tensor] = None,
|
2035
|
-
) -> Tuple[torch.
|
2037
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
2036
2038
|
output_states = ()
|
2037
2039
|
for resnet in self.resnets:
|
2038
2040
|
if self.training and self.gradient_checkpointing:
|
@@ -2139,11 +2141,11 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
|
|
2139
2141
|
|
2140
2142
|
def forward(
|
2141
2143
|
self,
|
2142
|
-
hidden_states: torch.
|
2143
|
-
temb: Optional[torch.
|
2144
|
-
encoder_hidden_states: Optional[torch.
|
2144
|
+
hidden_states: torch.Tensor,
|
2145
|
+
temb: Optional[torch.Tensor] = None,
|
2146
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2145
2147
|
image_only_indicator: Optional[torch.Tensor] = None,
|
2146
|
-
) -> Tuple[torch.
|
2148
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
2147
2149
|
output_states = ()
|
2148
2150
|
|
2149
2151
|
blocks = list(zip(self.resnets, self.attentions))
|
@@ -2238,11 +2240,11 @@ class UpBlockSpatioTemporal(nn.Module):
|
|
2238
2240
|
|
2239
2241
|
def forward(
|
2240
2242
|
self,
|
2241
|
-
hidden_states: torch.
|
2242
|
-
res_hidden_states_tuple: Tuple[torch.
|
2243
|
-
temb: Optional[torch.
|
2243
|
+
hidden_states: torch.Tensor,
|
2244
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
2245
|
+
temb: Optional[torch.Tensor] = None,
|
2244
2246
|
image_only_indicator: Optional[torch.Tensor] = None,
|
2245
|
-
) -> torch.
|
2247
|
+
) -> torch.Tensor:
|
2246
2248
|
for resnet in self.resnets:
|
2247
2249
|
# pop res hidden states
|
2248
2250
|
res_hidden_states = res_hidden_states_tuple[-1]
|
@@ -2347,12 +2349,12 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
|
|
2347
2349
|
|
2348
2350
|
def forward(
|
2349
2351
|
self,
|
2350
|
-
hidden_states: torch.
|
2351
|
-
res_hidden_states_tuple: Tuple[torch.
|
2352
|
-
temb: Optional[torch.
|
2353
|
-
encoder_hidden_states: Optional[torch.
|
2352
|
+
hidden_states: torch.Tensor,
|
2353
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
2354
|
+
temb: Optional[torch.Tensor] = None,
|
2355
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2354
2356
|
image_only_indicator: Optional[torch.Tensor] = None,
|
2355
|
-
) -> torch.
|
2357
|
+
) -> torch.Tensor:
|
2356
2358
|
for resnet, attn in zip(self.resnets, self.attentions):
|
2357
2359
|
# pop res hidden states
|
2358
2360
|
res_hidden_states = res_hidden_states_tuple[-1]
|
@@ -55,11 +55,11 @@ class UNet3DConditionOutput(BaseOutput):
|
|
55
55
|
The output of [`UNet3DConditionModel`].
|
56
56
|
|
57
57
|
Args:
|
58
|
-
sample (`torch.
|
58
|
+
sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`):
|
59
59
|
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
60
60
|
"""
|
61
61
|
|
62
|
-
sample: torch.
|
62
|
+
sample: torch.Tensor
|
63
63
|
|
64
64
|
|
65
65
|
class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
@@ -91,6 +91,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
91
91
|
cross_attention_dim (`int`, *optional*, defaults to 1024): The dimension of the cross attention features.
|
92
92
|
attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads.
|
93
93
|
num_attention_heads (`int`, *optional*): The number of attention heads.
|
94
|
+
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
95
|
+
The dimension of `cond_proj` layer in the timestep embedding.
|
94
96
|
"""
|
95
97
|
|
96
98
|
_supports_gradient_checkpointing = False
|
@@ -123,6 +125,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
123
125
|
cross_attention_dim: int = 1024,
|
124
126
|
attention_head_dim: Union[int, Tuple[int]] = 64,
|
125
127
|
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
128
|
+
time_cond_proj_dim: Optional[int] = None,
|
126
129
|
):
|
127
130
|
super().__init__()
|
128
131
|
|
@@ -174,6 +177,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
174
177
|
timestep_input_dim,
|
175
178
|
time_embed_dim,
|
176
179
|
act_fn=act_fn,
|
180
|
+
cond_proj_dim=time_cond_proj_dim,
|
177
181
|
)
|
178
182
|
|
179
183
|
self.transformer_in = TransformerTemporalModel(
|
@@ -507,8 +511,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
507
511
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
508
512
|
def fuse_qkv_projections(self):
|
509
513
|
"""
|
510
|
-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
511
|
-
|
514
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
515
|
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
512
516
|
|
513
517
|
<Tip warning={true}>
|
514
518
|
|
@@ -556,7 +560,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
556
560
|
|
557
561
|
def forward(
|
558
562
|
self,
|
559
|
-
sample: torch.
|
563
|
+
sample: torch.Tensor,
|
560
564
|
timestep: Union[torch.Tensor, float, int],
|
561
565
|
encoder_hidden_states: torch.Tensor,
|
562
566
|
class_labels: Optional[torch.Tensor] = None,
|
@@ -566,15 +570,15 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
566
570
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
567
571
|
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
568
572
|
return_dict: bool = True,
|
569
|
-
) -> Union[UNet3DConditionOutput, Tuple[torch.
|
573
|
+
) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]:
|
570
574
|
r"""
|
571
575
|
The [`UNet3DConditionModel`] forward method.
|
572
576
|
|
573
577
|
Args:
|
574
|
-
sample (`torch.
|
578
|
+
sample (`torch.Tensor`):
|
575
579
|
The noisy input tensor with the following shape `(batch, num_channels, num_frames, height, width`.
|
576
|
-
timestep (`torch.
|
577
|
-
encoder_hidden_states (`torch.
|
580
|
+
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
|
581
|
+
encoder_hidden_states (`torch.Tensor`):
|
578
582
|
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
579
583
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
580
584
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
@@ -81,8 +81,8 @@ class I2VGenXLTransformerTemporalEncoder(nn.Module):
|
|
81
81
|
|
82
82
|
def forward(
|
83
83
|
self,
|
84
|
-
hidden_states: torch.
|
85
|
-
) -> torch.
|
84
|
+
hidden_states: torch.Tensor,
|
85
|
+
) -> torch.Tensor:
|
86
86
|
norm_hidden_states = self.norm1(hidden_states)
|
87
87
|
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
|
88
88
|
hidden_states = attn_output + hidden_states
|
@@ -99,8 +99,8 @@ class I2VGenXLTransformerTemporalEncoder(nn.Module):
|
|
99
99
|
|
100
100
|
class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
101
101
|
r"""
|
102
|
-
I2VGenXL UNet. It is a conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep
|
103
|
-
|
102
|
+
I2VGenXL UNet. It is a conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and
|
103
|
+
returns a sample-shaped output.
|
104
104
|
|
105
105
|
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
106
106
|
for all models (such as downloading or saving).
|
@@ -477,8 +477,8 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
477
477
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
478
478
|
def fuse_qkv_projections(self):
|
479
479
|
"""
|
480
|
-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
481
|
-
|
480
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
481
|
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
482
482
|
|
483
483
|
<Tip warning={true}>
|
484
484
|
|
@@ -514,7 +514,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
514
514
|
|
515
515
|
def forward(
|
516
516
|
self,
|
517
|
-
sample: torch.
|
517
|
+
sample: torch.Tensor,
|
518
518
|
timestep: Union[torch.Tensor, float, int],
|
519
519
|
fps: torch.Tensor,
|
520
520
|
image_latents: torch.Tensor,
|
@@ -523,18 +523,19 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
523
523
|
timestep_cond: Optional[torch.Tensor] = None,
|
524
524
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
525
525
|
return_dict: bool = True,
|
526
|
-
) -> Union[UNet3DConditionOutput, Tuple[torch.
|
526
|
+
) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]:
|
527
527
|
r"""
|
528
528
|
The [`I2VGenXLUNet`] forward method.
|
529
529
|
|
530
530
|
Args:
|
531
|
-
sample (`torch.
|
531
|
+
sample (`torch.Tensor`):
|
532
532
|
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
|
533
|
-
timestep (`torch.
|
533
|
+
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
|
534
534
|
fps (`torch.Tensor`): Frames per second for the video being generated. Used as a "micro-condition".
|
535
|
-
image_latents (`torch.
|
536
|
-
image_embeddings (`torch.
|
537
|
-
|
535
|
+
image_latents (`torch.Tensor`): Image encodings from the VAE.
|
536
|
+
image_embeddings (`torch.Tensor`):
|
537
|
+
Projection embeddings of the conditioning image computed with a vision encoder.
|
538
|
+
encoder_hidden_states (`torch.Tensor`):
|
538
539
|
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
539
540
|
cross_attention_kwargs (`dict`, *optional*):
|
540
541
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|