diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
- diffusers-0.28.0.dist-info/RECORD +414 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -15,9 +15,10 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|
15
15
|
|
16
16
|
import torch
|
17
17
|
import torch.nn as nn
|
18
|
+
import torch.nn.functional as F
|
18
19
|
import torch.utils.checkpoint
|
19
20
|
|
20
|
-
from ...configuration_utils import ConfigMixin, register_to_config
|
21
|
+
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
21
22
|
from ...loaders import UNet2DConditionLoadersMixin
|
22
23
|
from ...utils import logging
|
23
24
|
from ..attention_processor import (
|
@@ -27,6 +28,9 @@ from ..attention_processor import (
|
|
27
28
|
AttentionProcessor,
|
28
29
|
AttnAddedKVProcessor,
|
29
30
|
AttnProcessor,
|
31
|
+
AttnProcessor2_0,
|
32
|
+
IPAdapterAttnProcessor,
|
33
|
+
IPAdapterAttnProcessor2_0,
|
30
34
|
)
|
31
35
|
from ..embeddings import TimestepEmbedding, Timesteps
|
32
36
|
from ..modeling_utils import ModelMixin
|
@@ -211,6 +215,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
211
215
|
norm_num_groups: int = 32,
|
212
216
|
norm_eps: float = 1e-5,
|
213
217
|
cross_attention_dim: int = 1280,
|
218
|
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
219
|
+
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
|
214
220
|
use_linear_projection: bool = False,
|
215
221
|
num_attention_heads: Union[int, Tuple[int, ...]] = 8,
|
216
222
|
motion_max_seq_length: int = 32,
|
@@ -218,6 +224,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
218
224
|
use_motion_mid_block: int = True,
|
219
225
|
encoder_hid_dim: Optional[int] = None,
|
220
226
|
encoder_hid_dim_type: Optional[str] = None,
|
227
|
+
addition_embed_type: Optional[str] = None,
|
228
|
+
addition_time_embed_dim: Optional[int] = None,
|
229
|
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
221
230
|
time_cond_proj_dim: Optional[int] = None,
|
222
231
|
):
|
223
232
|
super().__init__()
|
@@ -240,6 +249,21 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
240
249
|
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
241
250
|
)
|
242
251
|
|
252
|
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
253
|
+
raise ValueError(
|
254
|
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
255
|
+
)
|
256
|
+
|
257
|
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
258
|
+
raise ValueError(
|
259
|
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
260
|
+
)
|
261
|
+
|
262
|
+
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
|
263
|
+
for layer_number_per_block in transformer_layers_per_block:
|
264
|
+
if isinstance(layer_number_per_block, list):
|
265
|
+
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
|
266
|
+
|
243
267
|
# input
|
244
268
|
conv_in_kernel = 3
|
245
269
|
conv_out_kernel = 3
|
@@ -260,6 +284,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
260
284
|
if encoder_hid_dim_type is None:
|
261
285
|
self.encoder_hid_proj = None
|
262
286
|
|
287
|
+
if addition_embed_type == "text_time":
|
288
|
+
self.add_time_proj = Timesteps(addition_time_embed_dim, True, 0)
|
289
|
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
290
|
+
|
263
291
|
# class embedding
|
264
292
|
self.down_blocks = nn.ModuleList([])
|
265
293
|
self.up_blocks = nn.ModuleList([])
|
@@ -267,6 +295,15 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
267
295
|
if isinstance(num_attention_heads, int):
|
268
296
|
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
269
297
|
|
298
|
+
if isinstance(cross_attention_dim, int):
|
299
|
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
300
|
+
|
301
|
+
if isinstance(layers_per_block, int):
|
302
|
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
303
|
+
|
304
|
+
if isinstance(transformer_layers_per_block, int):
|
305
|
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
306
|
+
|
270
307
|
# down
|
271
308
|
output_channel = block_out_channels[0]
|
272
309
|
for i, down_block_type in enumerate(down_block_types):
|
@@ -276,7 +313,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
276
313
|
|
277
314
|
down_block = get_down_block(
|
278
315
|
down_block_type,
|
279
|
-
num_layers=layers_per_block,
|
316
|
+
num_layers=layers_per_block[i],
|
280
317
|
in_channels=input_channel,
|
281
318
|
out_channels=output_channel,
|
282
319
|
temb_channels=time_embed_dim,
|
@@ -284,13 +321,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
284
321
|
resnet_eps=norm_eps,
|
285
322
|
resnet_act_fn=act_fn,
|
286
323
|
resnet_groups=norm_num_groups,
|
287
|
-
cross_attention_dim=cross_attention_dim,
|
324
|
+
cross_attention_dim=cross_attention_dim[i],
|
288
325
|
num_attention_heads=num_attention_heads[i],
|
289
326
|
downsample_padding=downsample_padding,
|
290
327
|
use_linear_projection=use_linear_projection,
|
291
328
|
dual_cross_attention=False,
|
292
329
|
temporal_num_attention_heads=motion_num_attention_heads,
|
293
330
|
temporal_max_seq_length=motion_max_seq_length,
|
331
|
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
294
332
|
)
|
295
333
|
self.down_blocks.append(down_block)
|
296
334
|
|
@@ -302,13 +340,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
302
340
|
resnet_eps=norm_eps,
|
303
341
|
resnet_act_fn=act_fn,
|
304
342
|
output_scale_factor=mid_block_scale_factor,
|
305
|
-
cross_attention_dim=cross_attention_dim,
|
343
|
+
cross_attention_dim=cross_attention_dim[-1],
|
306
344
|
num_attention_heads=num_attention_heads[-1],
|
307
345
|
resnet_groups=norm_num_groups,
|
308
346
|
dual_cross_attention=False,
|
309
347
|
use_linear_projection=use_linear_projection,
|
310
348
|
temporal_num_attention_heads=motion_num_attention_heads,
|
311
349
|
temporal_max_seq_length=motion_max_seq_length,
|
350
|
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
312
351
|
)
|
313
352
|
|
314
353
|
else:
|
@@ -318,11 +357,12 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
318
357
|
resnet_eps=norm_eps,
|
319
358
|
resnet_act_fn=act_fn,
|
320
359
|
output_scale_factor=mid_block_scale_factor,
|
321
|
-
cross_attention_dim=cross_attention_dim,
|
360
|
+
cross_attention_dim=cross_attention_dim[-1],
|
322
361
|
num_attention_heads=num_attention_heads[-1],
|
323
362
|
resnet_groups=norm_num_groups,
|
324
363
|
dual_cross_attention=False,
|
325
364
|
use_linear_projection=use_linear_projection,
|
365
|
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
326
366
|
)
|
327
367
|
|
328
368
|
# count how many layers upsample the images
|
@@ -331,6 +371,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
331
371
|
# up
|
332
372
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
333
373
|
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
374
|
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
375
|
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
376
|
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
334
377
|
|
335
378
|
output_channel = reversed_block_out_channels[0]
|
336
379
|
for i, up_block_type in enumerate(up_block_types):
|
@@ -349,7 +392,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
349
392
|
|
350
393
|
up_block = get_up_block(
|
351
394
|
up_block_type,
|
352
|
-
num_layers=
|
395
|
+
num_layers=reversed_layers_per_block[i] + 1,
|
353
396
|
in_channels=input_channel,
|
354
397
|
out_channels=output_channel,
|
355
398
|
prev_output_channel=prev_output_channel,
|
@@ -358,13 +401,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
358
401
|
resnet_eps=norm_eps,
|
359
402
|
resnet_act_fn=act_fn,
|
360
403
|
resnet_groups=norm_num_groups,
|
361
|
-
cross_attention_dim=
|
404
|
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
362
405
|
num_attention_heads=reversed_num_attention_heads[i],
|
363
406
|
dual_cross_attention=False,
|
364
407
|
resolution_idx=i,
|
365
408
|
use_linear_projection=use_linear_projection,
|
366
409
|
temporal_num_attention_heads=motion_num_attention_heads,
|
367
410
|
temporal_max_seq_length=motion_max_seq_length,
|
411
|
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
368
412
|
)
|
369
413
|
self.up_blocks.append(up_block)
|
370
414
|
prev_output_channel = output_channel
|
@@ -393,8 +437,11 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
393
437
|
):
|
394
438
|
has_motion_adapter = motion_adapter is not None
|
395
439
|
|
440
|
+
if has_motion_adapter:
|
441
|
+
motion_adapter.to(device=unet.device)
|
442
|
+
|
396
443
|
# based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459
|
397
|
-
config = unet.config
|
444
|
+
config = dict(unet.config)
|
398
445
|
config["_class_name"] = cls.__name__
|
399
446
|
|
400
447
|
down_blocks = []
|
@@ -427,6 +474,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
427
474
|
if not config.get("num_attention_heads"):
|
428
475
|
config["num_attention_heads"] = config["attention_head_dim"]
|
429
476
|
|
477
|
+
config = FrozenDict(config)
|
430
478
|
model = cls.from_config(config)
|
431
479
|
|
432
480
|
if not load_weights:
|
@@ -446,6 +494,36 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
446
494
|
model.time_proj.load_state_dict(unet.time_proj.state_dict())
|
447
495
|
model.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
448
496
|
|
497
|
+
if any(
|
498
|
+
isinstance(proc, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
|
499
|
+
for proc in unet.attn_processors.values()
|
500
|
+
):
|
501
|
+
attn_procs = {}
|
502
|
+
for name, processor in unet.attn_processors.items():
|
503
|
+
if name.endswith("attn1.processor"):
|
504
|
+
attn_processor_class = (
|
505
|
+
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
|
506
|
+
)
|
507
|
+
attn_procs[name] = attn_processor_class()
|
508
|
+
else:
|
509
|
+
attn_processor_class = (
|
510
|
+
IPAdapterAttnProcessor2_0
|
511
|
+
if hasattr(F, "scaled_dot_product_attention")
|
512
|
+
else IPAdapterAttnProcessor
|
513
|
+
)
|
514
|
+
attn_procs[name] = attn_processor_class(
|
515
|
+
hidden_size=processor.hidden_size,
|
516
|
+
cross_attention_dim=processor.cross_attention_dim,
|
517
|
+
scale=processor.scale,
|
518
|
+
num_tokens=processor.num_tokens,
|
519
|
+
)
|
520
|
+
for name, processor in model.attn_processors.items():
|
521
|
+
if name not in attn_procs:
|
522
|
+
attn_procs[name] = processor.__class__()
|
523
|
+
model.set_attn_processor(attn_procs)
|
524
|
+
model.config.encoder_hid_dim_type = "ip_image_proj"
|
525
|
+
model.encoder_hid_proj = unet.encoder_hid_proj
|
526
|
+
|
449
527
|
for i, down_block in enumerate(unet.down_blocks):
|
450
528
|
model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict())
|
451
529
|
if hasattr(model.down_blocks[i], "attentions"):
|
@@ -705,8 +783,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
705
783
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
706
784
|
def fuse_qkv_projections(self):
|
707
785
|
"""
|
708
|
-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
709
|
-
|
786
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
787
|
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
710
788
|
|
711
789
|
<Tip warning={true}>
|
712
790
|
|
@@ -742,7 +820,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
742
820
|
|
743
821
|
def forward(
|
744
822
|
self,
|
745
|
-
sample: torch.
|
823
|
+
sample: torch.Tensor,
|
746
824
|
timestep: Union[torch.Tensor, float, int],
|
747
825
|
encoder_hidden_states: torch.Tensor,
|
748
826
|
timestep_cond: Optional[torch.Tensor] = None,
|
@@ -757,10 +835,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
757
835
|
The [`UNetMotionModel`] forward method.
|
758
836
|
|
759
837
|
Args:
|
760
|
-
sample (`torch.
|
838
|
+
sample (`torch.Tensor`):
|
761
839
|
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
|
762
|
-
timestep (`torch.
|
763
|
-
encoder_hidden_states (`torch.
|
840
|
+
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
|
841
|
+
encoder_hidden_states (`torch.Tensor`):
|
764
842
|
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
765
843
|
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
766
844
|
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
@@ -831,6 +909,28 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
831
909
|
t_emb = t_emb.to(dtype=self.dtype)
|
832
910
|
|
833
911
|
emb = self.time_embedding(t_emb, timestep_cond)
|
912
|
+
aug_emb = None
|
913
|
+
|
914
|
+
if self.config.addition_embed_type == "text_time":
|
915
|
+
if "text_embeds" not in added_cond_kwargs:
|
916
|
+
raise ValueError(
|
917
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
918
|
+
)
|
919
|
+
|
920
|
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
921
|
+
if "time_ids" not in added_cond_kwargs:
|
922
|
+
raise ValueError(
|
923
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
924
|
+
)
|
925
|
+
time_ids = added_cond_kwargs.get("time_ids")
|
926
|
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
927
|
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
928
|
+
|
929
|
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
930
|
+
add_embeds = add_embeds.to(emb.dtype)
|
931
|
+
aug_emb = self.add_embedding(add_embeds)
|
932
|
+
|
933
|
+
emb = emb if aug_emb is None else emb + aug_emb
|
834
934
|
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
835
935
|
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
836
936
|
|
@@ -22,17 +22,17 @@ class UNetSpatioTemporalConditionOutput(BaseOutput):
|
|
22
22
|
The output of [`UNetSpatioTemporalConditionModel`].
|
23
23
|
|
24
24
|
Args:
|
25
|
-
sample (`torch.
|
25
|
+
sample (`torch.Tensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
|
26
26
|
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
27
27
|
"""
|
28
28
|
|
29
|
-
sample: torch.
|
29
|
+
sample: torch.Tensor = None
|
30
30
|
|
31
31
|
|
32
32
|
class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
33
33
|
r"""
|
34
|
-
A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and
|
35
|
-
shaped output.
|
34
|
+
A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and
|
35
|
+
returns a sample shaped output.
|
36
36
|
|
37
37
|
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
38
38
|
for all models (such as downloading or saving).
|
@@ -57,7 +57,8 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
|
57
57
|
The dimension of the cross attention features.
|
58
58
|
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
59
59
|
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
60
|
-
[`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`],
|
60
|
+
[`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`],
|
61
|
+
[`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
|
61
62
|
[`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
|
62
63
|
num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
|
63
64
|
The number of attention heads.
|
@@ -355,7 +356,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
|
355
356
|
|
356
357
|
def forward(
|
357
358
|
self,
|
358
|
-
sample: torch.
|
359
|
+
sample: torch.Tensor,
|
359
360
|
timestep: Union[torch.Tensor, float, int],
|
360
361
|
encoder_hidden_states: torch.Tensor,
|
361
362
|
added_time_ids: torch.Tensor,
|
@@ -365,21 +366,21 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
|
365
366
|
The [`UNetSpatioTemporalConditionModel`] forward method.
|
366
367
|
|
367
368
|
Args:
|
368
|
-
sample (`torch.
|
369
|
+
sample (`torch.Tensor`):
|
369
370
|
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
|
370
|
-
timestep (`torch.
|
371
|
-
encoder_hidden_states (`torch.
|
371
|
+
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
|
372
|
+
encoder_hidden_states (`torch.Tensor`):
|
372
373
|
The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
|
373
|
-
added_time_ids: (`torch.
|
374
|
+
added_time_ids: (`torch.Tensor`):
|
374
375
|
The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
|
375
376
|
embeddings and added to the time embeddings.
|
376
377
|
return_dict (`bool`, *optional*, defaults to `True`):
|
377
|
-
Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead
|
378
|
-
tuple.
|
378
|
+
Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead
|
379
|
+
of a plain tuple.
|
379
380
|
Returns:
|
380
381
|
[`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
|
381
|
-
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
|
382
|
-
a `tuple` is returned where the first element is the sample tensor.
|
382
|
+
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
|
383
|
+
returned, otherwise a `tuple` is returned where the first element is the sample tensor.
|
383
384
|
"""
|
384
385
|
# 1. time
|
385
386
|
timesteps = timestep
|
@@ -21,7 +21,7 @@ import torch
|
|
21
21
|
import torch.nn as nn
|
22
22
|
|
23
23
|
from ...configuration_utils import ConfigMixin, register_to_config
|
24
|
-
from ...loaders
|
24
|
+
from ...loaders import FromOriginalModelMixin
|
25
25
|
from ...utils import BaseOutput
|
26
26
|
from ..attention_processor import Attention
|
27
27
|
from ..modeling_utils import ModelMixin
|
@@ -41,11 +41,11 @@ class SDCascadeLayerNorm(nn.LayerNorm):
|
|
41
41
|
class SDCascadeTimestepBlock(nn.Module):
|
42
42
|
def __init__(self, c, c_timestep, conds=[]):
|
43
43
|
super().__init__()
|
44
|
-
|
45
|
-
self.mapper =
|
44
|
+
|
45
|
+
self.mapper = nn.Linear(c_timestep, c * 2)
|
46
46
|
self.conds = conds
|
47
47
|
for cname in conds:
|
48
|
-
setattr(self, f"mapper_{cname}",
|
48
|
+
setattr(self, f"mapper_{cname}", nn.Linear(c_timestep, c * 2))
|
49
49
|
|
50
50
|
def forward(self, x, t):
|
51
51
|
t = t.chunk(len(self.conds) + 1, dim=1)
|
@@ -94,12 +94,11 @@ class GlobalResponseNorm(nn.Module):
|
|
94
94
|
class SDCascadeAttnBlock(nn.Module):
|
95
95
|
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
|
96
96
|
super().__init__()
|
97
|
-
linear_cls = nn.Linear
|
98
97
|
|
99
98
|
self.self_attn = self_attn
|
100
99
|
self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
101
100
|
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
|
102
|
-
self.kv_mapper = nn.Sequential(nn.SiLU(),
|
101
|
+
self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
|
103
102
|
|
104
103
|
def forward(self, x, kv):
|
105
104
|
kv = self.kv_mapper(kv)
|
@@ -132,10 +131,10 @@ class UpDownBlock2d(nn.Module):
|
|
132
131
|
|
133
132
|
@dataclass
|
134
133
|
class StableCascadeUNetOutput(BaseOutput):
|
135
|
-
sample: torch.
|
134
|
+
sample: torch.Tensor = None
|
136
135
|
|
137
136
|
|
138
|
-
class StableCascadeUNet(ModelMixin, ConfigMixin,
|
137
|
+
class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
139
138
|
_supports_gradient_checkpointing = True
|
140
139
|
|
141
140
|
@register_to_config
|
@@ -187,7 +186,8 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
|
|
187
186
|
block_out_channels (Tuple[int], defaults to (2048, 2048)):
|
188
187
|
Tuple of output channels for each block.
|
189
188
|
num_attention_heads (Tuple[int], defaults to (32, 32)):
|
190
|
-
Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have
|
189
|
+
Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have
|
190
|
+
attention.
|
191
191
|
down_num_layers_per_block (Tuple[int], defaults to [8, 24]):
|
192
192
|
Number of layers in each down block.
|
193
193
|
up_num_layers_per_block (Tuple[int], defaults to [24, 8]):
|
@@ -198,10 +198,9 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
|
|
198
198
|
Number of 1x1 Convolutional layers to repeat in each up block.
|
199
199
|
block_types_per_layer (Tuple[Tuple[str]], optional,
|
200
200
|
defaults to (
|
201
|
-
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
|
202
|
-
|
203
|
-
):
|
204
|
-
Block types used in each layer of the up/down blocks.
|
201
|
+
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), ("SDCascadeResBlock",
|
202
|
+
"SDCascadeTimestepBlock", "SDCascadeAttnBlock")
|
203
|
+
): Block types used in each layer of the up/down blocks.
|
205
204
|
clip_text_in_channels (`int`, *optional*, defaults to `None`):
|
206
205
|
Number of input channels for CLIP based text conditioning.
|
207
206
|
clip_text_pooled_in_channels (`int`, *optional*, defaults to 1280):
|
@@ -521,9 +520,11 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
|
|
521
520
|
if isinstance(block, SDCascadeResBlock):
|
522
521
|
skip = level_outputs[i] if k == 0 and i > 0 else None
|
523
522
|
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
523
|
+
orig_type = x.dtype
|
524
524
|
x = torch.nn.functional.interpolate(
|
525
525
|
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
|
526
526
|
)
|
527
|
+
x = x.to(orig_type)
|
527
528
|
x = torch.utils.checkpoint.checkpoint(
|
528
529
|
create_custom_forward(block), x, skip, use_reentrant=False
|
529
530
|
)
|
@@ -547,9 +548,11 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
|
|
547
548
|
if isinstance(block, SDCascadeResBlock):
|
548
549
|
skip = level_outputs[i] if k == 0 and i > 0 else None
|
549
550
|
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
551
|
+
orig_type = x.dtype
|
550
552
|
x = torch.nn.functional.interpolate(
|
551
553
|
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
|
552
554
|
)
|
555
|
+
x = x.to(orig_type)
|
553
556
|
x = block(x, skip)
|
554
557
|
elif isinstance(block, SDCascadeAttnBlock):
|
555
558
|
x = block(x, clip)
|
diffusers/models/upsampling.py
CHANGED
@@ -110,7 +110,6 @@ class Upsample2D(nn.Module):
|
|
110
110
|
self.use_conv_transpose = use_conv_transpose
|
111
111
|
self.name = name
|
112
112
|
self.interpolate = interpolate
|
113
|
-
conv_cls = nn.Conv2d
|
114
113
|
|
115
114
|
if norm_type == "ln_norm":
|
116
115
|
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
@@ -131,7 +130,7 @@ class Upsample2D(nn.Module):
|
|
131
130
|
elif use_conv:
|
132
131
|
if kernel_size is None:
|
133
132
|
kernel_size = 3
|
134
|
-
conv =
|
133
|
+
conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
|
135
134
|
|
136
135
|
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
137
136
|
if name == "conv":
|
@@ -139,9 +138,7 @@ class Upsample2D(nn.Module):
|
|
139
138
|
else:
|
140
139
|
self.Conv2d_0 = conv
|
141
140
|
|
142
|
-
def forward(
|
143
|
-
self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, *args, **kwargs
|
144
|
-
) -> torch.FloatTensor:
|
141
|
+
def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, *args, **kwargs) -> torch.Tensor:
|
145
142
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
146
143
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
147
144
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -218,12 +215,12 @@ class FirUpsample2D(nn.Module):
|
|
218
215
|
|
219
216
|
def _upsample_2d(
|
220
217
|
self,
|
221
|
-
hidden_states: torch.
|
222
|
-
weight: Optional[torch.
|
223
|
-
kernel: Optional[torch.
|
218
|
+
hidden_states: torch.Tensor,
|
219
|
+
weight: Optional[torch.Tensor] = None,
|
220
|
+
kernel: Optional[torch.Tensor] = None,
|
224
221
|
factor: int = 2,
|
225
222
|
gain: float = 1,
|
226
|
-
) -> torch.
|
223
|
+
) -> torch.Tensor:
|
227
224
|
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
228
225
|
|
229
226
|
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
@@ -231,19 +228,19 @@ class FirUpsample2D(nn.Module):
|
|
231
228
|
arbitrary order.
|
232
229
|
|
233
230
|
Args:
|
234
|
-
hidden_states (`torch.
|
231
|
+
hidden_states (`torch.Tensor`):
|
235
232
|
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
236
|
-
weight (`torch.
|
233
|
+
weight (`torch.Tensor`, *optional*):
|
237
234
|
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
238
235
|
performed by `inChannels = x.shape[0] // numGroups`.
|
239
|
-
kernel (`torch.
|
236
|
+
kernel (`torch.Tensor`, *optional*):
|
240
237
|
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
241
238
|
corresponds to nearest-neighbor upsampling.
|
242
239
|
factor (`int`, *optional*): Integer upsampling factor (default: 2).
|
243
240
|
gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
|
244
241
|
|
245
242
|
Returns:
|
246
|
-
output (`torch.
|
243
|
+
output (`torch.Tensor`):
|
247
244
|
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
|
248
245
|
datatype as `hidden_states`.
|
249
246
|
"""
|
@@ -311,7 +308,7 @@ class FirUpsample2D(nn.Module):
|
|
311
308
|
|
312
309
|
return output
|
313
310
|
|
314
|
-
def forward(self, hidden_states: torch.
|
311
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
315
312
|
if self.use_conv:
|
316
313
|
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
317
314
|
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
@@ -402,11 +399,11 @@ def upfirdn2d_native(
|
|
402
399
|
|
403
400
|
|
404
401
|
def upsample_2d(
|
405
|
-
hidden_states: torch.
|
406
|
-
kernel: Optional[torch.
|
402
|
+
hidden_states: torch.Tensor,
|
403
|
+
kernel: Optional[torch.Tensor] = None,
|
407
404
|
factor: int = 2,
|
408
405
|
gain: float = 1,
|
409
|
-
) -> torch.
|
406
|
+
) -> torch.Tensor:
|
410
407
|
r"""Upsample2D a batch of 2D images with the given filter.
|
411
408
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
412
409
|
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
@@ -414,9 +411,9 @@ def upsample_2d(
|
|
414
411
|
a: multiple of the upsampling factor.
|
415
412
|
|
416
413
|
Args:
|
417
|
-
hidden_states (`torch.
|
414
|
+
hidden_states (`torch.Tensor`):
|
418
415
|
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
419
|
-
kernel (`torch.
|
416
|
+
kernel (`torch.Tensor`, *optional*):
|
420
417
|
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
421
418
|
corresponds to nearest-neighbor upsampling.
|
422
419
|
factor (`int`, *optional*, default to `2`):
|
@@ -425,7 +422,7 @@ def upsample_2d(
|
|
425
422
|
Scaling factor for signal magnitude (default: 1.0).
|
426
423
|
|
427
424
|
Returns:
|
428
|
-
output (`torch.
|
425
|
+
output (`torch.Tensor`):
|
429
426
|
Tensor of the shape `[N, C, H * factor, W * factor]`
|
430
427
|
"""
|
431
428
|
assert isinstance(factor, int) and factor >= 1
|
diffusers/models/vq_model.py
CHANGED
@@ -30,11 +30,11 @@ class VQEncoderOutput(BaseOutput):
|
|
30
30
|
Output of VQModel encoding method.
|
31
31
|
|
32
32
|
Args:
|
33
|
-
latents (`torch.
|
33
|
+
latents (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
34
34
|
The encoded output sample from the last layer of the model.
|
35
35
|
"""
|
36
36
|
|
37
|
-
latents: torch.
|
37
|
+
latents: torch.Tensor
|
38
38
|
|
39
39
|
|
40
40
|
class VQModel(ModelMixin, ConfigMixin):
|
@@ -127,7 +127,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
|
127
127
|
)
|
128
128
|
|
129
129
|
@apply_forward_hook
|
130
|
-
def encode(self, x: torch.
|
130
|
+
def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput:
|
131
131
|
h = self.encoder(x)
|
132
132
|
h = self.quant_conv(h)
|
133
133
|
|
@@ -138,31 +138,33 @@ class VQModel(ModelMixin, ConfigMixin):
|
|
138
138
|
|
139
139
|
@apply_forward_hook
|
140
140
|
def decode(
|
141
|
-
self, h: torch.
|
142
|
-
) -> Union[DecoderOutput, torch.
|
141
|
+
self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None
|
142
|
+
) -> Union[DecoderOutput, torch.Tensor]:
|
143
143
|
# also go through quantization layer
|
144
144
|
if not force_not_quantize:
|
145
|
-
quant,
|
145
|
+
quant, commit_loss, _ = self.quantize(h)
|
146
146
|
elif self.config.lookup_from_codebook:
|
147
147
|
quant = self.quantize.get_codebook_entry(h, shape)
|
148
|
+
commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
|
148
149
|
else:
|
149
150
|
quant = h
|
151
|
+
commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
|
150
152
|
quant2 = self.post_quant_conv(quant)
|
151
153
|
dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
|
152
154
|
|
153
155
|
if not return_dict:
|
154
|
-
return
|
156
|
+
return dec, commit_loss
|
155
157
|
|
156
|
-
return DecoderOutput(sample=dec)
|
158
|
+
return DecoderOutput(sample=dec, commit_loss=commit_loss)
|
157
159
|
|
158
160
|
def forward(
|
159
|
-
self, sample: torch.
|
160
|
-
) -> Union[DecoderOutput, Tuple[torch.
|
161
|
+
self, sample: torch.Tensor, return_dict: bool = True
|
162
|
+
) -> Union[DecoderOutput, Tuple[torch.Tensor, ...]]:
|
161
163
|
r"""
|
162
164
|
The [`VQModel`] forward method.
|
163
165
|
|
164
166
|
Args:
|
165
|
-
sample (`torch.
|
167
|
+
sample (`torch.Tensor`): Input sample.
|
166
168
|
return_dict (`bool`, *optional*, defaults to `True`):
|
167
169
|
Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple.
|
168
170
|
|
@@ -173,9 +175,8 @@ class VQModel(ModelMixin, ConfigMixin):
|
|
173
175
|
"""
|
174
176
|
|
175
177
|
h = self.encode(sample).latents
|
176
|
-
dec = self.decode(h)
|
178
|
+
dec = self.decode(h)
|
177
179
|
|
178
180
|
if not return_dict:
|
179
|
-
return
|
180
|
-
|
181
|
-
return DecoderOutput(sample=dec)
|
181
|
+
return dec.sample, dec.commit_loss
|
182
|
+
return dec
|