diffusers 0.28.0__py3-none-any.whl → 0.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +9 -1
- diffusers/configuration_utils.py +17 -0
- diffusers/models/__init__.py +6 -0
- diffusers/models/activations.py +12 -0
- diffusers/models/attention_processor.py +108 -0
- diffusers/models/embeddings.py +216 -8
- diffusers/models/model_loading_utils.py +28 -0
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_utils.py +57 -1
- diffusers/models/normalization.py +2 -1
- diffusers/models/transformers/__init__.py +3 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
- diffusers/models/transformers/pixart_transformer_2d.py +336 -0
- diffusers/models/transformers/transformer_2d.py +37 -45
- diffusers/pipelines/__init__.py +2 -0
- diffusers/pipelines/dit/pipeline_dit.py +4 -4
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
- diffusers/pipelines/pipeline_loading_utils.py +1 -0
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +4 -4
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +2 -2
- diffusers/utils/dummy_pt_objects.py +45 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +15 -0
- {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/METADATA +44 -44
- {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/RECORD +30 -25
- {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/WHEEL +1 -1
- {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
- {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,240 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from typing import Any, Dict, Optional
|
15
|
+
|
16
|
+
import torch
|
17
|
+
import torch.nn.functional as F
|
18
|
+
from torch import nn
|
19
|
+
|
20
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
21
|
+
from ...utils import is_torch_version, logging
|
22
|
+
from ..attention import BasicTransformerBlock
|
23
|
+
from ..embeddings import PatchEmbed
|
24
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
25
|
+
from ..modeling_utils import ModelMixin
|
26
|
+
|
27
|
+
|
28
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
29
|
+
|
30
|
+
|
31
|
+
class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
32
|
+
r"""
|
33
|
+
A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748).
|
34
|
+
|
35
|
+
Parameters:
|
36
|
+
num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
|
37
|
+
attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
|
38
|
+
in_channels (int, defaults to 4): The number of channels in the input.
|
39
|
+
out_channels (int, optional):
|
40
|
+
The number of channels in the output. Specify this parameter if the output channel number differs from the
|
41
|
+
input.
|
42
|
+
num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
|
43
|
+
dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
|
44
|
+
norm_num_groups (int, optional, defaults to 32):
|
45
|
+
Number of groups for group normalization within Transformer blocks.
|
46
|
+
attention_bias (bool, optional, defaults to True):
|
47
|
+
Configure if the Transformer blocks' attention should contain a bias parameter.
|
48
|
+
sample_size (int, defaults to 32):
|
49
|
+
The width of the latent images. This parameter is fixed during training.
|
50
|
+
patch_size (int, defaults to 2):
|
51
|
+
Size of the patches the model processes, relevant for architectures working on non-sequential data.
|
52
|
+
activation_fn (str, optional, defaults to "gelu-approximate"):
|
53
|
+
Activation function to use in feed-forward networks within Transformer blocks.
|
54
|
+
num_embeds_ada_norm (int, optional, defaults to 1000):
|
55
|
+
Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
|
56
|
+
inference.
|
57
|
+
upcast_attention (bool, optional, defaults to False):
|
58
|
+
If true, upcasts the attention mechanism dimensions for potentially improved performance.
|
59
|
+
norm_type (str, optional, defaults to "ada_norm_zero"):
|
60
|
+
Specifies the type of normalization used, can be 'ada_norm_zero'.
|
61
|
+
norm_elementwise_affine (bool, optional, defaults to False):
|
62
|
+
If true, enables element-wise affine parameters in the normalization layers.
|
63
|
+
norm_eps (float, optional, defaults to 1e-5):
|
64
|
+
A small constant added to the denominator in normalization layers to prevent division by zero.
|
65
|
+
"""
|
66
|
+
|
67
|
+
_supports_gradient_checkpointing = True
|
68
|
+
|
69
|
+
@register_to_config
|
70
|
+
def __init__(
|
71
|
+
self,
|
72
|
+
num_attention_heads: int = 16,
|
73
|
+
attention_head_dim: int = 72,
|
74
|
+
in_channels: int = 4,
|
75
|
+
out_channels: Optional[int] = None,
|
76
|
+
num_layers: int = 28,
|
77
|
+
dropout: float = 0.0,
|
78
|
+
norm_num_groups: int = 32,
|
79
|
+
attention_bias: bool = True,
|
80
|
+
sample_size: int = 32,
|
81
|
+
patch_size: int = 2,
|
82
|
+
activation_fn: str = "gelu-approximate",
|
83
|
+
num_embeds_ada_norm: Optional[int] = 1000,
|
84
|
+
upcast_attention: bool = False,
|
85
|
+
norm_type: str = "ada_norm_zero",
|
86
|
+
norm_elementwise_affine: bool = False,
|
87
|
+
norm_eps: float = 1e-5,
|
88
|
+
):
|
89
|
+
super().__init__()
|
90
|
+
|
91
|
+
# Validate inputs.
|
92
|
+
if norm_type != "ada_norm_zero":
|
93
|
+
raise NotImplementedError(
|
94
|
+
f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
|
95
|
+
)
|
96
|
+
elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None:
|
97
|
+
raise ValueError(
|
98
|
+
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
|
99
|
+
)
|
100
|
+
|
101
|
+
# Set some common variables used across the board.
|
102
|
+
self.attention_head_dim = attention_head_dim
|
103
|
+
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
104
|
+
self.out_channels = in_channels if out_channels is None else out_channels
|
105
|
+
self.gradient_checkpointing = False
|
106
|
+
|
107
|
+
# 2. Initialize the position embedding and transformer blocks.
|
108
|
+
self.height = self.config.sample_size
|
109
|
+
self.width = self.config.sample_size
|
110
|
+
|
111
|
+
self.patch_size = self.config.patch_size
|
112
|
+
self.pos_embed = PatchEmbed(
|
113
|
+
height=self.config.sample_size,
|
114
|
+
width=self.config.sample_size,
|
115
|
+
patch_size=self.config.patch_size,
|
116
|
+
in_channels=self.config.in_channels,
|
117
|
+
embed_dim=self.inner_dim,
|
118
|
+
)
|
119
|
+
|
120
|
+
self.transformer_blocks = nn.ModuleList(
|
121
|
+
[
|
122
|
+
BasicTransformerBlock(
|
123
|
+
self.inner_dim,
|
124
|
+
self.config.num_attention_heads,
|
125
|
+
self.config.attention_head_dim,
|
126
|
+
dropout=self.config.dropout,
|
127
|
+
activation_fn=self.config.activation_fn,
|
128
|
+
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
|
129
|
+
attention_bias=self.config.attention_bias,
|
130
|
+
upcast_attention=self.config.upcast_attention,
|
131
|
+
norm_type=norm_type,
|
132
|
+
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
133
|
+
norm_eps=self.config.norm_eps,
|
134
|
+
)
|
135
|
+
for _ in range(self.config.num_layers)
|
136
|
+
]
|
137
|
+
)
|
138
|
+
|
139
|
+
# 3. Output blocks.
|
140
|
+
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
141
|
+
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
142
|
+
self.proj_out_2 = nn.Linear(
|
143
|
+
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
|
144
|
+
)
|
145
|
+
|
146
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
147
|
+
if hasattr(module, "gradient_checkpointing"):
|
148
|
+
module.gradient_checkpointing = value
|
149
|
+
|
150
|
+
def forward(
|
151
|
+
self,
|
152
|
+
hidden_states: torch.Tensor,
|
153
|
+
timestep: Optional[torch.LongTensor] = None,
|
154
|
+
class_labels: Optional[torch.LongTensor] = None,
|
155
|
+
cross_attention_kwargs: Dict[str, Any] = None,
|
156
|
+
return_dict: bool = True,
|
157
|
+
):
|
158
|
+
"""
|
159
|
+
The [`DiTTransformer2DModel`] forward method.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
163
|
+
Input `hidden_states`.
|
164
|
+
timestep ( `torch.LongTensor`, *optional*):
|
165
|
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
166
|
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
167
|
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
168
|
+
`AdaLayerZeroNorm`.
|
169
|
+
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
170
|
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
171
|
+
`self.processor` in
|
172
|
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
173
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
174
|
+
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
175
|
+
tuple.
|
176
|
+
|
177
|
+
Returns:
|
178
|
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
179
|
+
`tuple` where the first element is the sample tensor.
|
180
|
+
"""
|
181
|
+
# 1. Input
|
182
|
+
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
183
|
+
hidden_states = self.pos_embed(hidden_states)
|
184
|
+
|
185
|
+
# 2. Blocks
|
186
|
+
for block in self.transformer_blocks:
|
187
|
+
if self.training and self.gradient_checkpointing:
|
188
|
+
|
189
|
+
def create_custom_forward(module, return_dict=None):
|
190
|
+
def custom_forward(*inputs):
|
191
|
+
if return_dict is not None:
|
192
|
+
return module(*inputs, return_dict=return_dict)
|
193
|
+
else:
|
194
|
+
return module(*inputs)
|
195
|
+
|
196
|
+
return custom_forward
|
197
|
+
|
198
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
199
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
200
|
+
create_custom_forward(block),
|
201
|
+
hidden_states,
|
202
|
+
None,
|
203
|
+
None,
|
204
|
+
None,
|
205
|
+
timestep,
|
206
|
+
cross_attention_kwargs,
|
207
|
+
class_labels,
|
208
|
+
**ckpt_kwargs,
|
209
|
+
)
|
210
|
+
else:
|
211
|
+
hidden_states = block(
|
212
|
+
hidden_states,
|
213
|
+
attention_mask=None,
|
214
|
+
encoder_hidden_states=None,
|
215
|
+
encoder_attention_mask=None,
|
216
|
+
timestep=timestep,
|
217
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
218
|
+
class_labels=class_labels,
|
219
|
+
)
|
220
|
+
|
221
|
+
# 3. Output
|
222
|
+
conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype)
|
223
|
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
224
|
+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
225
|
+
hidden_states = self.proj_out_2(hidden_states)
|
226
|
+
|
227
|
+
# unpatchify
|
228
|
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
229
|
+
hidden_states = hidden_states.reshape(
|
230
|
+
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
231
|
+
)
|
232
|
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
233
|
+
output = hidden_states.reshape(
|
234
|
+
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
235
|
+
)
|
236
|
+
|
237
|
+
if not return_dict:
|
238
|
+
return (output,)
|
239
|
+
|
240
|
+
return Transformer2DModelOutput(sample=output)
|
@@ -0,0 +1,427 @@
|
|
1
|
+
# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from typing import Optional
|
15
|
+
|
16
|
+
import torch
|
17
|
+
import torch.nn.functional as F
|
18
|
+
from torch import nn
|
19
|
+
|
20
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
21
|
+
from ...utils import logging
|
22
|
+
from ...utils.torch_utils import maybe_allow_in_graph
|
23
|
+
from ..attention import FeedForward
|
24
|
+
from ..attention_processor import Attention, HunyuanAttnProcessor2_0
|
25
|
+
from ..embeddings import (
|
26
|
+
HunyuanCombinedTimestepTextSizeStyleEmbedding,
|
27
|
+
PatchEmbed,
|
28
|
+
PixArtAlphaTextProjection,
|
29
|
+
)
|
30
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
31
|
+
from ..modeling_utils import ModelMixin
|
32
|
+
from ..normalization import AdaLayerNormContinuous
|
33
|
+
|
34
|
+
|
35
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
36
|
+
|
37
|
+
|
38
|
+
class FP32LayerNorm(nn.LayerNorm):
|
39
|
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
40
|
+
origin_dtype = inputs.dtype
|
41
|
+
return F.layer_norm(
|
42
|
+
inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps
|
43
|
+
).to(origin_dtype)
|
44
|
+
|
45
|
+
|
46
|
+
class AdaLayerNormShift(nn.Module):
|
47
|
+
r"""
|
48
|
+
Norm layer modified to incorporate timestep embeddings.
|
49
|
+
|
50
|
+
Parameters:
|
51
|
+
embedding_dim (`int`): The size of each embedding vector.
|
52
|
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
53
|
+
"""
|
54
|
+
|
55
|
+
def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6):
|
56
|
+
super().__init__()
|
57
|
+
self.silu = nn.SiLU()
|
58
|
+
self.linear = nn.Linear(embedding_dim, embedding_dim)
|
59
|
+
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
|
60
|
+
|
61
|
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
62
|
+
shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype))
|
63
|
+
x = self.norm(x) + shift.unsqueeze(dim=1)
|
64
|
+
return x
|
65
|
+
|
66
|
+
|
67
|
+
@maybe_allow_in_graph
|
68
|
+
class HunyuanDiTBlock(nn.Module):
|
69
|
+
r"""
|
70
|
+
Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
|
71
|
+
QKNorm
|
72
|
+
|
73
|
+
Parameters:
|
74
|
+
dim (`int`):
|
75
|
+
The number of channels in the input and output.
|
76
|
+
num_attention_heads (`int`):
|
77
|
+
The number of headsto use for multi-head attention.
|
78
|
+
cross_attention_dim (`int`,*optional*):
|
79
|
+
The size of the encoder_hidden_states vector for cross attention.
|
80
|
+
dropout(`float`, *optional*, defaults to 0.0):
|
81
|
+
The dropout probability to use.
|
82
|
+
activation_fn (`str`,*optional*, defaults to `"geglu"`):
|
83
|
+
Activation function to be used in feed-forward. .
|
84
|
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
85
|
+
Whether to use learnable elementwise affine parameters for normalization.
|
86
|
+
norm_eps (`float`, *optional*, defaults to 1e-6):
|
87
|
+
A small constant added to the denominator in normalization layers to prevent division by zero.
|
88
|
+
final_dropout (`bool` *optional*, defaults to False):
|
89
|
+
Whether to apply a final dropout after the last feed-forward layer.
|
90
|
+
ff_inner_dim (`int`, *optional*):
|
91
|
+
The size of the hidden layer in the feed-forward block. Defaults to `None`.
|
92
|
+
ff_bias (`bool`, *optional*, defaults to `True`):
|
93
|
+
Whether to use bias in the feed-forward block.
|
94
|
+
skip (`bool`, *optional*, defaults to `False`):
|
95
|
+
Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
|
96
|
+
qk_norm (`bool`, *optional*, defaults to `True`):
|
97
|
+
Whether to use normalization in QK calculation. Defaults to `True`.
|
98
|
+
"""
|
99
|
+
|
100
|
+
def __init__(
|
101
|
+
self,
|
102
|
+
dim: int,
|
103
|
+
num_attention_heads: int,
|
104
|
+
cross_attention_dim: int = 1024,
|
105
|
+
dropout=0.0,
|
106
|
+
activation_fn: str = "geglu",
|
107
|
+
norm_elementwise_affine: bool = True,
|
108
|
+
norm_eps: float = 1e-6,
|
109
|
+
final_dropout: bool = False,
|
110
|
+
ff_inner_dim: Optional[int] = None,
|
111
|
+
ff_bias: bool = True,
|
112
|
+
skip: bool = False,
|
113
|
+
qk_norm: bool = True,
|
114
|
+
):
|
115
|
+
super().__init__()
|
116
|
+
|
117
|
+
# Define 3 blocks. Each block has its own normalization layer.
|
118
|
+
# NOTE: when new version comes, check norm2 and norm 3
|
119
|
+
# 1. Self-Attn
|
120
|
+
self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
121
|
+
|
122
|
+
self.attn1 = Attention(
|
123
|
+
query_dim=dim,
|
124
|
+
cross_attention_dim=None,
|
125
|
+
dim_head=dim // num_attention_heads,
|
126
|
+
heads=num_attention_heads,
|
127
|
+
qk_norm="layer_norm" if qk_norm else None,
|
128
|
+
eps=1e-6,
|
129
|
+
bias=True,
|
130
|
+
processor=HunyuanAttnProcessor2_0(),
|
131
|
+
)
|
132
|
+
|
133
|
+
# 2. Cross-Attn
|
134
|
+
self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
135
|
+
|
136
|
+
self.attn2 = Attention(
|
137
|
+
query_dim=dim,
|
138
|
+
cross_attention_dim=cross_attention_dim,
|
139
|
+
dim_head=dim // num_attention_heads,
|
140
|
+
heads=num_attention_heads,
|
141
|
+
qk_norm="layer_norm" if qk_norm else None,
|
142
|
+
eps=1e-6,
|
143
|
+
bias=True,
|
144
|
+
processor=HunyuanAttnProcessor2_0(),
|
145
|
+
)
|
146
|
+
# 3. Feed-forward
|
147
|
+
self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
148
|
+
|
149
|
+
self.ff = FeedForward(
|
150
|
+
dim,
|
151
|
+
dropout=dropout, ### 0.0
|
152
|
+
activation_fn=activation_fn, ### approx GeLU
|
153
|
+
final_dropout=final_dropout, ### 0.0
|
154
|
+
inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
|
155
|
+
bias=ff_bias,
|
156
|
+
)
|
157
|
+
|
158
|
+
# 4. Skip Connection
|
159
|
+
if skip:
|
160
|
+
self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True)
|
161
|
+
self.skip_linear = nn.Linear(2 * dim, dim)
|
162
|
+
else:
|
163
|
+
self.skip_linear = None
|
164
|
+
|
165
|
+
# let chunk size default to None
|
166
|
+
self._chunk_size = None
|
167
|
+
self._chunk_dim = 0
|
168
|
+
|
169
|
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
170
|
+
# Sets chunk feed-forward
|
171
|
+
self._chunk_size = chunk_size
|
172
|
+
self._chunk_dim = dim
|
173
|
+
|
174
|
+
def forward(
|
175
|
+
self,
|
176
|
+
hidden_states: torch.Tensor,
|
177
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
178
|
+
temb: Optional[torch.Tensor] = None,
|
179
|
+
image_rotary_emb=None,
|
180
|
+
skip=None,
|
181
|
+
) -> torch.Tensor:
|
182
|
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
183
|
+
# 0. Long Skip Connection
|
184
|
+
if self.skip_linear is not None:
|
185
|
+
cat = torch.cat([hidden_states, skip], dim=-1)
|
186
|
+
cat = self.skip_norm(cat)
|
187
|
+
hidden_states = self.skip_linear(cat)
|
188
|
+
|
189
|
+
# 1. Self-Attention
|
190
|
+
norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct
|
191
|
+
attn_output = self.attn1(
|
192
|
+
norm_hidden_states,
|
193
|
+
image_rotary_emb=image_rotary_emb,
|
194
|
+
)
|
195
|
+
hidden_states = hidden_states + attn_output
|
196
|
+
|
197
|
+
# 2. Cross-Attention
|
198
|
+
hidden_states = hidden_states + self.attn2(
|
199
|
+
self.norm2(hidden_states),
|
200
|
+
encoder_hidden_states=encoder_hidden_states,
|
201
|
+
image_rotary_emb=image_rotary_emb,
|
202
|
+
)
|
203
|
+
|
204
|
+
# FFN Layer ### TODO: switch norm2 and norm3 in the state dict
|
205
|
+
mlp_inputs = self.norm3(hidden_states)
|
206
|
+
hidden_states = hidden_states + self.ff(mlp_inputs)
|
207
|
+
|
208
|
+
return hidden_states
|
209
|
+
|
210
|
+
|
211
|
+
class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
212
|
+
"""
|
213
|
+
HunYuanDiT: Diffusion model with a Transformer backbone.
|
214
|
+
|
215
|
+
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
216
|
+
|
217
|
+
Parameters:
|
218
|
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
219
|
+
The number of heads to use for multi-head attention.
|
220
|
+
attention_head_dim (`int`, *optional*, defaults to 88):
|
221
|
+
The number of channels in each head.
|
222
|
+
in_channels (`int`, *optional*):
|
223
|
+
The number of channels in the input and output (specify if the input is **continuous**).
|
224
|
+
patch_size (`int`, *optional*):
|
225
|
+
The size of the patch to use for the input.
|
226
|
+
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
227
|
+
Activation function to use in feed-forward.
|
228
|
+
sample_size (`int`, *optional*):
|
229
|
+
The width of the latent images. This is fixed during training since it is used to learn a number of
|
230
|
+
position embeddings.
|
231
|
+
dropout (`float`, *optional*, defaults to 0.0):
|
232
|
+
The dropout probability to use.
|
233
|
+
cross_attention_dim (`int`, *optional*):
|
234
|
+
The number of dimension in the clip text embedding.
|
235
|
+
hidden_size (`int`, *optional*):
|
236
|
+
The size of hidden layer in the conditioning embedding layers.
|
237
|
+
num_layers (`int`, *optional*, defaults to 1):
|
238
|
+
The number of layers of Transformer blocks to use.
|
239
|
+
mlp_ratio (`float`, *optional*, defaults to 4.0):
|
240
|
+
The ratio of the hidden layer size to the input size.
|
241
|
+
learn_sigma (`bool`, *optional*, defaults to `True`):
|
242
|
+
Whether to predict variance.
|
243
|
+
cross_attention_dim_t5 (`int`, *optional*):
|
244
|
+
The number dimensions in t5 text embedding.
|
245
|
+
pooled_projection_dim (`int`, *optional*):
|
246
|
+
The size of the pooled projection.
|
247
|
+
text_len (`int`, *optional*):
|
248
|
+
The length of the clip text embedding.
|
249
|
+
text_len_t5 (`int`, *optional*):
|
250
|
+
The length of the T5 text embedding.
|
251
|
+
"""
|
252
|
+
|
253
|
+
@register_to_config
|
254
|
+
def __init__(
|
255
|
+
self,
|
256
|
+
num_attention_heads: int = 16,
|
257
|
+
attention_head_dim: int = 88,
|
258
|
+
in_channels: Optional[int] = None,
|
259
|
+
patch_size: Optional[int] = None,
|
260
|
+
activation_fn: str = "gelu-approximate",
|
261
|
+
sample_size=32,
|
262
|
+
hidden_size=1152,
|
263
|
+
num_layers: int = 28,
|
264
|
+
mlp_ratio: float = 4.0,
|
265
|
+
learn_sigma: bool = True,
|
266
|
+
cross_attention_dim: int = 1024,
|
267
|
+
norm_type: str = "layer_norm",
|
268
|
+
cross_attention_dim_t5: int = 2048,
|
269
|
+
pooled_projection_dim: int = 1024,
|
270
|
+
text_len: int = 77,
|
271
|
+
text_len_t5: int = 256,
|
272
|
+
):
|
273
|
+
super().__init__()
|
274
|
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
275
|
+
self.num_heads = num_attention_heads
|
276
|
+
self.inner_dim = num_attention_heads * attention_head_dim
|
277
|
+
|
278
|
+
self.text_embedder = PixArtAlphaTextProjection(
|
279
|
+
in_features=cross_attention_dim_t5,
|
280
|
+
hidden_size=cross_attention_dim_t5 * 4,
|
281
|
+
out_features=cross_attention_dim,
|
282
|
+
act_fn="silu_fp32",
|
283
|
+
)
|
284
|
+
|
285
|
+
self.text_embedding_padding = nn.Parameter(
|
286
|
+
torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
|
287
|
+
)
|
288
|
+
|
289
|
+
self.pos_embed = PatchEmbed(
|
290
|
+
height=sample_size,
|
291
|
+
width=sample_size,
|
292
|
+
in_channels=in_channels,
|
293
|
+
embed_dim=hidden_size,
|
294
|
+
patch_size=patch_size,
|
295
|
+
pos_embed_type=None,
|
296
|
+
)
|
297
|
+
|
298
|
+
self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
|
299
|
+
hidden_size,
|
300
|
+
pooled_projection_dim=pooled_projection_dim,
|
301
|
+
seq_len=text_len_t5,
|
302
|
+
cross_attention_dim=cross_attention_dim_t5,
|
303
|
+
)
|
304
|
+
|
305
|
+
# HunyuanDiT Blocks
|
306
|
+
self.blocks = nn.ModuleList(
|
307
|
+
[
|
308
|
+
HunyuanDiTBlock(
|
309
|
+
dim=self.inner_dim,
|
310
|
+
num_attention_heads=self.config.num_attention_heads,
|
311
|
+
activation_fn=activation_fn,
|
312
|
+
ff_inner_dim=int(self.inner_dim * mlp_ratio),
|
313
|
+
cross_attention_dim=cross_attention_dim,
|
314
|
+
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
|
315
|
+
skip=layer > num_layers // 2,
|
316
|
+
)
|
317
|
+
for layer in range(num_layers)
|
318
|
+
]
|
319
|
+
)
|
320
|
+
|
321
|
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
322
|
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
323
|
+
|
324
|
+
def forward(
|
325
|
+
self,
|
326
|
+
hidden_states,
|
327
|
+
timestep,
|
328
|
+
encoder_hidden_states=None,
|
329
|
+
text_embedding_mask=None,
|
330
|
+
encoder_hidden_states_t5=None,
|
331
|
+
text_embedding_mask_t5=None,
|
332
|
+
image_meta_size=None,
|
333
|
+
style=None,
|
334
|
+
image_rotary_emb=None,
|
335
|
+
return_dict=True,
|
336
|
+
):
|
337
|
+
"""
|
338
|
+
The [`HunyuanDiT2DModel`] forward method.
|
339
|
+
|
340
|
+
Args:
|
341
|
+
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
|
342
|
+
The input tensor.
|
343
|
+
timestep ( `torch.LongTensor`, *optional*):
|
344
|
+
Used to indicate denoising step.
|
345
|
+
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
346
|
+
Conditional embeddings for cross attention layer. This is the output of `BertModel`.
|
347
|
+
text_embedding_mask: torch.Tensor
|
348
|
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
|
349
|
+
of `BertModel`.
|
350
|
+
encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
351
|
+
Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
|
352
|
+
text_embedding_mask_t5: torch.Tensor
|
353
|
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
|
354
|
+
of T5 Text Encoder.
|
355
|
+
image_meta_size (torch.Tensor):
|
356
|
+
Conditional embedding indicate the image sizes
|
357
|
+
style: torch.Tensor:
|
358
|
+
Conditional embedding indicate the style
|
359
|
+
image_rotary_emb (`torch.Tensor`):
|
360
|
+
The image rotary embeddings to apply on query and key tensors during attention calculation.
|
361
|
+
return_dict: bool
|
362
|
+
Whether to return a dictionary.
|
363
|
+
"""
|
364
|
+
|
365
|
+
height, width = hidden_states.shape[-2:]
|
366
|
+
|
367
|
+
hidden_states = self.pos_embed(hidden_states)
|
368
|
+
|
369
|
+
temb = self.time_extra_emb(
|
370
|
+
timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
|
371
|
+
) # [B, D]
|
372
|
+
|
373
|
+
# text projection
|
374
|
+
batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
|
375
|
+
encoder_hidden_states_t5 = self.text_embedder(
|
376
|
+
encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
|
377
|
+
)
|
378
|
+
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)
|
379
|
+
|
380
|
+
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1)
|
381
|
+
text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
|
382
|
+
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
|
383
|
+
|
384
|
+
encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
|
385
|
+
|
386
|
+
skips = []
|
387
|
+
for layer, block in enumerate(self.blocks):
|
388
|
+
if layer > self.config.num_layers // 2:
|
389
|
+
skip = skips.pop()
|
390
|
+
hidden_states = block(
|
391
|
+
hidden_states,
|
392
|
+
temb=temb,
|
393
|
+
encoder_hidden_states=encoder_hidden_states,
|
394
|
+
image_rotary_emb=image_rotary_emb,
|
395
|
+
skip=skip,
|
396
|
+
) # (N, L, D)
|
397
|
+
else:
|
398
|
+
hidden_states = block(
|
399
|
+
hidden_states,
|
400
|
+
temb=temb,
|
401
|
+
encoder_hidden_states=encoder_hidden_states,
|
402
|
+
image_rotary_emb=image_rotary_emb,
|
403
|
+
) # (N, L, D)
|
404
|
+
|
405
|
+
if layer < (self.config.num_layers // 2 - 1):
|
406
|
+
skips.append(hidden_states)
|
407
|
+
|
408
|
+
# final layer
|
409
|
+
hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
|
410
|
+
hidden_states = self.proj_out(hidden_states)
|
411
|
+
# (N, L, patch_size ** 2 * out_channels)
|
412
|
+
|
413
|
+
# unpatchify: (N, out_channels, H, W)
|
414
|
+
patch_size = self.pos_embed.patch_size
|
415
|
+
height = height // patch_size
|
416
|
+
width = width // patch_size
|
417
|
+
|
418
|
+
hidden_states = hidden_states.reshape(
|
419
|
+
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
|
420
|
+
)
|
421
|
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
422
|
+
output = hidden_states.reshape(
|
423
|
+
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
|
424
|
+
)
|
425
|
+
if not return_dict:
|
426
|
+
return (output,)
|
427
|
+
return Transformer2DModelOutput(sample=output)
|