diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +26 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +33 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +8 -0
- diffusers/models/activations.py +23 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +475 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +363 -32
- diffusers/models/model_loading_utils.py +177 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +175 -99
- diffusers/models/normalization.py +2 -1
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/__init__.py +3 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
- diffusers/models/transformers/pixart_transformer_2d.py +336 -0
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +292 -184
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +27 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +7 -4
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +269 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +75 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
- diffusers-0.28.1.dist-info/RECORD +419 -0
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,336 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from typing import Any, Dict, Optional
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from torch import nn
|
18
|
+
|
19
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
20
|
+
from ...utils import is_torch_version, logging
|
21
|
+
from ..attention import BasicTransformerBlock
|
22
|
+
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
23
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
24
|
+
from ..modeling_utils import ModelMixin
|
25
|
+
from ..normalization import AdaLayerNormSingle
|
26
|
+
|
27
|
+
|
28
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
29
|
+
|
30
|
+
|
31
|
+
class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
32
|
+
r"""
|
33
|
+
A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
|
34
|
+
https://arxiv.org/abs/2403.04692).
|
35
|
+
|
36
|
+
Parameters:
|
37
|
+
num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
|
38
|
+
attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
|
39
|
+
in_channels (int, defaults to 4): The number of channels in the input.
|
40
|
+
out_channels (int, optional):
|
41
|
+
The number of channels in the output. Specify this parameter if the output channel number differs from the
|
42
|
+
input.
|
43
|
+
num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
|
44
|
+
dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
|
45
|
+
norm_num_groups (int, optional, defaults to 32):
|
46
|
+
Number of groups for group normalization within Transformer blocks.
|
47
|
+
cross_attention_dim (int, optional):
|
48
|
+
The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension.
|
49
|
+
attention_bias (bool, optional, defaults to True):
|
50
|
+
Configure if the Transformer blocks' attention should contain a bias parameter.
|
51
|
+
sample_size (int, defaults to 128):
|
52
|
+
The width of the latent images. This parameter is fixed during training.
|
53
|
+
patch_size (int, defaults to 2):
|
54
|
+
Size of the patches the model processes, relevant for architectures working on non-sequential data.
|
55
|
+
activation_fn (str, optional, defaults to "gelu-approximate"):
|
56
|
+
Activation function to use in feed-forward networks within Transformer blocks.
|
57
|
+
num_embeds_ada_norm (int, optional, defaults to 1000):
|
58
|
+
Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
|
59
|
+
inference.
|
60
|
+
upcast_attention (bool, optional, defaults to False):
|
61
|
+
If true, upcasts the attention mechanism dimensions for potentially improved performance.
|
62
|
+
norm_type (str, optional, defaults to "ada_norm_zero"):
|
63
|
+
Specifies the type of normalization used, can be 'ada_norm_zero'.
|
64
|
+
norm_elementwise_affine (bool, optional, defaults to False):
|
65
|
+
If true, enables element-wise affine parameters in the normalization layers.
|
66
|
+
norm_eps (float, optional, defaults to 1e-6):
|
67
|
+
A small constant added to the denominator in normalization layers to prevent division by zero.
|
68
|
+
interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings.
|
69
|
+
use_additional_conditions (bool, optional): If we're using additional conditions as inputs.
|
70
|
+
attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used.
|
71
|
+
caption_channels (int, optional, defaults to None):
|
72
|
+
Number of channels to use for projecting the caption embeddings.
|
73
|
+
use_linear_projection (bool, optional, defaults to False):
|
74
|
+
Deprecated argument. Will be removed in a future version.
|
75
|
+
num_vector_embeds (bool, optional, defaults to False):
|
76
|
+
Deprecated argument. Will be removed in a future version.
|
77
|
+
"""
|
78
|
+
|
79
|
+
_supports_gradient_checkpointing = True
|
80
|
+
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
|
81
|
+
|
82
|
+
@register_to_config
|
83
|
+
def __init__(
|
84
|
+
self,
|
85
|
+
num_attention_heads: int = 16,
|
86
|
+
attention_head_dim: int = 72,
|
87
|
+
in_channels: int = 4,
|
88
|
+
out_channels: Optional[int] = 8,
|
89
|
+
num_layers: int = 28,
|
90
|
+
dropout: float = 0.0,
|
91
|
+
norm_num_groups: int = 32,
|
92
|
+
cross_attention_dim: Optional[int] = 1152,
|
93
|
+
attention_bias: bool = True,
|
94
|
+
sample_size: int = 128,
|
95
|
+
patch_size: int = 2,
|
96
|
+
activation_fn: str = "gelu-approximate",
|
97
|
+
num_embeds_ada_norm: Optional[int] = 1000,
|
98
|
+
upcast_attention: bool = False,
|
99
|
+
norm_type: str = "ada_norm_single",
|
100
|
+
norm_elementwise_affine: bool = False,
|
101
|
+
norm_eps: float = 1e-6,
|
102
|
+
interpolation_scale: Optional[int] = None,
|
103
|
+
use_additional_conditions: Optional[bool] = None,
|
104
|
+
caption_channels: Optional[int] = None,
|
105
|
+
attention_type: Optional[str] = "default",
|
106
|
+
):
|
107
|
+
super().__init__()
|
108
|
+
|
109
|
+
# Validate inputs.
|
110
|
+
if norm_type != "ada_norm_single":
|
111
|
+
raise NotImplementedError(
|
112
|
+
f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
|
113
|
+
)
|
114
|
+
elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None:
|
115
|
+
raise ValueError(
|
116
|
+
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
|
117
|
+
)
|
118
|
+
|
119
|
+
# Set some common variables used across the board.
|
120
|
+
self.attention_head_dim = attention_head_dim
|
121
|
+
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
122
|
+
self.out_channels = in_channels if out_channels is None else out_channels
|
123
|
+
if use_additional_conditions is None:
|
124
|
+
if sample_size == 128:
|
125
|
+
use_additional_conditions = True
|
126
|
+
else:
|
127
|
+
use_additional_conditions = False
|
128
|
+
self.use_additional_conditions = use_additional_conditions
|
129
|
+
|
130
|
+
self.gradient_checkpointing = False
|
131
|
+
|
132
|
+
# 2. Initialize the position embedding and transformer blocks.
|
133
|
+
self.height = self.config.sample_size
|
134
|
+
self.width = self.config.sample_size
|
135
|
+
|
136
|
+
interpolation_scale = (
|
137
|
+
self.config.interpolation_scale
|
138
|
+
if self.config.interpolation_scale is not None
|
139
|
+
else max(self.config.sample_size // 64, 1)
|
140
|
+
)
|
141
|
+
self.pos_embed = PatchEmbed(
|
142
|
+
height=self.config.sample_size,
|
143
|
+
width=self.config.sample_size,
|
144
|
+
patch_size=self.config.patch_size,
|
145
|
+
in_channels=self.config.in_channels,
|
146
|
+
embed_dim=self.inner_dim,
|
147
|
+
interpolation_scale=interpolation_scale,
|
148
|
+
)
|
149
|
+
|
150
|
+
self.transformer_blocks = nn.ModuleList(
|
151
|
+
[
|
152
|
+
BasicTransformerBlock(
|
153
|
+
self.inner_dim,
|
154
|
+
self.config.num_attention_heads,
|
155
|
+
self.config.attention_head_dim,
|
156
|
+
dropout=self.config.dropout,
|
157
|
+
cross_attention_dim=self.config.cross_attention_dim,
|
158
|
+
activation_fn=self.config.activation_fn,
|
159
|
+
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
|
160
|
+
attention_bias=self.config.attention_bias,
|
161
|
+
upcast_attention=self.config.upcast_attention,
|
162
|
+
norm_type=norm_type,
|
163
|
+
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
164
|
+
norm_eps=self.config.norm_eps,
|
165
|
+
attention_type=self.config.attention_type,
|
166
|
+
)
|
167
|
+
for _ in range(self.config.num_layers)
|
168
|
+
]
|
169
|
+
)
|
170
|
+
|
171
|
+
# 3. Output blocks.
|
172
|
+
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
173
|
+
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
|
174
|
+
self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)
|
175
|
+
|
176
|
+
self.adaln_single = AdaLayerNormSingle(
|
177
|
+
self.inner_dim, use_additional_conditions=self.use_additional_conditions
|
178
|
+
)
|
179
|
+
self.caption_projection = None
|
180
|
+
if self.config.caption_channels is not None:
|
181
|
+
self.caption_projection = PixArtAlphaTextProjection(
|
182
|
+
in_features=self.config.caption_channels, hidden_size=self.inner_dim
|
183
|
+
)
|
184
|
+
|
185
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
186
|
+
if hasattr(module, "gradient_checkpointing"):
|
187
|
+
module.gradient_checkpointing = value
|
188
|
+
|
189
|
+
def forward(
|
190
|
+
self,
|
191
|
+
hidden_states: torch.Tensor,
|
192
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
193
|
+
timestep: Optional[torch.LongTensor] = None,
|
194
|
+
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
195
|
+
cross_attention_kwargs: Dict[str, Any] = None,
|
196
|
+
attention_mask: Optional[torch.Tensor] = None,
|
197
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
198
|
+
return_dict: bool = True,
|
199
|
+
):
|
200
|
+
"""
|
201
|
+
The [`PixArtTransformer2DModel`] forward method.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
205
|
+
Input `hidden_states`.
|
206
|
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
207
|
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
208
|
+
self-attention.
|
209
|
+
timestep (`torch.LongTensor`, *optional*):
|
210
|
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
211
|
+
added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs.
|
212
|
+
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
213
|
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
214
|
+
`self.processor` in
|
215
|
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
216
|
+
attention_mask ( `torch.Tensor`, *optional*):
|
217
|
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
218
|
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
219
|
+
negative values to the attention scores corresponding to "discard" tokens.
|
220
|
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
221
|
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
222
|
+
|
223
|
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
224
|
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
225
|
+
|
226
|
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
227
|
+
above. This bias will be added to the cross-attention scores.
|
228
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
229
|
+
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
230
|
+
tuple.
|
231
|
+
|
232
|
+
Returns:
|
233
|
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
234
|
+
`tuple` where the first element is the sample tensor.
|
235
|
+
"""
|
236
|
+
if self.use_additional_conditions and added_cond_kwargs is None:
|
237
|
+
raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
|
238
|
+
|
239
|
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
240
|
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
241
|
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
242
|
+
# expects mask of shape:
|
243
|
+
# [batch, key_tokens]
|
244
|
+
# adds singleton query_tokens dimension:
|
245
|
+
# [batch, 1, key_tokens]
|
246
|
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
247
|
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
248
|
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
249
|
+
if attention_mask is not None and attention_mask.ndim == 2:
|
250
|
+
# assume that mask is expressed as:
|
251
|
+
# (1 = keep, 0 = discard)
|
252
|
+
# convert mask into a bias that can be added to attention scores:
|
253
|
+
# (keep = +0, discard = -10000.0)
|
254
|
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
255
|
+
attention_mask = attention_mask.unsqueeze(1)
|
256
|
+
|
257
|
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
258
|
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
259
|
+
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
260
|
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
261
|
+
|
262
|
+
# 1. Input
|
263
|
+
batch_size = hidden_states.shape[0]
|
264
|
+
height, width = (
|
265
|
+
hidden_states.shape[-2] // self.config.patch_size,
|
266
|
+
hidden_states.shape[-1] // self.config.patch_size,
|
267
|
+
)
|
268
|
+
hidden_states = self.pos_embed(hidden_states)
|
269
|
+
|
270
|
+
timestep, embedded_timestep = self.adaln_single(
|
271
|
+
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
272
|
+
)
|
273
|
+
|
274
|
+
if self.caption_projection is not None:
|
275
|
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
276
|
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
277
|
+
|
278
|
+
# 2. Blocks
|
279
|
+
for block in self.transformer_blocks:
|
280
|
+
if self.training and self.gradient_checkpointing:
|
281
|
+
|
282
|
+
def create_custom_forward(module, return_dict=None):
|
283
|
+
def custom_forward(*inputs):
|
284
|
+
if return_dict is not None:
|
285
|
+
return module(*inputs, return_dict=return_dict)
|
286
|
+
else:
|
287
|
+
return module(*inputs)
|
288
|
+
|
289
|
+
return custom_forward
|
290
|
+
|
291
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
292
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
293
|
+
create_custom_forward(block),
|
294
|
+
hidden_states,
|
295
|
+
attention_mask,
|
296
|
+
encoder_hidden_states,
|
297
|
+
encoder_attention_mask,
|
298
|
+
timestep,
|
299
|
+
cross_attention_kwargs,
|
300
|
+
None,
|
301
|
+
**ckpt_kwargs,
|
302
|
+
)
|
303
|
+
else:
|
304
|
+
hidden_states = block(
|
305
|
+
hidden_states,
|
306
|
+
attention_mask=attention_mask,
|
307
|
+
encoder_hidden_states=encoder_hidden_states,
|
308
|
+
encoder_attention_mask=encoder_attention_mask,
|
309
|
+
timestep=timestep,
|
310
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
311
|
+
class_labels=None,
|
312
|
+
)
|
313
|
+
|
314
|
+
# 3. Output
|
315
|
+
shift, scale = (
|
316
|
+
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
|
317
|
+
).chunk(2, dim=1)
|
318
|
+
hidden_states = self.norm_out(hidden_states)
|
319
|
+
# Modulation
|
320
|
+
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
|
321
|
+
hidden_states = self.proj_out(hidden_states)
|
322
|
+
hidden_states = hidden_states.squeeze(1)
|
323
|
+
|
324
|
+
# unpatchify
|
325
|
+
hidden_states = hidden_states.reshape(
|
326
|
+
shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
|
327
|
+
)
|
328
|
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
329
|
+
output = hidden_states.reshape(
|
330
|
+
shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
|
331
|
+
)
|
332
|
+
|
333
|
+
if not return_dict:
|
334
|
+
return (output,)
|
335
|
+
|
336
|
+
return Transformer2DModelOutput(sample=output)
|
@@ -26,11 +26,11 @@ class PriorTransformerOutput(BaseOutput):
|
|
26
26
|
The output of [`PriorTransformer`].
|
27
27
|
|
28
28
|
Args:
|
29
|
-
predicted_image_embedding (`torch.
|
29
|
+
predicted_image_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
|
30
30
|
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
|
31
31
|
"""
|
32
32
|
|
33
|
-
predicted_image_embedding: torch.
|
33
|
+
predicted_image_embedding: torch.Tensor
|
34
34
|
|
35
35
|
|
36
36
|
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
@@ -246,8 +246,8 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
|
246
246
|
self,
|
247
247
|
hidden_states,
|
248
248
|
timestep: Union[torch.Tensor, float, int],
|
249
|
-
proj_embedding: torch.
|
250
|
-
encoder_hidden_states: Optional[torch.
|
249
|
+
proj_embedding: torch.Tensor,
|
250
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
251
251
|
attention_mask: Optional[torch.BoolTensor] = None,
|
252
252
|
return_dict: bool = True,
|
253
253
|
):
|
@@ -255,13 +255,13 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
|
255
255
|
The [`PriorTransformer`] forward method.
|
256
256
|
|
257
257
|
Args:
|
258
|
-
hidden_states (`torch.
|
258
|
+
hidden_states (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
|
259
259
|
The currently predicted image embeddings.
|
260
260
|
timestep (`torch.LongTensor`):
|
261
261
|
Current denoising step.
|
262
|
-
proj_embedding (`torch.
|
262
|
+
proj_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
|
263
263
|
Projected embedding vector the denoising process is conditioned on.
|
264
|
-
encoder_hidden_states (`torch.
|
264
|
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
|
265
265
|
Hidden states of the text embeddings the denoising process is conditioned on.
|
266
266
|
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
|
267
267
|
Text mask for the text embeddings.
|
@@ -86,7 +86,7 @@ class T5FilmDecoder(ModelMixin, ConfigMixin):
|
|
86
86
|
self.post_dropout = nn.Dropout(p=dropout_rate)
|
87
87
|
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
|
88
88
|
|
89
|
-
def encoder_decoder_mask(self, query_input: torch.
|
89
|
+
def encoder_decoder_mask(self, query_input: torch.Tensor, key_input: torch.Tensor) -> torch.Tensor:
|
90
90
|
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
|
91
91
|
return mask.unsqueeze(-3)
|
92
92
|
|
@@ -195,13 +195,13 @@ class DecoderLayer(nn.Module):
|
|
195
195
|
|
196
196
|
def forward(
|
197
197
|
self,
|
198
|
-
hidden_states: torch.
|
199
|
-
conditioning_emb: Optional[torch.
|
200
|
-
attention_mask: Optional[torch.
|
198
|
+
hidden_states: torch.Tensor,
|
199
|
+
conditioning_emb: Optional[torch.Tensor] = None,
|
200
|
+
attention_mask: Optional[torch.Tensor] = None,
|
201
201
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
202
202
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
203
203
|
encoder_decoder_position_bias=None,
|
204
|
-
) -> Tuple[torch.
|
204
|
+
) -> Tuple[torch.Tensor]:
|
205
205
|
hidden_states = self.layer[0](
|
206
206
|
hidden_states,
|
207
207
|
conditioning_emb=conditioning_emb,
|
@@ -249,10 +249,10 @@ class T5LayerSelfAttentionCond(nn.Module):
|
|
249
249
|
|
250
250
|
def forward(
|
251
251
|
self,
|
252
|
-
hidden_states: torch.
|
253
|
-
conditioning_emb: Optional[torch.
|
254
|
-
attention_mask: Optional[torch.
|
255
|
-
) -> torch.
|
252
|
+
hidden_states: torch.Tensor,
|
253
|
+
conditioning_emb: Optional[torch.Tensor] = None,
|
254
|
+
attention_mask: Optional[torch.Tensor] = None,
|
255
|
+
) -> torch.Tensor:
|
256
256
|
# pre_self_attention_layer_norm
|
257
257
|
normed_hidden_states = self.layer_norm(hidden_states)
|
258
258
|
|
@@ -292,10 +292,10 @@ class T5LayerCrossAttention(nn.Module):
|
|
292
292
|
|
293
293
|
def forward(
|
294
294
|
self,
|
295
|
-
hidden_states: torch.
|
296
|
-
key_value_states: Optional[torch.
|
297
|
-
attention_mask: Optional[torch.
|
298
|
-
) -> torch.
|
295
|
+
hidden_states: torch.Tensor,
|
296
|
+
key_value_states: Optional[torch.Tensor] = None,
|
297
|
+
attention_mask: Optional[torch.Tensor] = None,
|
298
|
+
) -> torch.Tensor:
|
299
299
|
normed_hidden_states = self.layer_norm(hidden_states)
|
300
300
|
attention_output = self.attention(
|
301
301
|
normed_hidden_states,
|
@@ -328,9 +328,7 @@ class T5LayerFFCond(nn.Module):
|
|
328
328
|
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
329
329
|
self.dropout = nn.Dropout(dropout_rate)
|
330
330
|
|
331
|
-
def forward(
|
332
|
-
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
|
333
|
-
) -> torch.FloatTensor:
|
331
|
+
def forward(self, hidden_states: torch.Tensor, conditioning_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
334
332
|
forwarded_states = self.layer_norm(hidden_states)
|
335
333
|
if conditioning_emb is not None:
|
336
334
|
forwarded_states = self.film(forwarded_states, conditioning_emb)
|
@@ -361,7 +359,7 @@ class T5DenseGatedActDense(nn.Module):
|
|
361
359
|
self.dropout = nn.Dropout(dropout_rate)
|
362
360
|
self.act = NewGELUActivation()
|
363
361
|
|
364
|
-
def forward(self, hidden_states: torch.
|
362
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
365
363
|
hidden_gelu = self.act(self.wi_0(hidden_states))
|
366
364
|
hidden_linear = self.wi_1(hidden_states)
|
367
365
|
hidden_states = hidden_gelu * hidden_linear
|
@@ -390,7 +388,7 @@ class T5LayerNorm(nn.Module):
|
|
390
388
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
391
389
|
self.variance_epsilon = eps
|
392
390
|
|
393
|
-
def forward(self, hidden_states: torch.
|
391
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
394
392
|
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
395
393
|
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
396
394
|
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
@@ -431,7 +429,7 @@ class T5FiLMLayer(nn.Module):
|
|
431
429
|
super().__init__()
|
432
430
|
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
|
433
431
|
|
434
|
-
def forward(self, x: torch.
|
432
|
+
def forward(self, x: torch.Tensor, conditioning_emb: torch.Tensor) -> torch.Tensor:
|
435
433
|
emb = self.scale_bias(conditioning_emb)
|
436
434
|
scale, shift = torch.chunk(emb, 2, -1)
|
437
435
|
x = x * (1 + scale) + shift
|