diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +74 -28
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,422 @@
|
|
1
|
+
# Copyright 2024 The RhymesAI and The HuggingFace Team.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from typing import Any, Dict, Optional, Tuple
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.nn as nn
|
20
|
+
import torch.nn.functional as F
|
21
|
+
|
22
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
+
from ...utils import is_torch_version, logging
|
24
|
+
from ...utils.torch_utils import maybe_allow_in_graph
|
25
|
+
from ..attention import FeedForward
|
26
|
+
from ..attention_processor import AllegroAttnProcessor2_0, Attention
|
27
|
+
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
28
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
29
|
+
from ..modeling_utils import ModelMixin
|
30
|
+
from ..normalization import AdaLayerNormSingle
|
31
|
+
|
32
|
+
|
33
|
+
logger = logging.get_logger(__name__)
|
34
|
+
|
35
|
+
|
36
|
+
@maybe_allow_in_graph
|
37
|
+
class AllegroTransformerBlock(nn.Module):
|
38
|
+
r"""
|
39
|
+
Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
dim (`int`):
|
43
|
+
The number of channels in the input and output.
|
44
|
+
num_attention_heads (`int`):
|
45
|
+
The number of heads to use for multi-head attention.
|
46
|
+
attention_head_dim (`int`):
|
47
|
+
The number of channels in each head.
|
48
|
+
dropout (`float`, defaults to `0.0`):
|
49
|
+
The dropout probability to use.
|
50
|
+
cross_attention_dim (`int`, defaults to `2304`):
|
51
|
+
The dimension of the cross attention features.
|
52
|
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
53
|
+
Activation function to be used in feed-forward.
|
54
|
+
attention_bias (`bool`, defaults to `False`):
|
55
|
+
Whether or not to use bias in attention projection layers.
|
56
|
+
only_cross_attention (`bool`, defaults to `False`):
|
57
|
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
58
|
+
Whether to use learnable elementwise affine parameters for normalization.
|
59
|
+
norm_eps (`float`, defaults to `1e-5`):
|
60
|
+
Epsilon value for normalization layers.
|
61
|
+
final_dropout (`bool` defaults to `False`):
|
62
|
+
Whether to apply a final dropout after the last feed-forward layer.
|
63
|
+
"""
|
64
|
+
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
dim: int,
|
68
|
+
num_attention_heads: int,
|
69
|
+
attention_head_dim: int,
|
70
|
+
dropout=0.0,
|
71
|
+
cross_attention_dim: Optional[int] = None,
|
72
|
+
activation_fn: str = "geglu",
|
73
|
+
attention_bias: bool = False,
|
74
|
+
norm_elementwise_affine: bool = True,
|
75
|
+
norm_eps: float = 1e-5,
|
76
|
+
):
|
77
|
+
super().__init__()
|
78
|
+
|
79
|
+
# 1. Self Attention
|
80
|
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
81
|
+
|
82
|
+
self.attn1 = Attention(
|
83
|
+
query_dim=dim,
|
84
|
+
heads=num_attention_heads,
|
85
|
+
dim_head=attention_head_dim,
|
86
|
+
dropout=dropout,
|
87
|
+
bias=attention_bias,
|
88
|
+
cross_attention_dim=None,
|
89
|
+
processor=AllegroAttnProcessor2_0(),
|
90
|
+
)
|
91
|
+
|
92
|
+
# 2. Cross Attention
|
93
|
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
94
|
+
self.attn2 = Attention(
|
95
|
+
query_dim=dim,
|
96
|
+
cross_attention_dim=cross_attention_dim,
|
97
|
+
heads=num_attention_heads,
|
98
|
+
dim_head=attention_head_dim,
|
99
|
+
dropout=dropout,
|
100
|
+
bias=attention_bias,
|
101
|
+
processor=AllegroAttnProcessor2_0(),
|
102
|
+
)
|
103
|
+
|
104
|
+
# 3. Feed Forward
|
105
|
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
106
|
+
|
107
|
+
self.ff = FeedForward(
|
108
|
+
dim,
|
109
|
+
dropout=dropout,
|
110
|
+
activation_fn=activation_fn,
|
111
|
+
)
|
112
|
+
|
113
|
+
# 4. Scale-shift
|
114
|
+
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
115
|
+
|
116
|
+
def forward(
|
117
|
+
self,
|
118
|
+
hidden_states: torch.Tensor,
|
119
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
120
|
+
temb: Optional[torch.LongTensor] = None,
|
121
|
+
attention_mask: Optional[torch.Tensor] = None,
|
122
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
123
|
+
image_rotary_emb=None,
|
124
|
+
) -> torch.Tensor:
|
125
|
+
# 0. Self-Attention
|
126
|
+
batch_size = hidden_states.shape[0]
|
127
|
+
|
128
|
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
129
|
+
self.scale_shift_table[None] + temb.reshape(batch_size, 6, -1)
|
130
|
+
).chunk(6, dim=1)
|
131
|
+
norm_hidden_states = self.norm1(hidden_states)
|
132
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
133
|
+
norm_hidden_states = norm_hidden_states.squeeze(1)
|
134
|
+
|
135
|
+
attn_output = self.attn1(
|
136
|
+
norm_hidden_states,
|
137
|
+
encoder_hidden_states=None,
|
138
|
+
attention_mask=attention_mask,
|
139
|
+
image_rotary_emb=image_rotary_emb,
|
140
|
+
)
|
141
|
+
attn_output = gate_msa * attn_output
|
142
|
+
|
143
|
+
hidden_states = attn_output + hidden_states
|
144
|
+
if hidden_states.ndim == 4:
|
145
|
+
hidden_states = hidden_states.squeeze(1)
|
146
|
+
|
147
|
+
# 1. Cross-Attention
|
148
|
+
if self.attn2 is not None:
|
149
|
+
norm_hidden_states = hidden_states
|
150
|
+
|
151
|
+
attn_output = self.attn2(
|
152
|
+
norm_hidden_states,
|
153
|
+
encoder_hidden_states=encoder_hidden_states,
|
154
|
+
attention_mask=encoder_attention_mask,
|
155
|
+
image_rotary_emb=None,
|
156
|
+
)
|
157
|
+
hidden_states = attn_output + hidden_states
|
158
|
+
|
159
|
+
# 2. Feed-forward
|
160
|
+
norm_hidden_states = self.norm2(hidden_states)
|
161
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
162
|
+
|
163
|
+
ff_output = self.ff(norm_hidden_states)
|
164
|
+
ff_output = gate_mlp * ff_output
|
165
|
+
|
166
|
+
hidden_states = ff_output + hidden_states
|
167
|
+
|
168
|
+
# TODO(aryan): maybe following line is not required
|
169
|
+
if hidden_states.ndim == 4:
|
170
|
+
hidden_states = hidden_states.squeeze(1)
|
171
|
+
|
172
|
+
return hidden_states
|
173
|
+
|
174
|
+
|
175
|
+
class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
|
176
|
+
_supports_gradient_checkpointing = True
|
177
|
+
|
178
|
+
"""
|
179
|
+
A 3D Transformer model for video-like data.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
patch_size (`int`, defaults to `2`):
|
183
|
+
The size of spatial patches to use in the patch embedding layer.
|
184
|
+
patch_size_t (`int`, defaults to `1`):
|
185
|
+
The size of temporal patches to use in the patch embedding layer.
|
186
|
+
num_attention_heads (`int`, defaults to `24`):
|
187
|
+
The number of heads to use for multi-head attention.
|
188
|
+
attention_head_dim (`int`, defaults to `96`):
|
189
|
+
The number of channels in each head.
|
190
|
+
in_channels (`int`, defaults to `4`):
|
191
|
+
The number of channels in the input.
|
192
|
+
out_channels (`int`, *optional*, defaults to `4`):
|
193
|
+
The number of channels in the output.
|
194
|
+
num_layers (`int`, defaults to `32`):
|
195
|
+
The number of layers of Transformer blocks to use.
|
196
|
+
dropout (`float`, defaults to `0.0`):
|
197
|
+
The dropout probability to use.
|
198
|
+
cross_attention_dim (`int`, defaults to `2304`):
|
199
|
+
The dimension of the cross attention features.
|
200
|
+
attention_bias (`bool`, defaults to `True`):
|
201
|
+
Whether or not to use bias in the attention projection layers.
|
202
|
+
sample_height (`int`, defaults to `90`):
|
203
|
+
The height of the input latents.
|
204
|
+
sample_width (`int`, defaults to `160`):
|
205
|
+
The width of the input latents.
|
206
|
+
sample_frames (`int`, defaults to `22`):
|
207
|
+
The number of frames in the input latents.
|
208
|
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
209
|
+
Activation function to use in feed-forward.
|
210
|
+
norm_elementwise_affine (`bool`, defaults to `False`):
|
211
|
+
Whether or not to use elementwise affine in normalization layers.
|
212
|
+
norm_eps (`float`, defaults to `1e-6`):
|
213
|
+
The epsilon value to use in normalization layers.
|
214
|
+
caption_channels (`int`, defaults to `4096`):
|
215
|
+
Number of channels to use for projecting the caption embeddings.
|
216
|
+
interpolation_scale_h (`float`, defaults to `2.0`):
|
217
|
+
Scaling factor to apply in 3D positional embeddings across height dimension.
|
218
|
+
interpolation_scale_w (`float`, defaults to `2.0`):
|
219
|
+
Scaling factor to apply in 3D positional embeddings across width dimension.
|
220
|
+
interpolation_scale_t (`float`, defaults to `2.2`):
|
221
|
+
Scaling factor to apply in 3D positional embeddings across time dimension.
|
222
|
+
"""
|
223
|
+
|
224
|
+
@register_to_config
|
225
|
+
def __init__(
|
226
|
+
self,
|
227
|
+
patch_size: int = 2,
|
228
|
+
patch_size_t: int = 1,
|
229
|
+
num_attention_heads: int = 24,
|
230
|
+
attention_head_dim: int = 96,
|
231
|
+
in_channels: int = 4,
|
232
|
+
out_channels: int = 4,
|
233
|
+
num_layers: int = 32,
|
234
|
+
dropout: float = 0.0,
|
235
|
+
cross_attention_dim: int = 2304,
|
236
|
+
attention_bias: bool = True,
|
237
|
+
sample_height: int = 90,
|
238
|
+
sample_width: int = 160,
|
239
|
+
sample_frames: int = 22,
|
240
|
+
activation_fn: str = "gelu-approximate",
|
241
|
+
norm_elementwise_affine: bool = False,
|
242
|
+
norm_eps: float = 1e-6,
|
243
|
+
caption_channels: int = 4096,
|
244
|
+
interpolation_scale_h: float = 2.0,
|
245
|
+
interpolation_scale_w: float = 2.0,
|
246
|
+
interpolation_scale_t: float = 2.2,
|
247
|
+
):
|
248
|
+
super().__init__()
|
249
|
+
|
250
|
+
self.inner_dim = num_attention_heads * attention_head_dim
|
251
|
+
|
252
|
+
interpolation_scale_t = (
|
253
|
+
interpolation_scale_t
|
254
|
+
if interpolation_scale_t is not None
|
255
|
+
else ((sample_frames - 1) // 16 + 1)
|
256
|
+
if sample_frames % 2 == 1
|
257
|
+
else sample_frames // 16
|
258
|
+
)
|
259
|
+
interpolation_scale_h = interpolation_scale_h if interpolation_scale_h is not None else sample_height / 30
|
260
|
+
interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40
|
261
|
+
|
262
|
+
# 1. Patch embedding
|
263
|
+
self.pos_embed = PatchEmbed(
|
264
|
+
height=sample_height,
|
265
|
+
width=sample_width,
|
266
|
+
patch_size=patch_size,
|
267
|
+
in_channels=in_channels,
|
268
|
+
embed_dim=self.inner_dim,
|
269
|
+
pos_embed_type=None,
|
270
|
+
)
|
271
|
+
|
272
|
+
# 2. Transformer blocks
|
273
|
+
self.transformer_blocks = nn.ModuleList(
|
274
|
+
[
|
275
|
+
AllegroTransformerBlock(
|
276
|
+
self.inner_dim,
|
277
|
+
num_attention_heads,
|
278
|
+
attention_head_dim,
|
279
|
+
dropout=dropout,
|
280
|
+
cross_attention_dim=cross_attention_dim,
|
281
|
+
activation_fn=activation_fn,
|
282
|
+
attention_bias=attention_bias,
|
283
|
+
norm_elementwise_affine=norm_elementwise_affine,
|
284
|
+
norm_eps=norm_eps,
|
285
|
+
)
|
286
|
+
for _ in range(num_layers)
|
287
|
+
]
|
288
|
+
)
|
289
|
+
|
290
|
+
# 3. Output projection & norm
|
291
|
+
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
292
|
+
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
|
293
|
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels)
|
294
|
+
|
295
|
+
# 4. Timestep embeddings
|
296
|
+
self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=False)
|
297
|
+
|
298
|
+
# 5. Caption projection
|
299
|
+
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=self.inner_dim)
|
300
|
+
|
301
|
+
self.gradient_checkpointing = False
|
302
|
+
|
303
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
304
|
+
self.gradient_checkpointing = value
|
305
|
+
|
306
|
+
def forward(
|
307
|
+
self,
|
308
|
+
hidden_states: torch.Tensor,
|
309
|
+
encoder_hidden_states: torch.Tensor,
|
310
|
+
timestep: torch.LongTensor,
|
311
|
+
attention_mask: Optional[torch.Tensor] = None,
|
312
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
313
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
314
|
+
return_dict: bool = True,
|
315
|
+
):
|
316
|
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
317
|
+
p_t = self.config.patch_size_t
|
318
|
+
p = self.config.patch_size
|
319
|
+
|
320
|
+
post_patch_num_frames = num_frames // p_t
|
321
|
+
post_patch_height = height // p
|
322
|
+
post_patch_width = width // p
|
323
|
+
|
324
|
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
325
|
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
326
|
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
327
|
+
# expects mask of shape:
|
328
|
+
# [batch, key_tokens]
|
329
|
+
# adds singleton query_tokens dimension:
|
330
|
+
# [batch, 1, key_tokens]
|
331
|
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
332
|
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
333
|
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) attention_mask_vid, attention_mask_img = None, None
|
334
|
+
if attention_mask is not None and attention_mask.ndim == 4:
|
335
|
+
# assume that mask is expressed as:
|
336
|
+
# (1 = keep, 0 = discard)
|
337
|
+
# convert mask into a bias that can be added to attention scores:
|
338
|
+
# (keep = +0, discard = -10000.0)
|
339
|
+
# b, frame+use_image_num, h, w -> a video with images
|
340
|
+
# b, 1, h, w -> only images
|
341
|
+
attention_mask = attention_mask.to(hidden_states.dtype)
|
342
|
+
attention_mask = attention_mask[:, :num_frames] # [batch_size, num_frames, height, width]
|
343
|
+
|
344
|
+
if attention_mask.numel() > 0:
|
345
|
+
attention_mask = attention_mask.unsqueeze(1) # [batch_size, 1, num_frames, height, width]
|
346
|
+
attention_mask = F.max_pool3d(attention_mask, kernel_size=(p_t, p, p), stride=(p_t, p, p))
|
347
|
+
attention_mask = attention_mask.flatten(1).view(batch_size, 1, -1)
|
348
|
+
|
349
|
+
attention_mask = (
|
350
|
+
(1 - attention_mask.bool().to(hidden_states.dtype)) * -10000.0 if attention_mask.numel() > 0 else None
|
351
|
+
)
|
352
|
+
|
353
|
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
354
|
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
355
|
+
encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0
|
356
|
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
357
|
+
|
358
|
+
# 1. Timestep embeddings
|
359
|
+
timestep, embedded_timestep = self.adaln_single(
|
360
|
+
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
361
|
+
)
|
362
|
+
|
363
|
+
# 2. Patch embeddings
|
364
|
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
365
|
+
hidden_states = self.pos_embed(hidden_states)
|
366
|
+
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
|
367
|
+
|
368
|
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
369
|
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1])
|
370
|
+
|
371
|
+
# 3. Transformer blocks
|
372
|
+
for i, block in enumerate(self.transformer_blocks):
|
373
|
+
# TODO(aryan): Implement gradient checkpointing
|
374
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
375
|
+
|
376
|
+
def create_custom_forward(module):
|
377
|
+
def custom_forward(*inputs):
|
378
|
+
return module(*inputs)
|
379
|
+
|
380
|
+
return custom_forward
|
381
|
+
|
382
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
383
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
384
|
+
create_custom_forward(block),
|
385
|
+
hidden_states,
|
386
|
+
encoder_hidden_states,
|
387
|
+
timestep,
|
388
|
+
attention_mask,
|
389
|
+
encoder_attention_mask,
|
390
|
+
image_rotary_emb,
|
391
|
+
**ckpt_kwargs,
|
392
|
+
)
|
393
|
+
else:
|
394
|
+
hidden_states = block(
|
395
|
+
hidden_states=hidden_states,
|
396
|
+
encoder_hidden_states=encoder_hidden_states,
|
397
|
+
temb=timestep,
|
398
|
+
attention_mask=attention_mask,
|
399
|
+
encoder_attention_mask=encoder_attention_mask,
|
400
|
+
image_rotary_emb=image_rotary_emb,
|
401
|
+
)
|
402
|
+
|
403
|
+
# 4. Output normalization & projection
|
404
|
+
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
405
|
+
hidden_states = self.norm_out(hidden_states)
|
406
|
+
|
407
|
+
# Modulation
|
408
|
+
hidden_states = hidden_states * (1 + scale) + shift
|
409
|
+
hidden_states = self.proj_out(hidden_states)
|
410
|
+
hidden_states = hidden_states.squeeze(1)
|
411
|
+
|
412
|
+
# 5. Unpatchify
|
413
|
+
hidden_states = hidden_states.reshape(
|
414
|
+
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1
|
415
|
+
)
|
416
|
+
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
417
|
+
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
|
418
|
+
|
419
|
+
if not return_dict:
|
420
|
+
return (output,)
|
421
|
+
|
422
|
+
return Transformer2DModelOutput(sample=output)
|
@@ -341,7 +341,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
|
341
341
|
hidden_states = hidden_states[:, text_seq_length:]
|
342
342
|
|
343
343
|
for index_block, block in enumerate(self.transformer_blocks):
|
344
|
-
if
|
344
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
345
345
|
|
346
346
|
def create_custom_forward(module):
|
347
347
|
def custom_forward(*inputs):
|
@@ -21,17 +21,19 @@ import torch.nn as nn
|
|
21
21
|
import torch.nn.functional as F
|
22
22
|
|
23
23
|
from ...configuration_utils import ConfigMixin, register_to_config
|
24
|
-
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
24
|
+
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
25
25
|
from ...models.attention import FeedForward
|
26
26
|
from ...models.attention_processor import (
|
27
27
|
Attention,
|
28
28
|
AttentionProcessor,
|
29
29
|
FluxAttnProcessor2_0,
|
30
|
+
FluxAttnProcessor2_0_NPU,
|
30
31
|
FusedFluxAttnProcessor2_0,
|
31
32
|
)
|
32
33
|
from ...models.modeling_utils import ModelMixin
|
33
34
|
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
34
35
|
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
36
|
+
from ...utils.import_utils import is_torch_npu_available
|
35
37
|
from ...utils.torch_utils import maybe_allow_in_graph
|
36
38
|
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
37
39
|
from ..modeling_outputs import Transformer2DModelOutput
|
@@ -64,7 +66,10 @@ class FluxSingleTransformerBlock(nn.Module):
|
|
64
66
|
self.act_mlp = nn.GELU(approximate="tanh")
|
65
67
|
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
66
68
|
|
67
|
-
|
69
|
+
if is_torch_npu_available():
|
70
|
+
processor = FluxAttnProcessor2_0_NPU()
|
71
|
+
else:
|
72
|
+
processor = FluxAttnProcessor2_0()
|
68
73
|
self.attn = Attention(
|
69
74
|
query_dim=dim,
|
70
75
|
cross_attention_dim=None,
|
@@ -172,13 +177,18 @@ class FluxTransformerBlock(nn.Module):
|
|
172
177
|
)
|
173
178
|
joint_attention_kwargs = joint_attention_kwargs or {}
|
174
179
|
# Attention.
|
175
|
-
|
180
|
+
attention_outputs = self.attn(
|
176
181
|
hidden_states=norm_hidden_states,
|
177
182
|
encoder_hidden_states=norm_encoder_hidden_states,
|
178
183
|
image_rotary_emb=image_rotary_emb,
|
179
184
|
**joint_attention_kwargs,
|
180
185
|
)
|
181
186
|
|
187
|
+
if len(attention_outputs) == 2:
|
188
|
+
attn_output, context_attn_output = attention_outputs
|
189
|
+
elif len(attention_outputs) == 3:
|
190
|
+
attn_output, context_attn_output, ip_attn_output = attention_outputs
|
191
|
+
|
182
192
|
# Process attention outputs for the `hidden_states`.
|
183
193
|
attn_output = gate_msa.unsqueeze(1) * attn_output
|
184
194
|
hidden_states = hidden_states + attn_output
|
@@ -190,6 +200,8 @@ class FluxTransformerBlock(nn.Module):
|
|
190
200
|
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
191
201
|
|
192
202
|
hidden_states = hidden_states + ff_output
|
203
|
+
if len(attention_outputs) == 3:
|
204
|
+
hidden_states = hidden_states + ip_attn_output
|
193
205
|
|
194
206
|
# Process attention outputs for the `encoder_hidden_states`.
|
195
207
|
|
@@ -207,7 +219,9 @@ class FluxTransformerBlock(nn.Module):
|
|
207
219
|
return encoder_hidden_states, hidden_states
|
208
220
|
|
209
221
|
|
210
|
-
class FluxTransformer2DModel(
|
222
|
+
class FluxTransformer2DModel(
|
223
|
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
|
224
|
+
):
|
211
225
|
"""
|
212
226
|
The Transformer model introduced in Flux.
|
213
227
|
|
@@ -233,6 +247,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
233
247
|
self,
|
234
248
|
patch_size: int = 1,
|
235
249
|
in_channels: int = 64,
|
250
|
+
out_channels: Optional[int] = None,
|
236
251
|
num_layers: int = 19,
|
237
252
|
num_single_layers: int = 38,
|
238
253
|
attention_head_dim: int = 128,
|
@@ -243,7 +258,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
243
258
|
axes_dims_rope: Tuple[int] = (16, 56, 56),
|
244
259
|
):
|
245
260
|
super().__init__()
|
246
|
-
self.out_channels = in_channels
|
261
|
+
self.out_channels = out_channels or in_channels
|
247
262
|
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
248
263
|
|
249
264
|
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
@@ -256,7 +271,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
256
271
|
)
|
257
272
|
|
258
273
|
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
|
259
|
-
self.x_embedder =
|
274
|
+
self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)
|
260
275
|
|
261
276
|
self.transformer_blocks = nn.ModuleList(
|
262
277
|
[
|
@@ -444,6 +459,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
444
459
|
logger.warning(
|
445
460
|
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
446
461
|
)
|
462
|
+
|
447
463
|
hidden_states = self.x_embedder(hidden_states)
|
448
464
|
|
449
465
|
timestep = timestep.to(hidden_states.dtype) * 1000
|
@@ -451,6 +467,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
451
467
|
guidance = guidance.to(hidden_states.dtype) * 1000
|
452
468
|
else:
|
453
469
|
guidance = None
|
470
|
+
|
454
471
|
temb = (
|
455
472
|
self.time_text_embed(timestep, pooled_projections)
|
456
473
|
if guidance is None
|
@@ -474,8 +491,13 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
474
491
|
ids = torch.cat((txt_ids, img_ids), dim=0)
|
475
492
|
image_rotary_emb = self.pos_embed(ids)
|
476
493
|
|
494
|
+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
495
|
+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
496
|
+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
497
|
+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
498
|
+
|
477
499
|
for index_block, block in enumerate(self.transformer_blocks):
|
478
|
-
if
|
500
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
479
501
|
|
480
502
|
def create_custom_forward(module, return_dict=None):
|
481
503
|
def custom_forward(*inputs):
|
@@ -516,11 +538,10 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
516
538
|
)
|
517
539
|
else:
|
518
540
|
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
519
|
-
|
520
541
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
521
542
|
|
522
543
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
523
|
-
if
|
544
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
524
545
|
|
525
546
|
def create_custom_forward(module, return_dict=None):
|
526
547
|
def custom_forward(*inputs):
|