diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +26 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +33 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +8 -0
- diffusers/models/activations.py +23 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +475 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +363 -32
- diffusers/models/model_loading_utils.py +177 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +175 -99
- diffusers/models/normalization.py +2 -1
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/__init__.py +3 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
- diffusers/models/transformers/pixart_transformer_2d.py +336 -0
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +292 -184
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +27 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +7 -4
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +269 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +75 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
- diffusers-0.28.1.dist-info/RECORD +419 -0
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,427 @@
|
|
1
|
+
# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from typing import Optional
|
15
|
+
|
16
|
+
import torch
|
17
|
+
import torch.nn.functional as F
|
18
|
+
from torch import nn
|
19
|
+
|
20
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
21
|
+
from ...utils import logging
|
22
|
+
from ...utils.torch_utils import maybe_allow_in_graph
|
23
|
+
from ..attention import FeedForward
|
24
|
+
from ..attention_processor import Attention, HunyuanAttnProcessor2_0
|
25
|
+
from ..embeddings import (
|
26
|
+
HunyuanCombinedTimestepTextSizeStyleEmbedding,
|
27
|
+
PatchEmbed,
|
28
|
+
PixArtAlphaTextProjection,
|
29
|
+
)
|
30
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
31
|
+
from ..modeling_utils import ModelMixin
|
32
|
+
from ..normalization import AdaLayerNormContinuous
|
33
|
+
|
34
|
+
|
35
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
36
|
+
|
37
|
+
|
38
|
+
class FP32LayerNorm(nn.LayerNorm):
|
39
|
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
40
|
+
origin_dtype = inputs.dtype
|
41
|
+
return F.layer_norm(
|
42
|
+
inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps
|
43
|
+
).to(origin_dtype)
|
44
|
+
|
45
|
+
|
46
|
+
class AdaLayerNormShift(nn.Module):
|
47
|
+
r"""
|
48
|
+
Norm layer modified to incorporate timestep embeddings.
|
49
|
+
|
50
|
+
Parameters:
|
51
|
+
embedding_dim (`int`): The size of each embedding vector.
|
52
|
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
53
|
+
"""
|
54
|
+
|
55
|
+
def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6):
|
56
|
+
super().__init__()
|
57
|
+
self.silu = nn.SiLU()
|
58
|
+
self.linear = nn.Linear(embedding_dim, embedding_dim)
|
59
|
+
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
|
60
|
+
|
61
|
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
62
|
+
shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype))
|
63
|
+
x = self.norm(x) + shift.unsqueeze(dim=1)
|
64
|
+
return x
|
65
|
+
|
66
|
+
|
67
|
+
@maybe_allow_in_graph
|
68
|
+
class HunyuanDiTBlock(nn.Module):
|
69
|
+
r"""
|
70
|
+
Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
|
71
|
+
QKNorm
|
72
|
+
|
73
|
+
Parameters:
|
74
|
+
dim (`int`):
|
75
|
+
The number of channels in the input and output.
|
76
|
+
num_attention_heads (`int`):
|
77
|
+
The number of headsto use for multi-head attention.
|
78
|
+
cross_attention_dim (`int`,*optional*):
|
79
|
+
The size of the encoder_hidden_states vector for cross attention.
|
80
|
+
dropout(`float`, *optional*, defaults to 0.0):
|
81
|
+
The dropout probability to use.
|
82
|
+
activation_fn (`str`,*optional*, defaults to `"geglu"`):
|
83
|
+
Activation function to be used in feed-forward. .
|
84
|
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
85
|
+
Whether to use learnable elementwise affine parameters for normalization.
|
86
|
+
norm_eps (`float`, *optional*, defaults to 1e-6):
|
87
|
+
A small constant added to the denominator in normalization layers to prevent division by zero.
|
88
|
+
final_dropout (`bool` *optional*, defaults to False):
|
89
|
+
Whether to apply a final dropout after the last feed-forward layer.
|
90
|
+
ff_inner_dim (`int`, *optional*):
|
91
|
+
The size of the hidden layer in the feed-forward block. Defaults to `None`.
|
92
|
+
ff_bias (`bool`, *optional*, defaults to `True`):
|
93
|
+
Whether to use bias in the feed-forward block.
|
94
|
+
skip (`bool`, *optional*, defaults to `False`):
|
95
|
+
Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
|
96
|
+
qk_norm (`bool`, *optional*, defaults to `True`):
|
97
|
+
Whether to use normalization in QK calculation. Defaults to `True`.
|
98
|
+
"""
|
99
|
+
|
100
|
+
def __init__(
|
101
|
+
self,
|
102
|
+
dim: int,
|
103
|
+
num_attention_heads: int,
|
104
|
+
cross_attention_dim: int = 1024,
|
105
|
+
dropout=0.0,
|
106
|
+
activation_fn: str = "geglu",
|
107
|
+
norm_elementwise_affine: bool = True,
|
108
|
+
norm_eps: float = 1e-6,
|
109
|
+
final_dropout: bool = False,
|
110
|
+
ff_inner_dim: Optional[int] = None,
|
111
|
+
ff_bias: bool = True,
|
112
|
+
skip: bool = False,
|
113
|
+
qk_norm: bool = True,
|
114
|
+
):
|
115
|
+
super().__init__()
|
116
|
+
|
117
|
+
# Define 3 blocks. Each block has its own normalization layer.
|
118
|
+
# NOTE: when new version comes, check norm2 and norm 3
|
119
|
+
# 1. Self-Attn
|
120
|
+
self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
121
|
+
|
122
|
+
self.attn1 = Attention(
|
123
|
+
query_dim=dim,
|
124
|
+
cross_attention_dim=None,
|
125
|
+
dim_head=dim // num_attention_heads,
|
126
|
+
heads=num_attention_heads,
|
127
|
+
qk_norm="layer_norm" if qk_norm else None,
|
128
|
+
eps=1e-6,
|
129
|
+
bias=True,
|
130
|
+
processor=HunyuanAttnProcessor2_0(),
|
131
|
+
)
|
132
|
+
|
133
|
+
# 2. Cross-Attn
|
134
|
+
self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
135
|
+
|
136
|
+
self.attn2 = Attention(
|
137
|
+
query_dim=dim,
|
138
|
+
cross_attention_dim=cross_attention_dim,
|
139
|
+
dim_head=dim // num_attention_heads,
|
140
|
+
heads=num_attention_heads,
|
141
|
+
qk_norm="layer_norm" if qk_norm else None,
|
142
|
+
eps=1e-6,
|
143
|
+
bias=True,
|
144
|
+
processor=HunyuanAttnProcessor2_0(),
|
145
|
+
)
|
146
|
+
# 3. Feed-forward
|
147
|
+
self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
148
|
+
|
149
|
+
self.ff = FeedForward(
|
150
|
+
dim,
|
151
|
+
dropout=dropout, ### 0.0
|
152
|
+
activation_fn=activation_fn, ### approx GeLU
|
153
|
+
final_dropout=final_dropout, ### 0.0
|
154
|
+
inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
|
155
|
+
bias=ff_bias,
|
156
|
+
)
|
157
|
+
|
158
|
+
# 4. Skip Connection
|
159
|
+
if skip:
|
160
|
+
self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True)
|
161
|
+
self.skip_linear = nn.Linear(2 * dim, dim)
|
162
|
+
else:
|
163
|
+
self.skip_linear = None
|
164
|
+
|
165
|
+
# let chunk size default to None
|
166
|
+
self._chunk_size = None
|
167
|
+
self._chunk_dim = 0
|
168
|
+
|
169
|
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
170
|
+
# Sets chunk feed-forward
|
171
|
+
self._chunk_size = chunk_size
|
172
|
+
self._chunk_dim = dim
|
173
|
+
|
174
|
+
def forward(
|
175
|
+
self,
|
176
|
+
hidden_states: torch.Tensor,
|
177
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
178
|
+
temb: Optional[torch.Tensor] = None,
|
179
|
+
image_rotary_emb=None,
|
180
|
+
skip=None,
|
181
|
+
) -> torch.Tensor:
|
182
|
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
183
|
+
# 0. Long Skip Connection
|
184
|
+
if self.skip_linear is not None:
|
185
|
+
cat = torch.cat([hidden_states, skip], dim=-1)
|
186
|
+
cat = self.skip_norm(cat)
|
187
|
+
hidden_states = self.skip_linear(cat)
|
188
|
+
|
189
|
+
# 1. Self-Attention
|
190
|
+
norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct
|
191
|
+
attn_output = self.attn1(
|
192
|
+
norm_hidden_states,
|
193
|
+
image_rotary_emb=image_rotary_emb,
|
194
|
+
)
|
195
|
+
hidden_states = hidden_states + attn_output
|
196
|
+
|
197
|
+
# 2. Cross-Attention
|
198
|
+
hidden_states = hidden_states + self.attn2(
|
199
|
+
self.norm2(hidden_states),
|
200
|
+
encoder_hidden_states=encoder_hidden_states,
|
201
|
+
image_rotary_emb=image_rotary_emb,
|
202
|
+
)
|
203
|
+
|
204
|
+
# FFN Layer ### TODO: switch norm2 and norm3 in the state dict
|
205
|
+
mlp_inputs = self.norm3(hidden_states)
|
206
|
+
hidden_states = hidden_states + self.ff(mlp_inputs)
|
207
|
+
|
208
|
+
return hidden_states
|
209
|
+
|
210
|
+
|
211
|
+
class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
212
|
+
"""
|
213
|
+
HunYuanDiT: Diffusion model with a Transformer backbone.
|
214
|
+
|
215
|
+
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
216
|
+
|
217
|
+
Parameters:
|
218
|
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
219
|
+
The number of heads to use for multi-head attention.
|
220
|
+
attention_head_dim (`int`, *optional*, defaults to 88):
|
221
|
+
The number of channels in each head.
|
222
|
+
in_channels (`int`, *optional*):
|
223
|
+
The number of channels in the input and output (specify if the input is **continuous**).
|
224
|
+
patch_size (`int`, *optional*):
|
225
|
+
The size of the patch to use for the input.
|
226
|
+
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
227
|
+
Activation function to use in feed-forward.
|
228
|
+
sample_size (`int`, *optional*):
|
229
|
+
The width of the latent images. This is fixed during training since it is used to learn a number of
|
230
|
+
position embeddings.
|
231
|
+
dropout (`float`, *optional*, defaults to 0.0):
|
232
|
+
The dropout probability to use.
|
233
|
+
cross_attention_dim (`int`, *optional*):
|
234
|
+
The number of dimension in the clip text embedding.
|
235
|
+
hidden_size (`int`, *optional*):
|
236
|
+
The size of hidden layer in the conditioning embedding layers.
|
237
|
+
num_layers (`int`, *optional*, defaults to 1):
|
238
|
+
The number of layers of Transformer blocks to use.
|
239
|
+
mlp_ratio (`float`, *optional*, defaults to 4.0):
|
240
|
+
The ratio of the hidden layer size to the input size.
|
241
|
+
learn_sigma (`bool`, *optional*, defaults to `True`):
|
242
|
+
Whether to predict variance.
|
243
|
+
cross_attention_dim_t5 (`int`, *optional*):
|
244
|
+
The number dimensions in t5 text embedding.
|
245
|
+
pooled_projection_dim (`int`, *optional*):
|
246
|
+
The size of the pooled projection.
|
247
|
+
text_len (`int`, *optional*):
|
248
|
+
The length of the clip text embedding.
|
249
|
+
text_len_t5 (`int`, *optional*):
|
250
|
+
The length of the T5 text embedding.
|
251
|
+
"""
|
252
|
+
|
253
|
+
@register_to_config
|
254
|
+
def __init__(
|
255
|
+
self,
|
256
|
+
num_attention_heads: int = 16,
|
257
|
+
attention_head_dim: int = 88,
|
258
|
+
in_channels: Optional[int] = None,
|
259
|
+
patch_size: Optional[int] = None,
|
260
|
+
activation_fn: str = "gelu-approximate",
|
261
|
+
sample_size=32,
|
262
|
+
hidden_size=1152,
|
263
|
+
num_layers: int = 28,
|
264
|
+
mlp_ratio: float = 4.0,
|
265
|
+
learn_sigma: bool = True,
|
266
|
+
cross_attention_dim: int = 1024,
|
267
|
+
norm_type: str = "layer_norm",
|
268
|
+
cross_attention_dim_t5: int = 2048,
|
269
|
+
pooled_projection_dim: int = 1024,
|
270
|
+
text_len: int = 77,
|
271
|
+
text_len_t5: int = 256,
|
272
|
+
):
|
273
|
+
super().__init__()
|
274
|
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
275
|
+
self.num_heads = num_attention_heads
|
276
|
+
self.inner_dim = num_attention_heads * attention_head_dim
|
277
|
+
|
278
|
+
self.text_embedder = PixArtAlphaTextProjection(
|
279
|
+
in_features=cross_attention_dim_t5,
|
280
|
+
hidden_size=cross_attention_dim_t5 * 4,
|
281
|
+
out_features=cross_attention_dim,
|
282
|
+
act_fn="silu_fp32",
|
283
|
+
)
|
284
|
+
|
285
|
+
self.text_embedding_padding = nn.Parameter(
|
286
|
+
torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
|
287
|
+
)
|
288
|
+
|
289
|
+
self.pos_embed = PatchEmbed(
|
290
|
+
height=sample_size,
|
291
|
+
width=sample_size,
|
292
|
+
in_channels=in_channels,
|
293
|
+
embed_dim=hidden_size,
|
294
|
+
patch_size=patch_size,
|
295
|
+
pos_embed_type=None,
|
296
|
+
)
|
297
|
+
|
298
|
+
self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
|
299
|
+
hidden_size,
|
300
|
+
pooled_projection_dim=pooled_projection_dim,
|
301
|
+
seq_len=text_len_t5,
|
302
|
+
cross_attention_dim=cross_attention_dim_t5,
|
303
|
+
)
|
304
|
+
|
305
|
+
# HunyuanDiT Blocks
|
306
|
+
self.blocks = nn.ModuleList(
|
307
|
+
[
|
308
|
+
HunyuanDiTBlock(
|
309
|
+
dim=self.inner_dim,
|
310
|
+
num_attention_heads=self.config.num_attention_heads,
|
311
|
+
activation_fn=activation_fn,
|
312
|
+
ff_inner_dim=int(self.inner_dim * mlp_ratio),
|
313
|
+
cross_attention_dim=cross_attention_dim,
|
314
|
+
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
|
315
|
+
skip=layer > num_layers // 2,
|
316
|
+
)
|
317
|
+
for layer in range(num_layers)
|
318
|
+
]
|
319
|
+
)
|
320
|
+
|
321
|
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
322
|
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
323
|
+
|
324
|
+
def forward(
|
325
|
+
self,
|
326
|
+
hidden_states,
|
327
|
+
timestep,
|
328
|
+
encoder_hidden_states=None,
|
329
|
+
text_embedding_mask=None,
|
330
|
+
encoder_hidden_states_t5=None,
|
331
|
+
text_embedding_mask_t5=None,
|
332
|
+
image_meta_size=None,
|
333
|
+
style=None,
|
334
|
+
image_rotary_emb=None,
|
335
|
+
return_dict=True,
|
336
|
+
):
|
337
|
+
"""
|
338
|
+
The [`HunyuanDiT2DModel`] forward method.
|
339
|
+
|
340
|
+
Args:
|
341
|
+
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
|
342
|
+
The input tensor.
|
343
|
+
timestep ( `torch.LongTensor`, *optional*):
|
344
|
+
Used to indicate denoising step.
|
345
|
+
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
346
|
+
Conditional embeddings for cross attention layer. This is the output of `BertModel`.
|
347
|
+
text_embedding_mask: torch.Tensor
|
348
|
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
|
349
|
+
of `BertModel`.
|
350
|
+
encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
351
|
+
Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
|
352
|
+
text_embedding_mask_t5: torch.Tensor
|
353
|
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
|
354
|
+
of T5 Text Encoder.
|
355
|
+
image_meta_size (torch.Tensor):
|
356
|
+
Conditional embedding indicate the image sizes
|
357
|
+
style: torch.Tensor:
|
358
|
+
Conditional embedding indicate the style
|
359
|
+
image_rotary_emb (`torch.Tensor`):
|
360
|
+
The image rotary embeddings to apply on query and key tensors during attention calculation.
|
361
|
+
return_dict: bool
|
362
|
+
Whether to return a dictionary.
|
363
|
+
"""
|
364
|
+
|
365
|
+
height, width = hidden_states.shape[-2:]
|
366
|
+
|
367
|
+
hidden_states = self.pos_embed(hidden_states)
|
368
|
+
|
369
|
+
temb = self.time_extra_emb(
|
370
|
+
timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
|
371
|
+
) # [B, D]
|
372
|
+
|
373
|
+
# text projection
|
374
|
+
batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
|
375
|
+
encoder_hidden_states_t5 = self.text_embedder(
|
376
|
+
encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
|
377
|
+
)
|
378
|
+
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)
|
379
|
+
|
380
|
+
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1)
|
381
|
+
text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
|
382
|
+
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
|
383
|
+
|
384
|
+
encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
|
385
|
+
|
386
|
+
skips = []
|
387
|
+
for layer, block in enumerate(self.blocks):
|
388
|
+
if layer > self.config.num_layers // 2:
|
389
|
+
skip = skips.pop()
|
390
|
+
hidden_states = block(
|
391
|
+
hidden_states,
|
392
|
+
temb=temb,
|
393
|
+
encoder_hidden_states=encoder_hidden_states,
|
394
|
+
image_rotary_emb=image_rotary_emb,
|
395
|
+
skip=skip,
|
396
|
+
) # (N, L, D)
|
397
|
+
else:
|
398
|
+
hidden_states = block(
|
399
|
+
hidden_states,
|
400
|
+
temb=temb,
|
401
|
+
encoder_hidden_states=encoder_hidden_states,
|
402
|
+
image_rotary_emb=image_rotary_emb,
|
403
|
+
) # (N, L, D)
|
404
|
+
|
405
|
+
if layer < (self.config.num_layers // 2 - 1):
|
406
|
+
skips.append(hidden_states)
|
407
|
+
|
408
|
+
# final layer
|
409
|
+
hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
|
410
|
+
hidden_states = self.proj_out(hidden_states)
|
411
|
+
# (N, L, patch_size ** 2 * out_channels)
|
412
|
+
|
413
|
+
# unpatchify: (N, out_channels, H, W)
|
414
|
+
patch_size = self.pos_embed.patch_size
|
415
|
+
height = height // patch_size
|
416
|
+
width = width // patch_size
|
417
|
+
|
418
|
+
hidden_states = hidden_states.reshape(
|
419
|
+
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
|
420
|
+
)
|
421
|
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
422
|
+
output = hidden_states.reshape(
|
423
|
+
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
|
424
|
+
)
|
425
|
+
if not return_dict:
|
426
|
+
return (output,)
|
427
|
+
return Transformer2DModelOutput(sample=output)
|