diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +26 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +33 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +8 -0
- diffusers/models/activations.py +23 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +475 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +363 -32
- diffusers/models/model_loading_utils.py +177 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +175 -99
- diffusers/models/normalization.py +2 -1
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/__init__.py +3 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
- diffusers/models/transformers/pixart_transformer_2d.py +336 -0
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +292 -184
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +27 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +7 -4
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +269 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +75 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
- diffusers-0.28.1.dist-info/RECORD +419 -0
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
@@ -11,39 +11,30 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
from dataclasses import dataclass
|
15
14
|
from typing import Any, Dict, Optional
|
16
15
|
|
17
16
|
import torch
|
18
17
|
import torch.nn.functional as F
|
19
18
|
from torch import nn
|
20
19
|
|
21
|
-
from ...configuration_utils import
|
22
|
-
from ...utils import
|
20
|
+
from ...configuration_utils import LegacyConfigMixin, register_to_config
|
21
|
+
from ...utils import deprecate, is_torch_version, logging
|
23
22
|
from ..attention import BasicTransformerBlock
|
24
23
|
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
|
25
|
-
from ..
|
24
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
25
|
+
from ..modeling_utils import LegacyModelMixin
|
26
26
|
from ..normalization import AdaLayerNormSingle
|
27
27
|
|
28
28
|
|
29
29
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
30
30
|
|
31
31
|
|
32
|
-
|
33
|
-
|
34
|
-
"""
|
35
|
-
The output of [`Transformer2DModel`].
|
36
|
-
|
37
|
-
Args:
|
38
|
-
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
39
|
-
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
40
|
-
distributions for the unnoised latent pixels.
|
41
|
-
"""
|
42
|
-
|
43
|
-
sample: torch.FloatTensor
|
32
|
+
class Transformer2DModelOutput(Transformer2DModelOutput):
|
33
|
+
deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead."
|
34
|
+
deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
|
44
35
|
|
45
36
|
|
46
|
-
class Transformer2DModel(
|
37
|
+
class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
47
38
|
"""
|
48
39
|
A 2D Transformer model for image-like data.
|
49
40
|
|
@@ -72,6 +63,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
72
63
|
"""
|
73
64
|
|
74
65
|
_supports_gradient_checkpointing = True
|
66
|
+
_no_split_modules = ["BasicTransformerBlock"]
|
75
67
|
|
76
68
|
@register_to_config
|
77
69
|
def __init__(
|
@@ -100,8 +92,11 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
100
92
|
attention_type: str = "default",
|
101
93
|
caption_channels: int = None,
|
102
94
|
interpolation_scale: float = None,
|
95
|
+
use_additional_conditions: Optional[bool] = None,
|
103
96
|
):
|
104
97
|
super().__init__()
|
98
|
+
|
99
|
+
# Validate inputs.
|
105
100
|
if patch_size is not None:
|
106
101
|
if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]:
|
107
102
|
raise NotImplementedError(
|
@@ -112,31 +107,12 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
112
107
|
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
|
113
108
|
)
|
114
109
|
|
115
|
-
self.use_linear_projection = use_linear_projection
|
116
|
-
self.num_attention_heads = num_attention_heads
|
117
|
-
self.attention_head_dim = attention_head_dim
|
118
|
-
inner_dim = num_attention_heads * attention_head_dim
|
119
|
-
|
120
|
-
conv_cls = nn.Conv2d
|
121
|
-
linear_cls = nn.Linear
|
122
|
-
|
123
110
|
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
124
111
|
# Define whether input is continuous or discrete depending on configuration
|
125
112
|
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
126
113
|
self.is_input_vectorized = num_vector_embeds is not None
|
127
114
|
self.is_input_patches = in_channels is not None and patch_size is not None
|
128
115
|
|
129
|
-
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
130
|
-
deprecation_message = (
|
131
|
-
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
132
|
-
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
133
|
-
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
134
|
-
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
135
|
-
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
136
|
-
)
|
137
|
-
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
138
|
-
norm_type = "ada_norm"
|
139
|
-
|
140
116
|
if self.is_input_continuous and self.is_input_vectorized:
|
141
117
|
raise ValueError(
|
142
118
|
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
@@ -153,104 +129,194 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
153
129
|
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
154
130
|
)
|
155
131
|
|
156
|
-
|
157
|
-
|
158
|
-
|
132
|
+
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
133
|
+
deprecation_message = (
|
134
|
+
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
135
|
+
" incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config."
|
136
|
+
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
137
|
+
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
138
|
+
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
139
|
+
)
|
140
|
+
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
141
|
+
norm_type = "ada_norm"
|
159
142
|
|
160
|
-
|
161
|
-
|
162
|
-
|
143
|
+
# Set some common variables used across the board.
|
144
|
+
self.use_linear_projection = use_linear_projection
|
145
|
+
self.interpolation_scale = interpolation_scale
|
146
|
+
self.caption_channels = caption_channels
|
147
|
+
self.num_attention_heads = num_attention_heads
|
148
|
+
self.attention_head_dim = attention_head_dim
|
149
|
+
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
150
|
+
self.in_channels = in_channels
|
151
|
+
self.out_channels = in_channels if out_channels is None else out_channels
|
152
|
+
self.gradient_checkpointing = False
|
153
|
+
|
154
|
+
if use_additional_conditions is None:
|
155
|
+
if norm_type == "ada_norm_single" and sample_size == 128:
|
156
|
+
use_additional_conditions = True
|
163
157
|
else:
|
164
|
-
|
158
|
+
use_additional_conditions = False
|
159
|
+
self.use_additional_conditions = use_additional_conditions
|
160
|
+
|
161
|
+
# 2. Initialize the right blocks.
|
162
|
+
# These functions follow a common structure:
|
163
|
+
# a. Initialize the input blocks. b. Initialize the transformer blocks.
|
164
|
+
# c. Initialize the output blocks and other projection blocks when necessary.
|
165
|
+
if self.is_input_continuous:
|
166
|
+
self._init_continuous_input(norm_type=norm_type)
|
165
167
|
elif self.is_input_vectorized:
|
166
|
-
|
167
|
-
|
168
|
+
self._init_vectorized_inputs(norm_type=norm_type)
|
169
|
+
elif self.is_input_patches:
|
170
|
+
self._init_patched_inputs(norm_type=norm_type)
|
168
171
|
|
169
|
-
|
170
|
-
|
171
|
-
self.
|
172
|
-
|
172
|
+
def _init_continuous_input(self, norm_type):
|
173
|
+
self.norm = torch.nn.GroupNorm(
|
174
|
+
num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True
|
175
|
+
)
|
176
|
+
if self.use_linear_projection:
|
177
|
+
self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim)
|
178
|
+
else:
|
179
|
+
self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0)
|
173
180
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
181
|
+
self.transformer_blocks = nn.ModuleList(
|
182
|
+
[
|
183
|
+
BasicTransformerBlock(
|
184
|
+
self.inner_dim,
|
185
|
+
self.config.num_attention_heads,
|
186
|
+
self.config.attention_head_dim,
|
187
|
+
dropout=self.config.dropout,
|
188
|
+
cross_attention_dim=self.config.cross_attention_dim,
|
189
|
+
activation_fn=self.config.activation_fn,
|
190
|
+
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
|
191
|
+
attention_bias=self.config.attention_bias,
|
192
|
+
only_cross_attention=self.config.only_cross_attention,
|
193
|
+
double_self_attention=self.config.double_self_attention,
|
194
|
+
upcast_attention=self.config.upcast_attention,
|
195
|
+
norm_type=norm_type,
|
196
|
+
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
197
|
+
norm_eps=self.config.norm_eps,
|
198
|
+
attention_type=self.config.attention_type,
|
199
|
+
)
|
200
|
+
for _ in range(self.config.num_layers)
|
201
|
+
]
|
202
|
+
)
|
179
203
|
|
180
|
-
|
181
|
-
self.
|
204
|
+
if self.use_linear_projection:
|
205
|
+
self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels)
|
206
|
+
else:
|
207
|
+
self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0)
|
182
208
|
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
209
|
+
def _init_vectorized_inputs(self, norm_type):
|
210
|
+
assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
211
|
+
assert (
|
212
|
+
self.config.num_vector_embeds is not None
|
213
|
+
), "Transformer2DModel over discrete input must provide num_embed"
|
214
|
+
|
215
|
+
self.height = self.config.sample_size
|
216
|
+
self.width = self.config.sample_size
|
217
|
+
self.num_latent_pixels = self.height * self.width
|
218
|
+
|
219
|
+
self.latent_image_embedding = ImagePositionalEmbeddings(
|
220
|
+
num_embed=self.config.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width
|
221
|
+
)
|
195
222
|
|
196
|
-
# 3. Define transformers blocks
|
197
223
|
self.transformer_blocks = nn.ModuleList(
|
198
224
|
[
|
199
225
|
BasicTransformerBlock(
|
200
|
-
inner_dim,
|
201
|
-
num_attention_heads,
|
202
|
-
attention_head_dim,
|
203
|
-
dropout=dropout,
|
204
|
-
cross_attention_dim=cross_attention_dim,
|
205
|
-
activation_fn=activation_fn,
|
206
|
-
num_embeds_ada_norm=num_embeds_ada_norm,
|
207
|
-
attention_bias=attention_bias,
|
208
|
-
only_cross_attention=only_cross_attention,
|
209
|
-
double_self_attention=double_self_attention,
|
210
|
-
upcast_attention=upcast_attention,
|
226
|
+
self.inner_dim,
|
227
|
+
self.config.num_attention_heads,
|
228
|
+
self.config.attention_head_dim,
|
229
|
+
dropout=self.config.dropout,
|
230
|
+
cross_attention_dim=self.config.cross_attention_dim,
|
231
|
+
activation_fn=self.config.activation_fn,
|
232
|
+
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
|
233
|
+
attention_bias=self.config.attention_bias,
|
234
|
+
only_cross_attention=self.config.only_cross_attention,
|
235
|
+
double_self_attention=self.config.double_self_attention,
|
236
|
+
upcast_attention=self.config.upcast_attention,
|
211
237
|
norm_type=norm_type,
|
212
|
-
norm_elementwise_affine=norm_elementwise_affine,
|
213
|
-
norm_eps=norm_eps,
|
214
|
-
attention_type=attention_type,
|
238
|
+
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
239
|
+
norm_eps=self.config.norm_eps,
|
240
|
+
attention_type=self.config.attention_type,
|
215
241
|
)
|
216
|
-
for
|
242
|
+
for _ in range(self.config.num_layers)
|
217
243
|
]
|
218
244
|
)
|
219
245
|
|
220
|
-
|
221
|
-
self.
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
self.
|
233
|
-
self.
|
234
|
-
|
235
|
-
|
236
|
-
self.
|
237
|
-
self.
|
238
|
-
self.
|
239
|
-
|
240
|
-
|
246
|
+
self.norm_out = nn.LayerNorm(self.inner_dim)
|
247
|
+
self.out = nn.Linear(self.inner_dim, self.config.num_vector_embeds - 1)
|
248
|
+
|
249
|
+
def _init_patched_inputs(self, norm_type):
|
250
|
+
assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
251
|
+
|
252
|
+
self.height = self.config.sample_size
|
253
|
+
self.width = self.config.sample_size
|
254
|
+
|
255
|
+
self.patch_size = self.config.patch_size
|
256
|
+
interpolation_scale = (
|
257
|
+
self.config.interpolation_scale
|
258
|
+
if self.config.interpolation_scale is not None
|
259
|
+
else max(self.config.sample_size // 64, 1)
|
260
|
+
)
|
261
|
+
self.pos_embed = PatchEmbed(
|
262
|
+
height=self.config.sample_size,
|
263
|
+
width=self.config.sample_size,
|
264
|
+
patch_size=self.config.patch_size,
|
265
|
+
in_channels=self.in_channels,
|
266
|
+
embed_dim=self.inner_dim,
|
267
|
+
interpolation_scale=interpolation_scale,
|
268
|
+
)
|
269
|
+
|
270
|
+
self.transformer_blocks = nn.ModuleList(
|
271
|
+
[
|
272
|
+
BasicTransformerBlock(
|
273
|
+
self.inner_dim,
|
274
|
+
self.config.num_attention_heads,
|
275
|
+
self.config.attention_head_dim,
|
276
|
+
dropout=self.config.dropout,
|
277
|
+
cross_attention_dim=self.config.cross_attention_dim,
|
278
|
+
activation_fn=self.config.activation_fn,
|
279
|
+
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
|
280
|
+
attention_bias=self.config.attention_bias,
|
281
|
+
only_cross_attention=self.config.only_cross_attention,
|
282
|
+
double_self_attention=self.config.double_self_attention,
|
283
|
+
upcast_attention=self.config.upcast_attention,
|
284
|
+
norm_type=norm_type,
|
285
|
+
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
286
|
+
norm_eps=self.config.norm_eps,
|
287
|
+
attention_type=self.config.attention_type,
|
288
|
+
)
|
289
|
+
for _ in range(self.config.num_layers)
|
290
|
+
]
|
291
|
+
)
|
292
|
+
|
293
|
+
if self.config.norm_type != "ada_norm_single":
|
294
|
+
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
295
|
+
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
296
|
+
self.proj_out_2 = nn.Linear(
|
297
|
+
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
|
298
|
+
)
|
299
|
+
elif self.config.norm_type == "ada_norm_single":
|
300
|
+
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
301
|
+
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
|
302
|
+
self.proj_out = nn.Linear(
|
303
|
+
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
|
304
|
+
)
|
305
|
+
|
306
|
+
# PixArt-Alpha blocks.
|
241
307
|
self.adaln_single = None
|
242
|
-
self.
|
243
|
-
if norm_type == "ada_norm_single":
|
244
|
-
self.use_additional_conditions = self.config.sample_size == 128
|
308
|
+
if self.config.norm_type == "ada_norm_single":
|
245
309
|
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
|
246
310
|
# additional conditions until we find better name
|
247
|
-
self.adaln_single = AdaLayerNormSingle(
|
311
|
+
self.adaln_single = AdaLayerNormSingle(
|
312
|
+
self.inner_dim, use_additional_conditions=self.use_additional_conditions
|
313
|
+
)
|
248
314
|
|
249
315
|
self.caption_projection = None
|
250
|
-
if caption_channels is not None:
|
251
|
-
self.caption_projection = PixArtAlphaTextProjection(
|
252
|
-
|
253
|
-
|
316
|
+
if self.caption_channels is not None:
|
317
|
+
self.caption_projection = PixArtAlphaTextProjection(
|
318
|
+
in_features=self.caption_channels, hidden_size=self.inner_dim
|
319
|
+
)
|
254
320
|
|
255
321
|
def _set_gradient_checkpointing(self, module, value=False):
|
256
322
|
if hasattr(module, "gradient_checkpointing"):
|
@@ -272,9 +338,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
272
338
|
The [`Transformer2DModel`] forward method.
|
273
339
|
|
274
340
|
Args:
|
275
|
-
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.
|
341
|
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
|
276
342
|
Input `hidden_states`.
|
277
|
-
encoder_hidden_states ( `torch.
|
343
|
+
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
278
344
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
279
345
|
self-attention.
|
280
346
|
timestep ( `torch.LongTensor`, *optional*):
|
@@ -308,7 +374,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
308
374
|
"""
|
309
375
|
if cross_attention_kwargs is not None:
|
310
376
|
if cross_attention_kwargs.get("scale", None) is not None:
|
311
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
377
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
312
378
|
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
313
379
|
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
314
380
|
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
@@ -334,41 +400,18 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
334
400
|
|
335
401
|
# 1. Input
|
336
402
|
if self.is_input_continuous:
|
337
|
-
|
403
|
+
batch_size, _, height, width = hidden_states.shape
|
338
404
|
residual = hidden_states
|
339
|
-
|
340
|
-
hidden_states = self.norm(hidden_states)
|
341
|
-
if not self.use_linear_projection:
|
342
|
-
hidden_states = self.proj_in(hidden_states)
|
343
|
-
inner_dim = hidden_states.shape[1]
|
344
|
-
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
345
|
-
else:
|
346
|
-
inner_dim = hidden_states.shape[1]
|
347
|
-
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
348
|
-
hidden_states = self.proj_in(hidden_states)
|
349
|
-
|
405
|
+
hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states)
|
350
406
|
elif self.is_input_vectorized:
|
351
407
|
hidden_states = self.latent_image_embedding(hidden_states)
|
352
408
|
elif self.is_input_patches:
|
353
409
|
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
354
|
-
hidden_states = self.
|
355
|
-
|
356
|
-
|
357
|
-
if self.use_additional_conditions and added_cond_kwargs is None:
|
358
|
-
raise ValueError(
|
359
|
-
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
360
|
-
)
|
361
|
-
batch_size = hidden_states.shape[0]
|
362
|
-
timestep, embedded_timestep = self.adaln_single(
|
363
|
-
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
364
|
-
)
|
410
|
+
hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(
|
411
|
+
hidden_states, encoder_hidden_states, timestep, added_cond_kwargs
|
412
|
+
)
|
365
413
|
|
366
414
|
# 2. Blocks
|
367
|
-
if self.caption_projection is not None:
|
368
|
-
batch_size = hidden_states.shape[0]
|
369
|
-
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
370
|
-
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
371
|
-
|
372
415
|
for block in self.transformer_blocks:
|
373
416
|
if self.training and self.gradient_checkpointing:
|
374
417
|
|
@@ -406,51 +449,116 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
406
449
|
|
407
450
|
# 3. Output
|
408
451
|
if self.is_input_continuous:
|
409
|
-
|
410
|
-
hidden_states
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
452
|
+
output = self._get_output_for_continuous_inputs(
|
453
|
+
hidden_states=hidden_states,
|
454
|
+
residual=residual,
|
455
|
+
batch_size=batch_size,
|
456
|
+
height=height,
|
457
|
+
width=width,
|
458
|
+
inner_dim=inner_dim,
|
459
|
+
)
|
417
460
|
elif self.is_input_vectorized:
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
461
|
+
output = self._get_output_for_vectorized_inputs(hidden_states)
|
462
|
+
elif self.is_input_patches:
|
463
|
+
output = self._get_output_for_patched_inputs(
|
464
|
+
hidden_states=hidden_states,
|
465
|
+
timestep=timestep,
|
466
|
+
class_labels=class_labels,
|
467
|
+
embedded_timestep=embedded_timestep,
|
468
|
+
height=height,
|
469
|
+
width=width,
|
470
|
+
)
|
422
471
|
|
423
|
-
|
424
|
-
|
472
|
+
if not return_dict:
|
473
|
+
return (output,)
|
474
|
+
|
475
|
+
return Transformer2DModelOutput(sample=output)
|
425
476
|
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
477
|
+
def _operate_on_continuous_inputs(self, hidden_states):
|
478
|
+
batch, _, height, width = hidden_states.shape
|
479
|
+
hidden_states = self.norm(hidden_states)
|
480
|
+
|
481
|
+
if not self.use_linear_projection:
|
482
|
+
hidden_states = self.proj_in(hidden_states)
|
483
|
+
inner_dim = hidden_states.shape[1]
|
484
|
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
485
|
+
else:
|
486
|
+
inner_dim = hidden_states.shape[1]
|
487
|
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
488
|
+
hidden_states = self.proj_in(hidden_states)
|
489
|
+
|
490
|
+
return hidden_states, inner_dim
|
491
|
+
|
492
|
+
def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs):
|
493
|
+
batch_size = hidden_states.shape[0]
|
494
|
+
hidden_states = self.pos_embed(hidden_states)
|
495
|
+
embedded_timestep = None
|
496
|
+
|
497
|
+
if self.adaln_single is not None:
|
498
|
+
if self.use_additional_conditions and added_cond_kwargs is None:
|
499
|
+
raise ValueError(
|
500
|
+
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
430
501
|
)
|
431
|
-
|
432
|
-
|
433
|
-
hidden_states = self.proj_out_2(hidden_states)
|
434
|
-
elif self.config.norm_type == "ada_norm_single":
|
435
|
-
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
436
|
-
hidden_states = self.norm_out(hidden_states)
|
437
|
-
# Modulation
|
438
|
-
hidden_states = hidden_states * (1 + scale) + shift
|
439
|
-
hidden_states = self.proj_out(hidden_states)
|
440
|
-
hidden_states = hidden_states.squeeze(1)
|
441
|
-
|
442
|
-
# unpatchify
|
443
|
-
if self.adaln_single is None:
|
444
|
-
height = width = int(hidden_states.shape[1] ** 0.5)
|
445
|
-
hidden_states = hidden_states.reshape(
|
446
|
-
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
502
|
+
timestep, embedded_timestep = self.adaln_single(
|
503
|
+
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
447
504
|
)
|
448
|
-
|
449
|
-
|
450
|
-
|
505
|
+
|
506
|
+
if self.caption_projection is not None:
|
507
|
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
508
|
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
509
|
+
|
510
|
+
return hidden_states, encoder_hidden_states, timestep, embedded_timestep
|
511
|
+
|
512
|
+
def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim):
|
513
|
+
if not self.use_linear_projection:
|
514
|
+
hidden_states = (
|
515
|
+
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
516
|
+
)
|
517
|
+
hidden_states = self.proj_out(hidden_states)
|
518
|
+
else:
|
519
|
+
hidden_states = self.proj_out(hidden_states)
|
520
|
+
hidden_states = (
|
521
|
+
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
451
522
|
)
|
452
523
|
|
453
|
-
|
454
|
-
|
524
|
+
output = hidden_states + residual
|
525
|
+
return output
|
455
526
|
|
456
|
-
|
527
|
+
def _get_output_for_vectorized_inputs(self, hidden_states):
|
528
|
+
hidden_states = self.norm_out(hidden_states)
|
529
|
+
logits = self.out(hidden_states)
|
530
|
+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
531
|
+
logits = logits.permute(0, 2, 1)
|
532
|
+
# log(p(x_0))
|
533
|
+
output = F.log_softmax(logits.double(), dim=1).float()
|
534
|
+
return output
|
535
|
+
|
536
|
+
def _get_output_for_patched_inputs(
|
537
|
+
self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None
|
538
|
+
):
|
539
|
+
if self.config.norm_type != "ada_norm_single":
|
540
|
+
conditioning = self.transformer_blocks[0].norm1.emb(
|
541
|
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
542
|
+
)
|
543
|
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
544
|
+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
545
|
+
hidden_states = self.proj_out_2(hidden_states)
|
546
|
+
elif self.config.norm_type == "ada_norm_single":
|
547
|
+
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
548
|
+
hidden_states = self.norm_out(hidden_states)
|
549
|
+
# Modulation
|
550
|
+
hidden_states = hidden_states * (1 + scale) + shift
|
551
|
+
hidden_states = self.proj_out(hidden_states)
|
552
|
+
hidden_states = hidden_states.squeeze(1)
|
553
|
+
|
554
|
+
# unpatchify
|
555
|
+
if self.adaln_single is None:
|
556
|
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
557
|
+
hidden_states = hidden_states.reshape(
|
558
|
+
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
559
|
+
)
|
560
|
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
561
|
+
output = hidden_states.reshape(
|
562
|
+
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
563
|
+
)
|
564
|
+
return output
|