diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +94 -3
- diffusers/commands/env.py +1 -5
- diffusers/configuration_utils.py +4 -9
- diffusers/dependency_versions_table.py +2 -2
- diffusers/image_processor.py +1 -2
- diffusers/loaders/__init__.py +17 -2
- diffusers/loaders/ip_adapter.py +10 -7
- diffusers/loaders/lora_base.py +752 -0
- diffusers/loaders/lora_pipeline.py +2222 -0
- diffusers/loaders/peft.py +213 -5
- diffusers/loaders/single_file.py +1 -12
- diffusers/loaders/single_file_model.py +31 -10
- diffusers/loaders/single_file_utils.py +262 -2
- diffusers/loaders/textual_inversion.py +1 -6
- diffusers/loaders/unet.py +23 -208
- diffusers/models/__init__.py +20 -0
- diffusers/models/activations.py +22 -0
- diffusers/models/attention.py +386 -7
- diffusers/models/attention_processor.py +1795 -629
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_kl.py +14 -3
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vq_model.py +4 -4
- diffusers/models/controlnet.py +2 -3
- diffusers/models/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnet_sd3.py +11 -11
- diffusers/models/controlnet_sparsectrl.py +789 -0
- diffusers/models/controlnet_xs.py +40 -10
- diffusers/models/downsampling.py +68 -0
- diffusers/models/embeddings.py +319 -36
- diffusers/models/model_loading_utils.py +1 -3
- diffusers/models/modeling_flax_utils.py +1 -6
- diffusers/models/modeling_utils.py +4 -16
- diffusers/models/normalization.py +203 -12
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +102 -1
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/transformer_flux.py +455 -0
- diffusers/models/transformers/transformer_sd3.py +18 -4
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +8 -1
- diffusers/models/unets/unet_3d_blocks.py +51 -920
- diffusers/models/unets/unet_3d_condition.py +4 -1
- diffusers/models/unets/unet_i2vgen_xl.py +4 -1
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +1330 -84
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +1 -3
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +64 -0
- diffusers/models/vq_model.py +8 -4
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +100 -3
- diffusers/pipelines/animatediff/__init__.py +4 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
- diffusers/pipelines/auto_pipeline.py +97 -19
- diffusers/pipelines/cogvideo/__init__.py +48 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +749 -0
- diffusers/pipelines/flux/pipeline_output.py +21 -0
- diffusers/pipelines/free_init_utils.py +2 -0
- diffusers/pipelines/free_noise_utils.py +236 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +334 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
- diffusers/pipelines/pag/__init__.py +67 -0
- diffusers/pipelines/pag/pag_utils.py +237 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
- diffusers/pipelines/pia/pipeline_pia.py +30 -37
- diffusers/pipelines/pipeline_flax_utils.py +4 -9
- diffusers/pipelines/pipeline_loading_utils.py +0 -3
- diffusers/pipelines/pipeline_utils.py +2 -14
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
- diffusers/schedulers/__init__.py +8 -0
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +1 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
- diffusers/schedulers/scheduling_ddpm.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +2 -2
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
- diffusers/schedulers/scheduling_ipndm.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
- diffusers/schedulers/scheduling_utils.py +1 -3
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/training_utils.py +99 -14
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +210 -0
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
- diffusers/utils/dynamic_modules_utils.py +1 -11
- diffusers/utils/export_utils.py +1 -4
- diffusers/utils/hub_utils.py +45 -42
- diffusers/utils/import_utils.py +19 -16
- diffusers/utils/loading_utils.py +76 -3
- diffusers/utils/testing_utils.py +11 -8
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1728
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,527 @@
|
|
1
|
+
# Copyright 2024 AuraFlow Authors, The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from typing import Any, Dict, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.nn as nn
|
20
|
+
import torch.nn.functional as F
|
21
|
+
|
22
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
+
from ...utils import is_torch_version, logging
|
24
|
+
from ...utils.torch_utils import maybe_allow_in_graph
|
25
|
+
from ..attention_processor import (
|
26
|
+
Attention,
|
27
|
+
AttentionProcessor,
|
28
|
+
AuraFlowAttnProcessor2_0,
|
29
|
+
FusedAuraFlowAttnProcessor2_0,
|
30
|
+
)
|
31
|
+
from ..embeddings import TimestepEmbedding, Timesteps
|
32
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
33
|
+
from ..modeling_utils import ModelMixin
|
34
|
+
from ..normalization import AdaLayerNormZero, FP32LayerNorm
|
35
|
+
|
36
|
+
|
37
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
38
|
+
|
39
|
+
|
40
|
+
# Taken from the original aura flow inference code.
|
41
|
+
def find_multiple(n: int, k: int) -> int:
|
42
|
+
if n % k == 0:
|
43
|
+
return n
|
44
|
+
return n + k - (n % k)
|
45
|
+
|
46
|
+
|
47
|
+
# Aura Flow patch embed doesn't use convs for projections.
|
48
|
+
# Additionally, it uses learned positional embeddings.
|
49
|
+
class AuraFlowPatchEmbed(nn.Module):
|
50
|
+
def __init__(
|
51
|
+
self,
|
52
|
+
height=224,
|
53
|
+
width=224,
|
54
|
+
patch_size=16,
|
55
|
+
in_channels=3,
|
56
|
+
embed_dim=768,
|
57
|
+
pos_embed_max_size=None,
|
58
|
+
):
|
59
|
+
super().__init__()
|
60
|
+
|
61
|
+
self.num_patches = (height // patch_size) * (width // patch_size)
|
62
|
+
self.pos_embed_max_size = pos_embed_max_size
|
63
|
+
|
64
|
+
self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim)
|
65
|
+
self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1)
|
66
|
+
|
67
|
+
self.patch_size = patch_size
|
68
|
+
self.height, self.width = height // patch_size, width // patch_size
|
69
|
+
self.base_size = height // patch_size
|
70
|
+
|
71
|
+
def forward(self, latent):
|
72
|
+
batch_size, num_channels, height, width = latent.size()
|
73
|
+
latent = latent.view(
|
74
|
+
batch_size,
|
75
|
+
num_channels,
|
76
|
+
height // self.patch_size,
|
77
|
+
self.patch_size,
|
78
|
+
width // self.patch_size,
|
79
|
+
self.patch_size,
|
80
|
+
)
|
81
|
+
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
82
|
+
latent = self.proj(latent)
|
83
|
+
return latent + self.pos_embed
|
84
|
+
|
85
|
+
|
86
|
+
# Taken from the original Aura flow inference code.
|
87
|
+
# Our feedforward only has GELU but Aura uses SiLU.
|
88
|
+
class AuraFlowFeedForward(nn.Module):
|
89
|
+
def __init__(self, dim, hidden_dim=None) -> None:
|
90
|
+
super().__init__()
|
91
|
+
if hidden_dim is None:
|
92
|
+
hidden_dim = 4 * dim
|
93
|
+
|
94
|
+
final_hidden_dim = int(2 * hidden_dim / 3)
|
95
|
+
final_hidden_dim = find_multiple(final_hidden_dim, 256)
|
96
|
+
|
97
|
+
self.linear_1 = nn.Linear(dim, final_hidden_dim, bias=False)
|
98
|
+
self.linear_2 = nn.Linear(dim, final_hidden_dim, bias=False)
|
99
|
+
self.out_projection = nn.Linear(final_hidden_dim, dim, bias=False)
|
100
|
+
|
101
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
102
|
+
x = F.silu(self.linear_1(x)) * self.linear_2(x)
|
103
|
+
x = self.out_projection(x)
|
104
|
+
return x
|
105
|
+
|
106
|
+
|
107
|
+
class AuraFlowPreFinalBlock(nn.Module):
|
108
|
+
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int):
|
109
|
+
super().__init__()
|
110
|
+
|
111
|
+
self.silu = nn.SiLU()
|
112
|
+
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=False)
|
113
|
+
|
114
|
+
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
115
|
+
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
|
116
|
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
117
|
+
x = x * (1 + scale)[:, None, :] + shift[:, None, :]
|
118
|
+
return x
|
119
|
+
|
120
|
+
|
121
|
+
@maybe_allow_in_graph
|
122
|
+
class AuraFlowSingleTransformerBlock(nn.Module):
|
123
|
+
"""Similar to `AuraFlowJointTransformerBlock` with a single DiT instead of an MMDiT."""
|
124
|
+
|
125
|
+
def __init__(self, dim, num_attention_heads, attention_head_dim):
|
126
|
+
super().__init__()
|
127
|
+
|
128
|
+
self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
|
129
|
+
|
130
|
+
processor = AuraFlowAttnProcessor2_0()
|
131
|
+
self.attn = Attention(
|
132
|
+
query_dim=dim,
|
133
|
+
cross_attention_dim=None,
|
134
|
+
dim_head=attention_head_dim,
|
135
|
+
heads=num_attention_heads,
|
136
|
+
qk_norm="fp32_layer_norm",
|
137
|
+
out_dim=dim,
|
138
|
+
bias=False,
|
139
|
+
out_bias=False,
|
140
|
+
processor=processor,
|
141
|
+
)
|
142
|
+
|
143
|
+
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
|
144
|
+
self.ff = AuraFlowFeedForward(dim, dim * 4)
|
145
|
+
|
146
|
+
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor):
|
147
|
+
residual = hidden_states
|
148
|
+
|
149
|
+
# Norm + Projection.
|
150
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
151
|
+
|
152
|
+
# Attention.
|
153
|
+
attn_output = self.attn(hidden_states=norm_hidden_states)
|
154
|
+
|
155
|
+
# Process attention outputs for the `hidden_states`.
|
156
|
+
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
|
157
|
+
hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
158
|
+
ff_output = self.ff(hidden_states)
|
159
|
+
hidden_states = gate_mlp.unsqueeze(1) * ff_output
|
160
|
+
hidden_states = residual + hidden_states
|
161
|
+
|
162
|
+
return hidden_states
|
163
|
+
|
164
|
+
|
165
|
+
@maybe_allow_in_graph
|
166
|
+
class AuraFlowJointTransformerBlock(nn.Module):
|
167
|
+
r"""
|
168
|
+
Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive):
|
169
|
+
|
170
|
+
* QK Norm in the attention blocks
|
171
|
+
* No bias in the attention blocks
|
172
|
+
* Most LayerNorms are in FP32
|
173
|
+
|
174
|
+
Parameters:
|
175
|
+
dim (`int`): The number of channels in the input and output.
|
176
|
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
177
|
+
attention_head_dim (`int`): The number of channels in each head.
|
178
|
+
is_last (`bool`): Boolean to determine if this is the last block in the model.
|
179
|
+
"""
|
180
|
+
|
181
|
+
def __init__(self, dim, num_attention_heads, attention_head_dim):
|
182
|
+
super().__init__()
|
183
|
+
|
184
|
+
self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
|
185
|
+
self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
|
186
|
+
|
187
|
+
processor = AuraFlowAttnProcessor2_0()
|
188
|
+
self.attn = Attention(
|
189
|
+
query_dim=dim,
|
190
|
+
cross_attention_dim=None,
|
191
|
+
added_kv_proj_dim=dim,
|
192
|
+
added_proj_bias=False,
|
193
|
+
dim_head=attention_head_dim,
|
194
|
+
heads=num_attention_heads,
|
195
|
+
qk_norm="fp32_layer_norm",
|
196
|
+
out_dim=dim,
|
197
|
+
bias=False,
|
198
|
+
out_bias=False,
|
199
|
+
processor=processor,
|
200
|
+
context_pre_only=False,
|
201
|
+
)
|
202
|
+
|
203
|
+
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
|
204
|
+
self.ff = AuraFlowFeedForward(dim, dim * 4)
|
205
|
+
self.norm2_context = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
|
206
|
+
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
|
207
|
+
|
208
|
+
def forward(
|
209
|
+
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
|
210
|
+
):
|
211
|
+
residual = hidden_states
|
212
|
+
residual_context = encoder_hidden_states
|
213
|
+
|
214
|
+
# Norm + Projection.
|
215
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
216
|
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
217
|
+
encoder_hidden_states, emb=temb
|
218
|
+
)
|
219
|
+
|
220
|
+
# Attention.
|
221
|
+
attn_output, context_attn_output = self.attn(
|
222
|
+
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
|
223
|
+
)
|
224
|
+
|
225
|
+
# Process attention outputs for the `hidden_states`.
|
226
|
+
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
|
227
|
+
hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
228
|
+
hidden_states = gate_mlp.unsqueeze(1) * self.ff(hidden_states)
|
229
|
+
hidden_states = residual + hidden_states
|
230
|
+
|
231
|
+
# Process attention outputs for the `encoder_hidden_states`.
|
232
|
+
encoder_hidden_states = self.norm2_context(residual_context + c_gate_msa.unsqueeze(1) * context_attn_output)
|
233
|
+
encoder_hidden_states = encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
234
|
+
encoder_hidden_states = c_gate_mlp.unsqueeze(1) * self.ff_context(encoder_hidden_states)
|
235
|
+
encoder_hidden_states = residual_context + encoder_hidden_states
|
236
|
+
|
237
|
+
return encoder_hidden_states, hidden_states
|
238
|
+
|
239
|
+
|
240
|
+
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
241
|
+
r"""
|
242
|
+
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
|
243
|
+
|
244
|
+
Parameters:
|
245
|
+
sample_size (`int`): The width of the latent images. This is fixed during training since
|
246
|
+
it is used to learn a number of position embeddings.
|
247
|
+
patch_size (`int`): Patch size to turn the input data into small patches.
|
248
|
+
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
249
|
+
num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
|
250
|
+
num_single_dit_layers (`int`, *optional*, defaults to 4):
|
251
|
+
The number of layers of Transformer blocks to use. These blocks use concatenated image and text
|
252
|
+
representations.
|
253
|
+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
254
|
+
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
255
|
+
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
256
|
+
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
|
257
|
+
out_channels (`int`, defaults to 16): Number of output channels.
|
258
|
+
pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
|
259
|
+
"""
|
260
|
+
|
261
|
+
_supports_gradient_checkpointing = True
|
262
|
+
|
263
|
+
@register_to_config
|
264
|
+
def __init__(
|
265
|
+
self,
|
266
|
+
sample_size: int = 64,
|
267
|
+
patch_size: int = 2,
|
268
|
+
in_channels: int = 4,
|
269
|
+
num_mmdit_layers: int = 4,
|
270
|
+
num_single_dit_layers: int = 32,
|
271
|
+
attention_head_dim: int = 256,
|
272
|
+
num_attention_heads: int = 12,
|
273
|
+
joint_attention_dim: int = 2048,
|
274
|
+
caption_projection_dim: int = 3072,
|
275
|
+
out_channels: int = 4,
|
276
|
+
pos_embed_max_size: int = 1024,
|
277
|
+
):
|
278
|
+
super().__init__()
|
279
|
+
default_out_channels = in_channels
|
280
|
+
self.out_channels = out_channels if out_channels is not None else default_out_channels
|
281
|
+
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
282
|
+
|
283
|
+
self.pos_embed = AuraFlowPatchEmbed(
|
284
|
+
height=self.config.sample_size,
|
285
|
+
width=self.config.sample_size,
|
286
|
+
patch_size=self.config.patch_size,
|
287
|
+
in_channels=self.config.in_channels,
|
288
|
+
embed_dim=self.inner_dim,
|
289
|
+
pos_embed_max_size=pos_embed_max_size,
|
290
|
+
)
|
291
|
+
|
292
|
+
self.context_embedder = nn.Linear(
|
293
|
+
self.config.joint_attention_dim, self.config.caption_projection_dim, bias=False
|
294
|
+
)
|
295
|
+
self.time_step_embed = Timesteps(num_channels=256, downscale_freq_shift=0, scale=1000, flip_sin_to_cos=True)
|
296
|
+
self.time_step_proj = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim)
|
297
|
+
|
298
|
+
self.joint_transformer_blocks = nn.ModuleList(
|
299
|
+
[
|
300
|
+
AuraFlowJointTransformerBlock(
|
301
|
+
dim=self.inner_dim,
|
302
|
+
num_attention_heads=self.config.num_attention_heads,
|
303
|
+
attention_head_dim=self.config.attention_head_dim,
|
304
|
+
)
|
305
|
+
for i in range(self.config.num_mmdit_layers)
|
306
|
+
]
|
307
|
+
)
|
308
|
+
self.single_transformer_blocks = nn.ModuleList(
|
309
|
+
[
|
310
|
+
AuraFlowSingleTransformerBlock(
|
311
|
+
dim=self.inner_dim,
|
312
|
+
num_attention_heads=self.config.num_attention_heads,
|
313
|
+
attention_head_dim=self.config.attention_head_dim,
|
314
|
+
)
|
315
|
+
for _ in range(self.config.num_single_dit_layers)
|
316
|
+
]
|
317
|
+
)
|
318
|
+
|
319
|
+
self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim)
|
320
|
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
|
321
|
+
|
322
|
+
# https://arxiv.org/abs/2309.16588
|
323
|
+
# prevents artifacts in the attention maps
|
324
|
+
self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02)
|
325
|
+
|
326
|
+
self.gradient_checkpointing = False
|
327
|
+
|
328
|
+
@property
|
329
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
330
|
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
331
|
+
r"""
|
332
|
+
Returns:
|
333
|
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
334
|
+
indexed by its weight name.
|
335
|
+
"""
|
336
|
+
# set recursively
|
337
|
+
processors = {}
|
338
|
+
|
339
|
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
340
|
+
if hasattr(module, "get_processor"):
|
341
|
+
processors[f"{name}.processor"] = module.get_processor()
|
342
|
+
|
343
|
+
for sub_name, child in module.named_children():
|
344
|
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
345
|
+
|
346
|
+
return processors
|
347
|
+
|
348
|
+
for name, module in self.named_children():
|
349
|
+
fn_recursive_add_processors(name, module, processors)
|
350
|
+
|
351
|
+
return processors
|
352
|
+
|
353
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
354
|
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
355
|
+
r"""
|
356
|
+
Sets the attention processor to use to compute attention.
|
357
|
+
|
358
|
+
Parameters:
|
359
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
360
|
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
361
|
+
for **all** `Attention` layers.
|
362
|
+
|
363
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
364
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
365
|
+
|
366
|
+
"""
|
367
|
+
count = len(self.attn_processors.keys())
|
368
|
+
|
369
|
+
if isinstance(processor, dict) and len(processor) != count:
|
370
|
+
raise ValueError(
|
371
|
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
372
|
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
373
|
+
)
|
374
|
+
|
375
|
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
376
|
+
if hasattr(module, "set_processor"):
|
377
|
+
if not isinstance(processor, dict):
|
378
|
+
module.set_processor(processor)
|
379
|
+
else:
|
380
|
+
module.set_processor(processor.pop(f"{name}.processor"))
|
381
|
+
|
382
|
+
for sub_name, child in module.named_children():
|
383
|
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
384
|
+
|
385
|
+
for name, module in self.named_children():
|
386
|
+
fn_recursive_attn_processor(name, module, processor)
|
387
|
+
|
388
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0
|
389
|
+
def fuse_qkv_projections(self):
|
390
|
+
"""
|
391
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
392
|
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
393
|
+
|
394
|
+
<Tip warning={true}>
|
395
|
+
|
396
|
+
This API is 🧪 experimental.
|
397
|
+
|
398
|
+
</Tip>
|
399
|
+
"""
|
400
|
+
self.original_attn_processors = None
|
401
|
+
|
402
|
+
for _, attn_processor in self.attn_processors.items():
|
403
|
+
if "Added" in str(attn_processor.__class__.__name__):
|
404
|
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
405
|
+
|
406
|
+
self.original_attn_processors = self.attn_processors
|
407
|
+
|
408
|
+
for module in self.modules():
|
409
|
+
if isinstance(module, Attention):
|
410
|
+
module.fuse_projections(fuse=True)
|
411
|
+
|
412
|
+
self.set_attn_processor(FusedAuraFlowAttnProcessor2_0())
|
413
|
+
|
414
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
415
|
+
def unfuse_qkv_projections(self):
|
416
|
+
"""Disables the fused QKV projection if enabled.
|
417
|
+
|
418
|
+
<Tip warning={true}>
|
419
|
+
|
420
|
+
This API is 🧪 experimental.
|
421
|
+
|
422
|
+
</Tip>
|
423
|
+
|
424
|
+
"""
|
425
|
+
if self.original_attn_processors is not None:
|
426
|
+
self.set_attn_processor(self.original_attn_processors)
|
427
|
+
|
428
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
429
|
+
if hasattr(module, "gradient_checkpointing"):
|
430
|
+
module.gradient_checkpointing = value
|
431
|
+
|
432
|
+
def forward(
|
433
|
+
self,
|
434
|
+
hidden_states: torch.FloatTensor,
|
435
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
436
|
+
timestep: torch.LongTensor = None,
|
437
|
+
return_dict: bool = True,
|
438
|
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
439
|
+
height, width = hidden_states.shape[-2:]
|
440
|
+
|
441
|
+
# Apply patch embedding, timestep embedding, and project the caption embeddings.
|
442
|
+
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
443
|
+
temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype)
|
444
|
+
temb = self.time_step_proj(temb)
|
445
|
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
446
|
+
encoder_hidden_states = torch.cat(
|
447
|
+
[self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1
|
448
|
+
)
|
449
|
+
|
450
|
+
# MMDiT blocks.
|
451
|
+
for index_block, block in enumerate(self.joint_transformer_blocks):
|
452
|
+
if self.training and self.gradient_checkpointing:
|
453
|
+
|
454
|
+
def create_custom_forward(module, return_dict=None):
|
455
|
+
def custom_forward(*inputs):
|
456
|
+
if return_dict is not None:
|
457
|
+
return module(*inputs, return_dict=return_dict)
|
458
|
+
else:
|
459
|
+
return module(*inputs)
|
460
|
+
|
461
|
+
return custom_forward
|
462
|
+
|
463
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
464
|
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
465
|
+
create_custom_forward(block),
|
466
|
+
hidden_states,
|
467
|
+
encoder_hidden_states,
|
468
|
+
temb,
|
469
|
+
**ckpt_kwargs,
|
470
|
+
)
|
471
|
+
|
472
|
+
else:
|
473
|
+
encoder_hidden_states, hidden_states = block(
|
474
|
+
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
475
|
+
)
|
476
|
+
|
477
|
+
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
|
478
|
+
if len(self.single_transformer_blocks) > 0:
|
479
|
+
encoder_seq_len = encoder_hidden_states.size(1)
|
480
|
+
combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
481
|
+
|
482
|
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
483
|
+
if self.training and self.gradient_checkpointing:
|
484
|
+
|
485
|
+
def create_custom_forward(module, return_dict=None):
|
486
|
+
def custom_forward(*inputs):
|
487
|
+
if return_dict is not None:
|
488
|
+
return module(*inputs, return_dict=return_dict)
|
489
|
+
else:
|
490
|
+
return module(*inputs)
|
491
|
+
|
492
|
+
return custom_forward
|
493
|
+
|
494
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
495
|
+
combined_hidden_states = torch.utils.checkpoint.checkpoint(
|
496
|
+
create_custom_forward(block),
|
497
|
+
combined_hidden_states,
|
498
|
+
temb,
|
499
|
+
**ckpt_kwargs,
|
500
|
+
)
|
501
|
+
|
502
|
+
else:
|
503
|
+
combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb)
|
504
|
+
|
505
|
+
hidden_states = combined_hidden_states[:, encoder_seq_len:]
|
506
|
+
|
507
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
508
|
+
hidden_states = self.proj_out(hidden_states)
|
509
|
+
|
510
|
+
# unpatchify
|
511
|
+
patch_size = self.config.patch_size
|
512
|
+
out_channels = self.config.out_channels
|
513
|
+
height = height // patch_size
|
514
|
+
width = width // patch_size
|
515
|
+
|
516
|
+
hidden_states = hidden_states.reshape(
|
517
|
+
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, out_channels)
|
518
|
+
)
|
519
|
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
520
|
+
output = hidden_states.reshape(
|
521
|
+
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
|
522
|
+
)
|
523
|
+
|
524
|
+
if not return_dict:
|
525
|
+
return (output,)
|
526
|
+
|
527
|
+
return Transformer2DModelOutput(sample=output)
|