diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +94 -3
- diffusers/commands/env.py +1 -5
- diffusers/configuration_utils.py +4 -9
- diffusers/dependency_versions_table.py +2 -2
- diffusers/image_processor.py +1 -2
- diffusers/loaders/__init__.py +17 -2
- diffusers/loaders/ip_adapter.py +10 -7
- diffusers/loaders/lora_base.py +752 -0
- diffusers/loaders/lora_pipeline.py +2222 -0
- diffusers/loaders/peft.py +213 -5
- diffusers/loaders/single_file.py +1 -12
- diffusers/loaders/single_file_model.py +31 -10
- diffusers/loaders/single_file_utils.py +262 -2
- diffusers/loaders/textual_inversion.py +1 -6
- diffusers/loaders/unet.py +23 -208
- diffusers/models/__init__.py +20 -0
- diffusers/models/activations.py +22 -0
- diffusers/models/attention.py +386 -7
- diffusers/models/attention_processor.py +1795 -629
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_kl.py +14 -3
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vq_model.py +4 -4
- diffusers/models/controlnet.py +2 -3
- diffusers/models/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnet_sd3.py +11 -11
- diffusers/models/controlnet_sparsectrl.py +789 -0
- diffusers/models/controlnet_xs.py +40 -10
- diffusers/models/downsampling.py +68 -0
- diffusers/models/embeddings.py +319 -36
- diffusers/models/model_loading_utils.py +1 -3
- diffusers/models/modeling_flax_utils.py +1 -6
- diffusers/models/modeling_utils.py +4 -16
- diffusers/models/normalization.py +203 -12
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +102 -1
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/transformer_flux.py +455 -0
- diffusers/models/transformers/transformer_sd3.py +18 -4
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +8 -1
- diffusers/models/unets/unet_3d_blocks.py +51 -920
- diffusers/models/unets/unet_3d_condition.py +4 -1
- diffusers/models/unets/unet_i2vgen_xl.py +4 -1
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +1330 -84
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +1 -3
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +64 -0
- diffusers/models/vq_model.py +8 -4
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +100 -3
- diffusers/pipelines/animatediff/__init__.py +4 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
- diffusers/pipelines/auto_pipeline.py +97 -19
- diffusers/pipelines/cogvideo/__init__.py +48 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +749 -0
- diffusers/pipelines/flux/pipeline_output.py +21 -0
- diffusers/pipelines/free_init_utils.py +2 -0
- diffusers/pipelines/free_noise_utils.py +236 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +334 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
- diffusers/pipelines/pag/__init__.py +67 -0
- diffusers/pipelines/pag/pag_utils.py +237 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
- diffusers/pipelines/pia/pipeline_pia.py +30 -37
- diffusers/pipelines/pipeline_flax_utils.py +4 -9
- diffusers/pipelines/pipeline_loading_utils.py +0 -3
- diffusers/pipelines/pipeline_utils.py +2 -14
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
- diffusers/schedulers/__init__.py +8 -0
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +1 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
- diffusers/schedulers/scheduling_ddpm.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +2 -2
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
- diffusers/schedulers/scheduling_ipndm.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
- diffusers/schedulers/scheduling_utils.py +1 -3
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/training_utils.py +99 -14
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +210 -0
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
- diffusers/utils/dynamic_modules_utils.py +1 -11
- diffusers/utils/export_utils.py +1 -4
- diffusers/utils/hub_utils.py +45 -42
- diffusers/utils/import_utils.py +19 -16
- diffusers/utils/loading_utils.py +76 -3
- diffusers/utils/testing_utils.py +11 -8
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1728
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,10 @@ import torch.nn.functional as F
|
|
22
22
|
|
23
23
|
from ..utils import is_torch_version
|
24
24
|
from .activations import get_activation
|
25
|
-
from .embeddings import
|
25
|
+
from .embeddings import (
|
26
|
+
CombinedTimestepLabelEmbeddings,
|
27
|
+
PixArtAlphaCombinedTimestepSizeEmbeddings,
|
28
|
+
)
|
26
29
|
|
27
30
|
|
28
31
|
class AdaLayerNorm(nn.Module):
|
@@ -31,23 +34,69 @@ class AdaLayerNorm(nn.Module):
|
|
31
34
|
|
32
35
|
Parameters:
|
33
36
|
embedding_dim (`int`): The size of each embedding vector.
|
34
|
-
num_embeddings (`int
|
37
|
+
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
|
38
|
+
output_dim (`int`, *optional*):
|
39
|
+
norm_elementwise_affine (`bool`, defaults to `False):
|
40
|
+
norm_eps (`bool`, defaults to `False`):
|
41
|
+
chunk_dim (`int`, defaults to `0`):
|
35
42
|
"""
|
36
43
|
|
37
|
-
def __init__(
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
embedding_dim: int,
|
47
|
+
num_embeddings: Optional[int] = None,
|
48
|
+
output_dim: Optional[int] = None,
|
49
|
+
norm_elementwise_affine: bool = False,
|
50
|
+
norm_eps: float = 1e-5,
|
51
|
+
chunk_dim: int = 0,
|
52
|
+
):
|
38
53
|
super().__init__()
|
39
|
-
|
54
|
+
|
55
|
+
self.chunk_dim = chunk_dim
|
56
|
+
output_dim = output_dim or embedding_dim * 2
|
57
|
+
|
58
|
+
if num_embeddings is not None:
|
59
|
+
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
60
|
+
else:
|
61
|
+
self.emb = None
|
62
|
+
|
40
63
|
self.silu = nn.SiLU()
|
41
|
-
self.linear = nn.Linear(embedding_dim,
|
42
|
-
self.norm = nn.LayerNorm(
|
64
|
+
self.linear = nn.Linear(embedding_dim, output_dim)
|
65
|
+
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
66
|
+
|
67
|
+
def forward(
|
68
|
+
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
|
69
|
+
) -> torch.Tensor:
|
70
|
+
if self.emb is not None:
|
71
|
+
temb = self.emb(timestep)
|
72
|
+
|
73
|
+
temb = self.linear(self.silu(temb))
|
74
|
+
|
75
|
+
if self.chunk_dim == 1:
|
76
|
+
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
|
77
|
+
# other if-branch. This branch is specific to CogVideoX for now.
|
78
|
+
shift, scale = temb.chunk(2, dim=1)
|
79
|
+
shift = shift[:, None, :]
|
80
|
+
scale = scale[:, None, :]
|
81
|
+
else:
|
82
|
+
scale, shift = temb.chunk(2, dim=0)
|
43
83
|
|
44
|
-
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
|
45
|
-
emb = self.linear(self.silu(self.emb(timestep)))
|
46
|
-
scale, shift = torch.chunk(emb, 2)
|
47
84
|
x = self.norm(x) * (1 + scale) + shift
|
48
85
|
return x
|
49
86
|
|
50
87
|
|
88
|
+
class FP32LayerNorm(nn.LayerNorm):
|
89
|
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
90
|
+
origin_dtype = inputs.dtype
|
91
|
+
return F.layer_norm(
|
92
|
+
inputs.float(),
|
93
|
+
self.normalized_shape,
|
94
|
+
self.weight.float() if self.weight is not None else None,
|
95
|
+
self.bias.float() if self.bias is not None else None,
|
96
|
+
self.eps,
|
97
|
+
).to(origin_dtype)
|
98
|
+
|
99
|
+
|
51
100
|
class AdaLayerNormZero(nn.Module):
|
52
101
|
r"""
|
53
102
|
Norm layer adaptive layer norm zero (adaLN-Zero).
|
@@ -57,7 +106,7 @@ class AdaLayerNormZero(nn.Module):
|
|
57
106
|
num_embeddings (`int`): The size of the embeddings dictionary.
|
58
107
|
"""
|
59
108
|
|
60
|
-
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None):
|
109
|
+
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
|
61
110
|
super().__init__()
|
62
111
|
if num_embeddings is not None:
|
63
112
|
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
@@ -65,8 +114,15 @@ class AdaLayerNormZero(nn.Module):
|
|
65
114
|
self.emb = None
|
66
115
|
|
67
116
|
self.silu = nn.SiLU()
|
68
|
-
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=
|
69
|
-
|
117
|
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
|
118
|
+
if norm_type == "layer_norm":
|
119
|
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
120
|
+
elif norm_type == "fp32_layer_norm":
|
121
|
+
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
|
122
|
+
else:
|
123
|
+
raise ValueError(
|
124
|
+
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
125
|
+
)
|
70
126
|
|
71
127
|
def forward(
|
72
128
|
self,
|
@@ -84,6 +140,69 @@ class AdaLayerNormZero(nn.Module):
|
|
84
140
|
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
85
141
|
|
86
142
|
|
143
|
+
class AdaLayerNormZeroSingle(nn.Module):
|
144
|
+
r"""
|
145
|
+
Norm layer adaptive layer norm zero (adaLN-Zero).
|
146
|
+
|
147
|
+
Parameters:
|
148
|
+
embedding_dim (`int`): The size of each embedding vector.
|
149
|
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
150
|
+
"""
|
151
|
+
|
152
|
+
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
|
153
|
+
super().__init__()
|
154
|
+
|
155
|
+
self.silu = nn.SiLU()
|
156
|
+
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
|
157
|
+
if norm_type == "layer_norm":
|
158
|
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
159
|
+
else:
|
160
|
+
raise ValueError(
|
161
|
+
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
162
|
+
)
|
163
|
+
|
164
|
+
def forward(
|
165
|
+
self,
|
166
|
+
x: torch.Tensor,
|
167
|
+
emb: Optional[torch.Tensor] = None,
|
168
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
169
|
+
emb = self.linear(self.silu(emb))
|
170
|
+
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
171
|
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
172
|
+
return x, gate_msa
|
173
|
+
|
174
|
+
|
175
|
+
class LuminaRMSNormZero(nn.Module):
|
176
|
+
"""
|
177
|
+
Norm layer adaptive RMS normalization zero.
|
178
|
+
|
179
|
+
Parameters:
|
180
|
+
embedding_dim (`int`): The size of each embedding vector.
|
181
|
+
"""
|
182
|
+
|
183
|
+
def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool):
|
184
|
+
super().__init__()
|
185
|
+
self.silu = nn.SiLU()
|
186
|
+
self.linear = nn.Linear(
|
187
|
+
min(embedding_dim, 1024),
|
188
|
+
4 * embedding_dim,
|
189
|
+
bias=True,
|
190
|
+
)
|
191
|
+
self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
192
|
+
|
193
|
+
def forward(
|
194
|
+
self,
|
195
|
+
x: torch.Tensor,
|
196
|
+
emb: Optional[torch.Tensor] = None,
|
197
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
198
|
+
# emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
|
199
|
+
emb = self.linear(self.silu(emb))
|
200
|
+
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
201
|
+
x = self.norm(x) * (1 + scale_msa[:, None])
|
202
|
+
|
203
|
+
return x, gate_msa, scale_mlp, gate_mlp
|
204
|
+
|
205
|
+
|
87
206
|
class AdaLayerNormSingle(nn.Module):
|
88
207
|
r"""
|
89
208
|
Norm layer adaptive layer norm single (adaLN-single).
|
@@ -188,6 +307,78 @@ class AdaLayerNormContinuous(nn.Module):
|
|
188
307
|
return x
|
189
308
|
|
190
309
|
|
310
|
+
class LuminaLayerNormContinuous(nn.Module):
|
311
|
+
def __init__(
|
312
|
+
self,
|
313
|
+
embedding_dim: int,
|
314
|
+
conditioning_embedding_dim: int,
|
315
|
+
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
316
|
+
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
317
|
+
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
318
|
+
# However, this is how it was implemented in the original code, and it's rather likely you should
|
319
|
+
# set `elementwise_affine` to False.
|
320
|
+
elementwise_affine=True,
|
321
|
+
eps=1e-5,
|
322
|
+
bias=True,
|
323
|
+
norm_type="layer_norm",
|
324
|
+
out_dim: Optional[int] = None,
|
325
|
+
):
|
326
|
+
super().__init__()
|
327
|
+
# AdaLN
|
328
|
+
self.silu = nn.SiLU()
|
329
|
+
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
|
330
|
+
if norm_type == "layer_norm":
|
331
|
+
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
332
|
+
else:
|
333
|
+
raise ValueError(f"unknown norm_type {norm_type}")
|
334
|
+
# linear_2
|
335
|
+
if out_dim is not None:
|
336
|
+
self.linear_2 = nn.Linear(
|
337
|
+
embedding_dim,
|
338
|
+
out_dim,
|
339
|
+
bias=bias,
|
340
|
+
)
|
341
|
+
|
342
|
+
def forward(
|
343
|
+
self,
|
344
|
+
x: torch.Tensor,
|
345
|
+
conditioning_embedding: torch.Tensor,
|
346
|
+
) -> torch.Tensor:
|
347
|
+
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
348
|
+
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
349
|
+
scale = emb
|
350
|
+
x = self.norm(x) * (1 + scale)[:, None, :]
|
351
|
+
|
352
|
+
if self.linear_2 is not None:
|
353
|
+
x = self.linear_2(x)
|
354
|
+
|
355
|
+
return x
|
356
|
+
|
357
|
+
|
358
|
+
class CogVideoXLayerNormZero(nn.Module):
|
359
|
+
def __init__(
|
360
|
+
self,
|
361
|
+
conditioning_dim: int,
|
362
|
+
embedding_dim: int,
|
363
|
+
elementwise_affine: bool = True,
|
364
|
+
eps: float = 1e-5,
|
365
|
+
bias: bool = True,
|
366
|
+
) -> None:
|
367
|
+
super().__init__()
|
368
|
+
|
369
|
+
self.silu = nn.SiLU()
|
370
|
+
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
|
371
|
+
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
372
|
+
|
373
|
+
def forward(
|
374
|
+
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
375
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
376
|
+
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
|
377
|
+
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
378
|
+
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
|
379
|
+
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
|
380
|
+
|
381
|
+
|
191
382
|
if is_torch_version(">=", "2.1.0"):
|
192
383
|
LayerNorm = nn.LayerNorm
|
193
384
|
else:
|
@@ -2,12 +2,18 @@ from ...utils import is_torch_available
|
|
2
2
|
|
3
3
|
|
4
4
|
if is_torch_available():
|
5
|
+
from .auraflow_transformer_2d import AuraFlowTransformer2DModel
|
6
|
+
from .cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
5
7
|
from .dit_transformer_2d import DiTTransformer2DModel
|
6
8
|
from .dual_transformer_2d import DualTransformer2DModel
|
7
9
|
from .hunyuan_transformer_2d import HunyuanDiT2DModel
|
10
|
+
from .latte_transformer_3d import LatteTransformer3DModel
|
11
|
+
from .lumina_nextdit2d import LuminaNextDiT2DModel
|
8
12
|
from .pixart_transformer_2d import PixArtTransformer2DModel
|
9
13
|
from .prior_transformer import PriorTransformer
|
14
|
+
from .stable_audio_transformer import StableAudioDiTModel
|
10
15
|
from .t5_film_transformer import T5FilmDecoder
|
11
16
|
from .transformer_2d import Transformer2DModel
|
17
|
+
from .transformer_flux import FluxTransformer2DModel
|
12
18
|
from .transformer_sd3 import SD3Transformer2DModel
|
13
19
|
from .transformer_temporal import TransformerTemporalModel
|