diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +94 -3
- diffusers/commands/env.py +1 -5
- diffusers/configuration_utils.py +4 -9
- diffusers/dependency_versions_table.py +2 -2
- diffusers/image_processor.py +1 -2
- diffusers/loaders/__init__.py +17 -2
- diffusers/loaders/ip_adapter.py +10 -7
- diffusers/loaders/lora_base.py +752 -0
- diffusers/loaders/lora_pipeline.py +2222 -0
- diffusers/loaders/peft.py +213 -5
- diffusers/loaders/single_file.py +1 -12
- diffusers/loaders/single_file_model.py +31 -10
- diffusers/loaders/single_file_utils.py +262 -2
- diffusers/loaders/textual_inversion.py +1 -6
- diffusers/loaders/unet.py +23 -208
- diffusers/models/__init__.py +20 -0
- diffusers/models/activations.py +22 -0
- diffusers/models/attention.py +386 -7
- diffusers/models/attention_processor.py +1795 -629
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_kl.py +14 -3
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vq_model.py +4 -4
- diffusers/models/controlnet.py +2 -3
- diffusers/models/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnet_sd3.py +11 -11
- diffusers/models/controlnet_sparsectrl.py +789 -0
- diffusers/models/controlnet_xs.py +40 -10
- diffusers/models/downsampling.py +68 -0
- diffusers/models/embeddings.py +319 -36
- diffusers/models/model_loading_utils.py +1 -3
- diffusers/models/modeling_flax_utils.py +1 -6
- diffusers/models/modeling_utils.py +4 -16
- diffusers/models/normalization.py +203 -12
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +102 -1
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/transformer_flux.py +455 -0
- diffusers/models/transformers/transformer_sd3.py +18 -4
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +8 -1
- diffusers/models/unets/unet_3d_blocks.py +51 -920
- diffusers/models/unets/unet_3d_condition.py +4 -1
- diffusers/models/unets/unet_i2vgen_xl.py +4 -1
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +1330 -84
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +1 -3
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +64 -0
- diffusers/models/vq_model.py +8 -4
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +100 -3
- diffusers/pipelines/animatediff/__init__.py +4 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
- diffusers/pipelines/auto_pipeline.py +97 -19
- diffusers/pipelines/cogvideo/__init__.py +48 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +749 -0
- diffusers/pipelines/flux/pipeline_output.py +21 -0
- diffusers/pipelines/free_init_utils.py +2 -0
- diffusers/pipelines/free_noise_utils.py +236 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +334 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
- diffusers/pipelines/pag/__init__.py +67 -0
- diffusers/pipelines/pag/pag_utils.py +237 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
- diffusers/pipelines/pia/pipeline_pia.py +30 -37
- diffusers/pipelines/pipeline_flax_utils.py +4 -9
- diffusers/pipelines/pipeline_loading_utils.py +0 -3
- diffusers/pipelines/pipeline_utils.py +2 -14
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
- diffusers/schedulers/__init__.py +8 -0
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +1 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
- diffusers/schedulers/scheduling_ddpm.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +2 -2
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
- diffusers/schedulers/scheduling_ipndm.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
- diffusers/schedulers/scheduling_utils.py +1 -3
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/training_utils.py +99 -14
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +210 -0
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
- diffusers/utils/dynamic_modules_utils.py +1 -11
- diffusers/utils/export_utils.py +1 -4
- diffusers/utils/hub_utils.py +45 -42
- diffusers/utils/import_utils.py +19 -16
- diffusers/utils/loading_utils.py +76 -3
- diffusers/utils/testing_utils.py +11 -8
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1728
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,327 @@
|
|
1
|
+
# Copyright 2024 the Latte Team and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from typing import Optional
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from torch import nn
|
18
|
+
|
19
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
20
|
+
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
|
21
|
+
from ..attention import BasicTransformerBlock
|
22
|
+
from ..embeddings import PatchEmbed
|
23
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
24
|
+
from ..modeling_utils import ModelMixin
|
25
|
+
from ..normalization import AdaLayerNormSingle
|
26
|
+
|
27
|
+
|
28
|
+
class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
29
|
+
_supports_gradient_checkpointing = True
|
30
|
+
|
31
|
+
"""
|
32
|
+
A 3D Transformer model for video-like data, paper: https://arxiv.org/abs/2401.03048, offical code:
|
33
|
+
https://github.com/Vchitect/Latte
|
34
|
+
|
35
|
+
Parameters:
|
36
|
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
37
|
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
38
|
+
in_channels (`int`, *optional*):
|
39
|
+
The number of channels in the input.
|
40
|
+
out_channels (`int`, *optional*):
|
41
|
+
The number of channels in the output.
|
42
|
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
43
|
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
44
|
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
45
|
+
attention_bias (`bool`, *optional*):
|
46
|
+
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
47
|
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
48
|
+
This is fixed during training since it is used to learn a number of position embeddings.
|
49
|
+
patch_size (`int`, *optional*):
|
50
|
+
The size of the patches to use in the patch embedding layer.
|
51
|
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
52
|
+
num_embeds_ada_norm ( `int`, *optional*):
|
53
|
+
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
54
|
+
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
55
|
+
added to the hidden states. During inference, you can denoise for up to but not more steps than
|
56
|
+
`num_embeds_ada_norm`.
|
57
|
+
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
58
|
+
The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`.
|
59
|
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
60
|
+
Whether or not to use elementwise affine in normalization layers.
|
61
|
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers.
|
62
|
+
caption_channels (`int`, *optional*):
|
63
|
+
The number of channels in the caption embeddings.
|
64
|
+
video_length (`int`, *optional*):
|
65
|
+
The number of frames in the video-like data.
|
66
|
+
"""
|
67
|
+
|
68
|
+
@register_to_config
|
69
|
+
def __init__(
|
70
|
+
self,
|
71
|
+
num_attention_heads: int = 16,
|
72
|
+
attention_head_dim: int = 88,
|
73
|
+
in_channels: Optional[int] = None,
|
74
|
+
out_channels: Optional[int] = None,
|
75
|
+
num_layers: int = 1,
|
76
|
+
dropout: float = 0.0,
|
77
|
+
cross_attention_dim: Optional[int] = None,
|
78
|
+
attention_bias: bool = False,
|
79
|
+
sample_size: int = 64,
|
80
|
+
patch_size: Optional[int] = None,
|
81
|
+
activation_fn: str = "geglu",
|
82
|
+
num_embeds_ada_norm: Optional[int] = None,
|
83
|
+
norm_type: str = "layer_norm",
|
84
|
+
norm_elementwise_affine: bool = True,
|
85
|
+
norm_eps: float = 1e-5,
|
86
|
+
caption_channels: int = None,
|
87
|
+
video_length: int = 16,
|
88
|
+
):
|
89
|
+
super().__init__()
|
90
|
+
inner_dim = num_attention_heads * attention_head_dim
|
91
|
+
|
92
|
+
# 1. Define input layers
|
93
|
+
self.height = sample_size
|
94
|
+
self.width = sample_size
|
95
|
+
|
96
|
+
interpolation_scale = self.config.sample_size // 64
|
97
|
+
interpolation_scale = max(interpolation_scale, 1)
|
98
|
+
self.pos_embed = PatchEmbed(
|
99
|
+
height=sample_size,
|
100
|
+
width=sample_size,
|
101
|
+
patch_size=patch_size,
|
102
|
+
in_channels=in_channels,
|
103
|
+
embed_dim=inner_dim,
|
104
|
+
interpolation_scale=interpolation_scale,
|
105
|
+
)
|
106
|
+
|
107
|
+
# 2. Define spatial transformers blocks
|
108
|
+
self.transformer_blocks = nn.ModuleList(
|
109
|
+
[
|
110
|
+
BasicTransformerBlock(
|
111
|
+
inner_dim,
|
112
|
+
num_attention_heads,
|
113
|
+
attention_head_dim,
|
114
|
+
dropout=dropout,
|
115
|
+
cross_attention_dim=cross_attention_dim,
|
116
|
+
activation_fn=activation_fn,
|
117
|
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
118
|
+
attention_bias=attention_bias,
|
119
|
+
norm_type=norm_type,
|
120
|
+
norm_elementwise_affine=norm_elementwise_affine,
|
121
|
+
norm_eps=norm_eps,
|
122
|
+
)
|
123
|
+
for d in range(num_layers)
|
124
|
+
]
|
125
|
+
)
|
126
|
+
|
127
|
+
# 3. Define temporal transformers blocks
|
128
|
+
self.temporal_transformer_blocks = nn.ModuleList(
|
129
|
+
[
|
130
|
+
BasicTransformerBlock(
|
131
|
+
inner_dim,
|
132
|
+
num_attention_heads,
|
133
|
+
attention_head_dim,
|
134
|
+
dropout=dropout,
|
135
|
+
cross_attention_dim=None,
|
136
|
+
activation_fn=activation_fn,
|
137
|
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
138
|
+
attention_bias=attention_bias,
|
139
|
+
norm_type=norm_type,
|
140
|
+
norm_elementwise_affine=norm_elementwise_affine,
|
141
|
+
norm_eps=norm_eps,
|
142
|
+
)
|
143
|
+
for d in range(num_layers)
|
144
|
+
]
|
145
|
+
)
|
146
|
+
|
147
|
+
# 4. Define output layers
|
148
|
+
self.out_channels = in_channels if out_channels is None else out_channels
|
149
|
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
150
|
+
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
151
|
+
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
152
|
+
|
153
|
+
# 5. Latte other blocks.
|
154
|
+
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)
|
155
|
+
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
156
|
+
|
157
|
+
# define temporal positional embedding
|
158
|
+
temp_pos_embed = get_1d_sincos_pos_embed_from_grid(
|
159
|
+
inner_dim, torch.arange(0, video_length).unsqueeze(1)
|
160
|
+
) # 1152 hidden size
|
161
|
+
self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
|
162
|
+
|
163
|
+
self.gradient_checkpointing = False
|
164
|
+
|
165
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
166
|
+
self.gradient_checkpointing = value
|
167
|
+
|
168
|
+
def forward(
|
169
|
+
self,
|
170
|
+
hidden_states: torch.Tensor,
|
171
|
+
timestep: Optional[torch.LongTensor] = None,
|
172
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
173
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
174
|
+
enable_temporal_attentions: bool = True,
|
175
|
+
return_dict: bool = True,
|
176
|
+
):
|
177
|
+
"""
|
178
|
+
The [`LatteTransformer3DModel`] forward method.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
hidden_states shape `(batch size, channel, num_frame, height, width)`:
|
182
|
+
Input `hidden_states`.
|
183
|
+
timestep ( `torch.LongTensor`, *optional*):
|
184
|
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
185
|
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
186
|
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
187
|
+
self-attention.
|
188
|
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
189
|
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
190
|
+
|
191
|
+
* Mask `(batcheight, sequence_length)` True = keep, False = discard.
|
192
|
+
* Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard.
|
193
|
+
|
194
|
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
195
|
+
above. This bias will be added to the cross-attention scores.
|
196
|
+
enable_temporal_attentions:
|
197
|
+
(`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions.
|
198
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
199
|
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
200
|
+
tuple.
|
201
|
+
|
202
|
+
Returns:
|
203
|
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
204
|
+
`tuple` where the first element is the sample tensor.
|
205
|
+
"""
|
206
|
+
|
207
|
+
# Reshape hidden states
|
208
|
+
batch_size, channels, num_frame, height, width = hidden_states.shape
|
209
|
+
# batch_size channels num_frame height width -> (batch_size * num_frame) channels height width
|
210
|
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width)
|
211
|
+
|
212
|
+
# Input
|
213
|
+
height, width = (
|
214
|
+
hidden_states.shape[-2] // self.config.patch_size,
|
215
|
+
hidden_states.shape[-1] // self.config.patch_size,
|
216
|
+
)
|
217
|
+
num_patches = height * width
|
218
|
+
|
219
|
+
hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings
|
220
|
+
|
221
|
+
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
222
|
+
timestep, embedded_timestep = self.adaln_single(
|
223
|
+
timestep, added_cond_kwargs=added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
224
|
+
)
|
225
|
+
|
226
|
+
# Prepare text embeddings for spatial block
|
227
|
+
# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
|
228
|
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
|
229
|
+
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view(
|
230
|
+
-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]
|
231
|
+
)
|
232
|
+
|
233
|
+
# Prepare timesteps for spatial and temporal block
|
234
|
+
timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1])
|
235
|
+
timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1])
|
236
|
+
|
237
|
+
# Spatial and temporal transformer blocks
|
238
|
+
for i, (spatial_block, temp_block) in enumerate(
|
239
|
+
zip(self.transformer_blocks, self.temporal_transformer_blocks)
|
240
|
+
):
|
241
|
+
if self.training and self.gradient_checkpointing:
|
242
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
243
|
+
spatial_block,
|
244
|
+
hidden_states,
|
245
|
+
None, # attention_mask
|
246
|
+
encoder_hidden_states_spatial,
|
247
|
+
encoder_attention_mask,
|
248
|
+
timestep_spatial,
|
249
|
+
None, # cross_attention_kwargs
|
250
|
+
None, # class_labels
|
251
|
+
use_reentrant=False,
|
252
|
+
)
|
253
|
+
else:
|
254
|
+
hidden_states = spatial_block(
|
255
|
+
hidden_states,
|
256
|
+
None, # attention_mask
|
257
|
+
encoder_hidden_states_spatial,
|
258
|
+
encoder_attention_mask,
|
259
|
+
timestep_spatial,
|
260
|
+
None, # cross_attention_kwargs
|
261
|
+
None, # class_labels
|
262
|
+
)
|
263
|
+
|
264
|
+
if enable_temporal_attentions:
|
265
|
+
# (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_size
|
266
|
+
hidden_states = hidden_states.reshape(
|
267
|
+
batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
|
268
|
+
).permute(0, 2, 1, 3)
|
269
|
+
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
|
270
|
+
|
271
|
+
if i == 0 and num_frame > 1:
|
272
|
+
hidden_states = hidden_states + self.temp_pos_embed
|
273
|
+
|
274
|
+
if self.training and self.gradient_checkpointing:
|
275
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
276
|
+
temp_block,
|
277
|
+
hidden_states,
|
278
|
+
None, # attention_mask
|
279
|
+
None, # encoder_hidden_states
|
280
|
+
None, # encoder_attention_mask
|
281
|
+
timestep_temp,
|
282
|
+
None, # cross_attention_kwargs
|
283
|
+
None, # class_labels
|
284
|
+
use_reentrant=False,
|
285
|
+
)
|
286
|
+
else:
|
287
|
+
hidden_states = temp_block(
|
288
|
+
hidden_states,
|
289
|
+
None, # attention_mask
|
290
|
+
None, # encoder_hidden_states
|
291
|
+
None, # encoder_attention_mask
|
292
|
+
timestep_temp,
|
293
|
+
None, # cross_attention_kwargs
|
294
|
+
None, # class_labels
|
295
|
+
)
|
296
|
+
|
297
|
+
# (batch_size * num_tokens) num_frame hidden_size -> (batch_size * num_frame) num_tokens hidden_size
|
298
|
+
hidden_states = hidden_states.reshape(
|
299
|
+
batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
|
300
|
+
).permute(0, 2, 1, 3)
|
301
|
+
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
|
302
|
+
|
303
|
+
embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
|
304
|
+
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
305
|
+
hidden_states = self.norm_out(hidden_states)
|
306
|
+
# Modulation
|
307
|
+
hidden_states = hidden_states * (1 + scale) + shift
|
308
|
+
hidden_states = self.proj_out(hidden_states)
|
309
|
+
|
310
|
+
# unpatchify
|
311
|
+
if self.adaln_single is None:
|
312
|
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
313
|
+
hidden_states = hidden_states.reshape(
|
314
|
+
shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
|
315
|
+
)
|
316
|
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
317
|
+
output = hidden_states.reshape(
|
318
|
+
shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
|
319
|
+
)
|
320
|
+
output = output.reshape(batch_size, -1, output.shape[-3], output.shape[-2], output.shape[-1]).permute(
|
321
|
+
0, 2, 1, 3, 4
|
322
|
+
)
|
323
|
+
|
324
|
+
if not return_dict:
|
325
|
+
return (output,)
|
326
|
+
|
327
|
+
return Transformer2DModelOutput(sample=output)
|
@@ -0,0 +1,340 @@
|
|
1
|
+
# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Any, Dict, Optional
|
16
|
+
|
17
|
+
import torch
|
18
|
+
import torch.nn as nn
|
19
|
+
|
20
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
21
|
+
from ...utils import logging
|
22
|
+
from ..attention import LuminaFeedForward
|
23
|
+
from ..attention_processor import Attention, LuminaAttnProcessor2_0
|
24
|
+
from ..embeddings import (
|
25
|
+
LuminaCombinedTimestepCaptionEmbedding,
|
26
|
+
LuminaPatchEmbed,
|
27
|
+
)
|
28
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
29
|
+
from ..modeling_utils import ModelMixin
|
30
|
+
from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
|
31
|
+
|
32
|
+
|
33
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
34
|
+
|
35
|
+
|
36
|
+
class LuminaNextDiTBlock(nn.Module):
|
37
|
+
"""
|
38
|
+
A LuminaNextDiTBlock for LuminaNextDiT2DModel.
|
39
|
+
|
40
|
+
Parameters:
|
41
|
+
dim (`int`): Embedding dimension of the input features.
|
42
|
+
num_attention_heads (`int`): Number of attention heads.
|
43
|
+
num_kv_heads (`int`):
|
44
|
+
Number of attention heads in key and value features (if using GQA), or set to None for the same as query.
|
45
|
+
multiple_of (`int`): The number of multiple of ffn layer.
|
46
|
+
ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension.
|
47
|
+
norm_eps (`float`): The eps for norm layer.
|
48
|
+
qk_norm (`bool`): normalization for query and key.
|
49
|
+
cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states.
|
50
|
+
norm_elementwise_affine (`bool`, *optional*, defaults to True),
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
dim: int,
|
56
|
+
num_attention_heads: int,
|
57
|
+
num_kv_heads: int,
|
58
|
+
multiple_of: int,
|
59
|
+
ffn_dim_multiplier: float,
|
60
|
+
norm_eps: float,
|
61
|
+
qk_norm: bool,
|
62
|
+
cross_attention_dim: int,
|
63
|
+
norm_elementwise_affine: bool = True,
|
64
|
+
) -> None:
|
65
|
+
super().__init__()
|
66
|
+
self.head_dim = dim // num_attention_heads
|
67
|
+
|
68
|
+
self.gate = nn.Parameter(torch.zeros([num_attention_heads]))
|
69
|
+
|
70
|
+
# Self-attention
|
71
|
+
self.attn1 = Attention(
|
72
|
+
query_dim=dim,
|
73
|
+
cross_attention_dim=None,
|
74
|
+
dim_head=dim // num_attention_heads,
|
75
|
+
qk_norm="layer_norm_across_heads" if qk_norm else None,
|
76
|
+
heads=num_attention_heads,
|
77
|
+
kv_heads=num_kv_heads,
|
78
|
+
eps=1e-5,
|
79
|
+
bias=False,
|
80
|
+
out_bias=False,
|
81
|
+
processor=LuminaAttnProcessor2_0(),
|
82
|
+
)
|
83
|
+
self.attn1.to_out = nn.Identity()
|
84
|
+
|
85
|
+
# Cross-attention
|
86
|
+
self.attn2 = Attention(
|
87
|
+
query_dim=dim,
|
88
|
+
cross_attention_dim=cross_attention_dim,
|
89
|
+
dim_head=dim // num_attention_heads,
|
90
|
+
qk_norm="layer_norm_across_heads" if qk_norm else None,
|
91
|
+
heads=num_attention_heads,
|
92
|
+
kv_heads=num_kv_heads,
|
93
|
+
eps=1e-5,
|
94
|
+
bias=False,
|
95
|
+
out_bias=False,
|
96
|
+
processor=LuminaAttnProcessor2_0(),
|
97
|
+
)
|
98
|
+
|
99
|
+
self.feed_forward = LuminaFeedForward(
|
100
|
+
dim=dim,
|
101
|
+
inner_dim=4 * dim,
|
102
|
+
multiple_of=multiple_of,
|
103
|
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
104
|
+
)
|
105
|
+
|
106
|
+
self.norm1 = LuminaRMSNormZero(
|
107
|
+
embedding_dim=dim,
|
108
|
+
norm_eps=norm_eps,
|
109
|
+
norm_elementwise_affine=norm_elementwise_affine,
|
110
|
+
)
|
111
|
+
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
112
|
+
|
113
|
+
self.norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
114
|
+
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
115
|
+
|
116
|
+
self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
117
|
+
|
118
|
+
def forward(
|
119
|
+
self,
|
120
|
+
hidden_states: torch.Tensor,
|
121
|
+
attention_mask: torch.Tensor,
|
122
|
+
image_rotary_emb: torch.Tensor,
|
123
|
+
encoder_hidden_states: torch.Tensor,
|
124
|
+
encoder_mask: torch.Tensor,
|
125
|
+
temb: torch.Tensor,
|
126
|
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
127
|
+
):
|
128
|
+
"""
|
129
|
+
Perform a forward pass through the LuminaNextDiTBlock.
|
130
|
+
|
131
|
+
Parameters:
|
132
|
+
hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock.
|
133
|
+
attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask.
|
134
|
+
image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies.
|
135
|
+
encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder.
|
136
|
+
encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask.
|
137
|
+
temb (`torch.Tensor`): Timestep embedding with text prompt embedding.
|
138
|
+
cross_attention_kwargs (`Dict[str, Any]`): kwargs for cross attention.
|
139
|
+
"""
|
140
|
+
residual = hidden_states
|
141
|
+
|
142
|
+
# Self-attention
|
143
|
+
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
144
|
+
self_attn_output = self.attn1(
|
145
|
+
hidden_states=norm_hidden_states,
|
146
|
+
encoder_hidden_states=norm_hidden_states,
|
147
|
+
attention_mask=attention_mask,
|
148
|
+
query_rotary_emb=image_rotary_emb,
|
149
|
+
key_rotary_emb=image_rotary_emb,
|
150
|
+
**cross_attention_kwargs,
|
151
|
+
)
|
152
|
+
|
153
|
+
# Cross-attention
|
154
|
+
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
|
155
|
+
cross_attn_output = self.attn2(
|
156
|
+
hidden_states=norm_hidden_states,
|
157
|
+
encoder_hidden_states=norm_encoder_hidden_states,
|
158
|
+
attention_mask=encoder_mask,
|
159
|
+
query_rotary_emb=image_rotary_emb,
|
160
|
+
key_rotary_emb=None,
|
161
|
+
**cross_attention_kwargs,
|
162
|
+
)
|
163
|
+
cross_attn_output = cross_attn_output * self.gate.tanh().view(1, 1, -1, 1)
|
164
|
+
mixed_attn_output = self_attn_output + cross_attn_output
|
165
|
+
mixed_attn_output = mixed_attn_output.flatten(-2)
|
166
|
+
# linear proj
|
167
|
+
hidden_states = self.attn2.to_out[0](mixed_attn_output)
|
168
|
+
|
169
|
+
hidden_states = residual + gate_msa.unsqueeze(1).tanh() * self.norm2(hidden_states)
|
170
|
+
|
171
|
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
172
|
+
|
173
|
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
174
|
+
|
175
|
+
return hidden_states
|
176
|
+
|
177
|
+
|
178
|
+
class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
|
179
|
+
"""
|
180
|
+
LuminaNextDiT: Diffusion model with a Transformer backbone.
|
181
|
+
|
182
|
+
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
183
|
+
|
184
|
+
Parameters:
|
185
|
+
sample_size (`int`): The width of the latent images. This is fixed during training since
|
186
|
+
it is used to learn a number of position embeddings.
|
187
|
+
patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
|
188
|
+
The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
|
189
|
+
in_channels (`int`, *optional*, defaults to 4):
|
190
|
+
The number of input channels for the model. Typically, this matches the number of channels in the input
|
191
|
+
images.
|
192
|
+
hidden_size (`int`, *optional*, defaults to 4096):
|
193
|
+
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
194
|
+
hidden representations.
|
195
|
+
num_layers (`int`, *optional*, default to 32):
|
196
|
+
The number of layers in the model. This defines the depth of the neural network.
|
197
|
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
198
|
+
The number of attention heads in each attention layer. This parameter specifies how many separate attention
|
199
|
+
mechanisms are used.
|
200
|
+
num_kv_heads (`int`, *optional*, defaults to 8):
|
201
|
+
The number of key-value heads in the attention mechanism, if different from the number of attention heads.
|
202
|
+
If None, it defaults to num_attention_heads.
|
203
|
+
multiple_of (`int`, *optional*, defaults to 256):
|
204
|
+
A factor that the hidden size should be a multiple of. This can help optimize certain hardware
|
205
|
+
configurations.
|
206
|
+
ffn_dim_multiplier (`float`, *optional*):
|
207
|
+
A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
|
208
|
+
the model configuration.
|
209
|
+
norm_eps (`float`, *optional*, defaults to 1e-5):
|
210
|
+
A small value added to the denominator for numerical stability in normalization layers.
|
211
|
+
learn_sigma (`bool`, *optional*, defaults to True):
|
212
|
+
Whether the model should learn the sigma parameter, which might be related to uncertainty or variance in
|
213
|
+
predictions.
|
214
|
+
qk_norm (`bool`, *optional*, defaults to True):
|
215
|
+
Indicates if the queries and keys in the attention mechanism should be normalized.
|
216
|
+
cross_attention_dim (`int`, *optional*, defaults to 2048):
|
217
|
+
The dimensionality of the text embeddings. This parameter defines the size of the text representations used
|
218
|
+
in the model.
|
219
|
+
scaling_factor (`float`, *optional*, defaults to 1.0):
|
220
|
+
A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
|
221
|
+
overall scale of the model's operations.
|
222
|
+
"""
|
223
|
+
|
224
|
+
@register_to_config
|
225
|
+
def __init__(
|
226
|
+
self,
|
227
|
+
sample_size: int = 128,
|
228
|
+
patch_size: Optional[int] = 2,
|
229
|
+
in_channels: Optional[int] = 4,
|
230
|
+
hidden_size: Optional[int] = 2304,
|
231
|
+
num_layers: Optional[int] = 32,
|
232
|
+
num_attention_heads: Optional[int] = 32,
|
233
|
+
num_kv_heads: Optional[int] = None,
|
234
|
+
multiple_of: Optional[int] = 256,
|
235
|
+
ffn_dim_multiplier: Optional[float] = None,
|
236
|
+
norm_eps: Optional[float] = 1e-5,
|
237
|
+
learn_sigma: Optional[bool] = True,
|
238
|
+
qk_norm: Optional[bool] = True,
|
239
|
+
cross_attention_dim: Optional[int] = 2048,
|
240
|
+
scaling_factor: Optional[float] = 1.0,
|
241
|
+
) -> None:
|
242
|
+
super().__init__()
|
243
|
+
self.sample_size = sample_size
|
244
|
+
self.patch_size = patch_size
|
245
|
+
self.in_channels = in_channels
|
246
|
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
247
|
+
self.hidden_size = hidden_size
|
248
|
+
self.num_attention_heads = num_attention_heads
|
249
|
+
self.head_dim = hidden_size // num_attention_heads
|
250
|
+
self.scaling_factor = scaling_factor
|
251
|
+
|
252
|
+
self.patch_embedder = LuminaPatchEmbed(
|
253
|
+
patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, bias=True
|
254
|
+
)
|
255
|
+
|
256
|
+
self.pad_token = nn.Parameter(torch.empty(hidden_size))
|
257
|
+
|
258
|
+
self.time_caption_embed = LuminaCombinedTimestepCaptionEmbedding(
|
259
|
+
hidden_size=min(hidden_size, 1024), cross_attention_dim=cross_attention_dim
|
260
|
+
)
|
261
|
+
|
262
|
+
self.layers = nn.ModuleList(
|
263
|
+
[
|
264
|
+
LuminaNextDiTBlock(
|
265
|
+
hidden_size,
|
266
|
+
num_attention_heads,
|
267
|
+
num_kv_heads,
|
268
|
+
multiple_of,
|
269
|
+
ffn_dim_multiplier,
|
270
|
+
norm_eps,
|
271
|
+
qk_norm,
|
272
|
+
cross_attention_dim,
|
273
|
+
)
|
274
|
+
for _ in range(num_layers)
|
275
|
+
]
|
276
|
+
)
|
277
|
+
self.norm_out = LuminaLayerNormContinuous(
|
278
|
+
embedding_dim=hidden_size,
|
279
|
+
conditioning_embedding_dim=min(hidden_size, 1024),
|
280
|
+
elementwise_affine=False,
|
281
|
+
eps=1e-6,
|
282
|
+
bias=True,
|
283
|
+
out_dim=patch_size * patch_size * self.out_channels,
|
284
|
+
)
|
285
|
+
# self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels)
|
286
|
+
|
287
|
+
assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
|
288
|
+
|
289
|
+
def forward(
|
290
|
+
self,
|
291
|
+
hidden_states: torch.Tensor,
|
292
|
+
timestep: torch.Tensor,
|
293
|
+
encoder_hidden_states: torch.Tensor,
|
294
|
+
encoder_mask: torch.Tensor,
|
295
|
+
image_rotary_emb: torch.Tensor,
|
296
|
+
cross_attention_kwargs: Dict[str, Any] = None,
|
297
|
+
return_dict=True,
|
298
|
+
) -> torch.Tensor:
|
299
|
+
"""
|
300
|
+
Forward pass of LuminaNextDiT.
|
301
|
+
|
302
|
+
Parameters:
|
303
|
+
hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W).
|
304
|
+
timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,).
|
305
|
+
encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D).
|
306
|
+
encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L).
|
307
|
+
"""
|
308
|
+
hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)
|
309
|
+
image_rotary_emb = image_rotary_emb.to(hidden_states.device)
|
310
|
+
|
311
|
+
temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask)
|
312
|
+
|
313
|
+
encoder_mask = encoder_mask.bool()
|
314
|
+
for layer in self.layers:
|
315
|
+
hidden_states = layer(
|
316
|
+
hidden_states,
|
317
|
+
mask,
|
318
|
+
image_rotary_emb,
|
319
|
+
encoder_hidden_states,
|
320
|
+
encoder_mask,
|
321
|
+
temb=temb,
|
322
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
323
|
+
)
|
324
|
+
|
325
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
326
|
+
|
327
|
+
# unpatchify
|
328
|
+
height_tokens = width_tokens = self.patch_size
|
329
|
+
height, width = img_size[0]
|
330
|
+
batch_size = hidden_states.size(0)
|
331
|
+
sequence_length = (height // height_tokens) * (width // width_tokens)
|
332
|
+
hidden_states = hidden_states[:, :sequence_length].view(
|
333
|
+
batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels
|
334
|
+
)
|
335
|
+
output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
|
336
|
+
|
337
|
+
if not return_dict:
|
338
|
+
return (output,)
|
339
|
+
|
340
|
+
return Transformer2DModelOutput(sample=output)
|