diffusers 0.28.0__py3-none-any.whl → 0.28.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. diffusers/__init__.py +9 -1
  2. diffusers/configuration_utils.py +17 -0
  3. diffusers/loaders/single_file_utils.py +1 -1
  4. diffusers/models/__init__.py +6 -0
  5. diffusers/models/activations.py +12 -0
  6. diffusers/models/attention_processor.py +108 -0
  7. diffusers/models/embeddings.py +216 -8
  8. diffusers/models/model_loading_utils.py +28 -0
  9. diffusers/models/modeling_outputs.py +14 -0
  10. diffusers/models/modeling_utils.py +57 -1
  11. diffusers/models/normalization.py +2 -1
  12. diffusers/models/transformers/__init__.py +3 -0
  13. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  14. diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
  15. diffusers/models/transformers/pixart_transformer_2d.py +336 -0
  16. diffusers/models/transformers/transformer_2d.py +37 -45
  17. diffusers/pipelines/__init__.py +2 -0
  18. diffusers/pipelines/dit/pipeline_dit.py +4 -4
  19. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  20. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
  21. diffusers/pipelines/pipeline_loading_utils.py +1 -0
  22. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +4 -4
  23. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +2 -2
  24. diffusers/utils/dummy_pt_objects.py +45 -0
  25. diffusers/utils/dummy_torch_and_transformers_objects.py +15 -0
  26. {diffusers-0.28.0.dist-info → diffusers-0.28.2.dist-info}/METADATA +44 -44
  27. {diffusers-0.28.0.dist-info → diffusers-0.28.2.dist-info}/RECORD +31 -26
  28. {diffusers-0.28.0.dist-info → diffusers-0.28.2.dist-info}/WHEEL +1 -1
  29. {diffusers-0.28.0.dist-info → diffusers-0.28.2.dist-info}/LICENSE +0 -0
  30. {diffusers-0.28.0.dist-info → diffusers-0.28.2.dist-info}/entry_points.txt +0 -0
  31. {diffusers-0.28.0.dist-info → diffusers-0.28.2.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,10 @@ from ...utils import is_torch_available
2
2
 
3
3
 
4
4
  if is_torch_available():
5
+ from .dit_transformer_2d import DiTTransformer2DModel
5
6
  from .dual_transformer_2d import DualTransformer2DModel
7
+ from .hunyuan_transformer_2d import HunyuanDiT2DModel
8
+ from .pixart_transformer_2d import PixArtTransformer2DModel
6
9
  from .prior_transformer import PriorTransformer
7
10
  from .t5_film_transformer import T5FilmDecoder
8
11
  from .transformer_2d import Transformer2DModel
@@ -0,0 +1,240 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from ...configuration_utils import ConfigMixin, register_to_config
21
+ from ...utils import is_torch_version, logging
22
+ from ..attention import BasicTransformerBlock
23
+ from ..embeddings import PatchEmbed
24
+ from ..modeling_outputs import Transformer2DModelOutput
25
+ from ..modeling_utils import ModelMixin
26
+
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
32
+ r"""
33
+ A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748).
34
+
35
+ Parameters:
36
+ num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
37
+ attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
38
+ in_channels (int, defaults to 4): The number of channels in the input.
39
+ out_channels (int, optional):
40
+ The number of channels in the output. Specify this parameter if the output channel number differs from the
41
+ input.
42
+ num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
43
+ dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
44
+ norm_num_groups (int, optional, defaults to 32):
45
+ Number of groups for group normalization within Transformer blocks.
46
+ attention_bias (bool, optional, defaults to True):
47
+ Configure if the Transformer blocks' attention should contain a bias parameter.
48
+ sample_size (int, defaults to 32):
49
+ The width of the latent images. This parameter is fixed during training.
50
+ patch_size (int, defaults to 2):
51
+ Size of the patches the model processes, relevant for architectures working on non-sequential data.
52
+ activation_fn (str, optional, defaults to "gelu-approximate"):
53
+ Activation function to use in feed-forward networks within Transformer blocks.
54
+ num_embeds_ada_norm (int, optional, defaults to 1000):
55
+ Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
56
+ inference.
57
+ upcast_attention (bool, optional, defaults to False):
58
+ If true, upcasts the attention mechanism dimensions for potentially improved performance.
59
+ norm_type (str, optional, defaults to "ada_norm_zero"):
60
+ Specifies the type of normalization used, can be 'ada_norm_zero'.
61
+ norm_elementwise_affine (bool, optional, defaults to False):
62
+ If true, enables element-wise affine parameters in the normalization layers.
63
+ norm_eps (float, optional, defaults to 1e-5):
64
+ A small constant added to the denominator in normalization layers to prevent division by zero.
65
+ """
66
+
67
+ _supports_gradient_checkpointing = True
68
+
69
+ @register_to_config
70
+ def __init__(
71
+ self,
72
+ num_attention_heads: int = 16,
73
+ attention_head_dim: int = 72,
74
+ in_channels: int = 4,
75
+ out_channels: Optional[int] = None,
76
+ num_layers: int = 28,
77
+ dropout: float = 0.0,
78
+ norm_num_groups: int = 32,
79
+ attention_bias: bool = True,
80
+ sample_size: int = 32,
81
+ patch_size: int = 2,
82
+ activation_fn: str = "gelu-approximate",
83
+ num_embeds_ada_norm: Optional[int] = 1000,
84
+ upcast_attention: bool = False,
85
+ norm_type: str = "ada_norm_zero",
86
+ norm_elementwise_affine: bool = False,
87
+ norm_eps: float = 1e-5,
88
+ ):
89
+ super().__init__()
90
+
91
+ # Validate inputs.
92
+ if norm_type != "ada_norm_zero":
93
+ raise NotImplementedError(
94
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
95
+ )
96
+ elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None:
97
+ raise ValueError(
98
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
99
+ )
100
+
101
+ # Set some common variables used across the board.
102
+ self.attention_head_dim = attention_head_dim
103
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
104
+ self.out_channels = in_channels if out_channels is None else out_channels
105
+ self.gradient_checkpointing = False
106
+
107
+ # 2. Initialize the position embedding and transformer blocks.
108
+ self.height = self.config.sample_size
109
+ self.width = self.config.sample_size
110
+
111
+ self.patch_size = self.config.patch_size
112
+ self.pos_embed = PatchEmbed(
113
+ height=self.config.sample_size,
114
+ width=self.config.sample_size,
115
+ patch_size=self.config.patch_size,
116
+ in_channels=self.config.in_channels,
117
+ embed_dim=self.inner_dim,
118
+ )
119
+
120
+ self.transformer_blocks = nn.ModuleList(
121
+ [
122
+ BasicTransformerBlock(
123
+ self.inner_dim,
124
+ self.config.num_attention_heads,
125
+ self.config.attention_head_dim,
126
+ dropout=self.config.dropout,
127
+ activation_fn=self.config.activation_fn,
128
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
129
+ attention_bias=self.config.attention_bias,
130
+ upcast_attention=self.config.upcast_attention,
131
+ norm_type=norm_type,
132
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
133
+ norm_eps=self.config.norm_eps,
134
+ )
135
+ for _ in range(self.config.num_layers)
136
+ ]
137
+ )
138
+
139
+ # 3. Output blocks.
140
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
141
+ self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
142
+ self.proj_out_2 = nn.Linear(
143
+ self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
144
+ )
145
+
146
+ def _set_gradient_checkpointing(self, module, value=False):
147
+ if hasattr(module, "gradient_checkpointing"):
148
+ module.gradient_checkpointing = value
149
+
150
+ def forward(
151
+ self,
152
+ hidden_states: torch.Tensor,
153
+ timestep: Optional[torch.LongTensor] = None,
154
+ class_labels: Optional[torch.LongTensor] = None,
155
+ cross_attention_kwargs: Dict[str, Any] = None,
156
+ return_dict: bool = True,
157
+ ):
158
+ """
159
+ The [`DiTTransformer2DModel`] forward method.
160
+
161
+ Args:
162
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
163
+ Input `hidden_states`.
164
+ timestep ( `torch.LongTensor`, *optional*):
165
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
166
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
167
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
168
+ `AdaLayerZeroNorm`.
169
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
170
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
171
+ `self.processor` in
172
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
173
+ return_dict (`bool`, *optional*, defaults to `True`):
174
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
175
+ tuple.
176
+
177
+ Returns:
178
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
179
+ `tuple` where the first element is the sample tensor.
180
+ """
181
+ # 1. Input
182
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
183
+ hidden_states = self.pos_embed(hidden_states)
184
+
185
+ # 2. Blocks
186
+ for block in self.transformer_blocks:
187
+ if self.training and self.gradient_checkpointing:
188
+
189
+ def create_custom_forward(module, return_dict=None):
190
+ def custom_forward(*inputs):
191
+ if return_dict is not None:
192
+ return module(*inputs, return_dict=return_dict)
193
+ else:
194
+ return module(*inputs)
195
+
196
+ return custom_forward
197
+
198
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
199
+ hidden_states = torch.utils.checkpoint.checkpoint(
200
+ create_custom_forward(block),
201
+ hidden_states,
202
+ None,
203
+ None,
204
+ None,
205
+ timestep,
206
+ cross_attention_kwargs,
207
+ class_labels,
208
+ **ckpt_kwargs,
209
+ )
210
+ else:
211
+ hidden_states = block(
212
+ hidden_states,
213
+ attention_mask=None,
214
+ encoder_hidden_states=None,
215
+ encoder_attention_mask=None,
216
+ timestep=timestep,
217
+ cross_attention_kwargs=cross_attention_kwargs,
218
+ class_labels=class_labels,
219
+ )
220
+
221
+ # 3. Output
222
+ conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype)
223
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
224
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
225
+ hidden_states = self.proj_out_2(hidden_states)
226
+
227
+ # unpatchify
228
+ height = width = int(hidden_states.shape[1] ** 0.5)
229
+ hidden_states = hidden_states.reshape(
230
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
231
+ )
232
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
233
+ output = hidden_states.reshape(
234
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
235
+ )
236
+
237
+ if not return_dict:
238
+ return (output,)
239
+
240
+ return Transformer2DModelOutput(sample=output)
@@ -0,0 +1,427 @@
1
+ # Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from ...configuration_utils import ConfigMixin, register_to_config
21
+ from ...utils import logging
22
+ from ...utils.torch_utils import maybe_allow_in_graph
23
+ from ..attention import FeedForward
24
+ from ..attention_processor import Attention, HunyuanAttnProcessor2_0
25
+ from ..embeddings import (
26
+ HunyuanCombinedTimestepTextSizeStyleEmbedding,
27
+ PatchEmbed,
28
+ PixArtAlphaTextProjection,
29
+ )
30
+ from ..modeling_outputs import Transformer2DModelOutput
31
+ from ..modeling_utils import ModelMixin
32
+ from ..normalization import AdaLayerNormContinuous
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ class FP32LayerNorm(nn.LayerNorm):
39
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
40
+ origin_dtype = inputs.dtype
41
+ return F.layer_norm(
42
+ inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps
43
+ ).to(origin_dtype)
44
+
45
+
46
+ class AdaLayerNormShift(nn.Module):
47
+ r"""
48
+ Norm layer modified to incorporate timestep embeddings.
49
+
50
+ Parameters:
51
+ embedding_dim (`int`): The size of each embedding vector.
52
+ num_embeddings (`int`): The size of the embeddings dictionary.
53
+ """
54
+
55
+ def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6):
56
+ super().__init__()
57
+ self.silu = nn.SiLU()
58
+ self.linear = nn.Linear(embedding_dim, embedding_dim)
59
+ self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
60
+
61
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
62
+ shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype))
63
+ x = self.norm(x) + shift.unsqueeze(dim=1)
64
+ return x
65
+
66
+
67
+ @maybe_allow_in_graph
68
+ class HunyuanDiTBlock(nn.Module):
69
+ r"""
70
+ Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
71
+ QKNorm
72
+
73
+ Parameters:
74
+ dim (`int`):
75
+ The number of channels in the input and output.
76
+ num_attention_heads (`int`):
77
+ The number of headsto use for multi-head attention.
78
+ cross_attention_dim (`int`,*optional*):
79
+ The size of the encoder_hidden_states vector for cross attention.
80
+ dropout(`float`, *optional*, defaults to 0.0):
81
+ The dropout probability to use.
82
+ activation_fn (`str`,*optional*, defaults to `"geglu"`):
83
+ Activation function to be used in feed-forward. .
84
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
85
+ Whether to use learnable elementwise affine parameters for normalization.
86
+ norm_eps (`float`, *optional*, defaults to 1e-6):
87
+ A small constant added to the denominator in normalization layers to prevent division by zero.
88
+ final_dropout (`bool` *optional*, defaults to False):
89
+ Whether to apply a final dropout after the last feed-forward layer.
90
+ ff_inner_dim (`int`, *optional*):
91
+ The size of the hidden layer in the feed-forward block. Defaults to `None`.
92
+ ff_bias (`bool`, *optional*, defaults to `True`):
93
+ Whether to use bias in the feed-forward block.
94
+ skip (`bool`, *optional*, defaults to `False`):
95
+ Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
96
+ qk_norm (`bool`, *optional*, defaults to `True`):
97
+ Whether to use normalization in QK calculation. Defaults to `True`.
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ dim: int,
103
+ num_attention_heads: int,
104
+ cross_attention_dim: int = 1024,
105
+ dropout=0.0,
106
+ activation_fn: str = "geglu",
107
+ norm_elementwise_affine: bool = True,
108
+ norm_eps: float = 1e-6,
109
+ final_dropout: bool = False,
110
+ ff_inner_dim: Optional[int] = None,
111
+ ff_bias: bool = True,
112
+ skip: bool = False,
113
+ qk_norm: bool = True,
114
+ ):
115
+ super().__init__()
116
+
117
+ # Define 3 blocks. Each block has its own normalization layer.
118
+ # NOTE: when new version comes, check norm2 and norm 3
119
+ # 1. Self-Attn
120
+ self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
121
+
122
+ self.attn1 = Attention(
123
+ query_dim=dim,
124
+ cross_attention_dim=None,
125
+ dim_head=dim // num_attention_heads,
126
+ heads=num_attention_heads,
127
+ qk_norm="layer_norm" if qk_norm else None,
128
+ eps=1e-6,
129
+ bias=True,
130
+ processor=HunyuanAttnProcessor2_0(),
131
+ )
132
+
133
+ # 2. Cross-Attn
134
+ self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
135
+
136
+ self.attn2 = Attention(
137
+ query_dim=dim,
138
+ cross_attention_dim=cross_attention_dim,
139
+ dim_head=dim // num_attention_heads,
140
+ heads=num_attention_heads,
141
+ qk_norm="layer_norm" if qk_norm else None,
142
+ eps=1e-6,
143
+ bias=True,
144
+ processor=HunyuanAttnProcessor2_0(),
145
+ )
146
+ # 3. Feed-forward
147
+ self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
148
+
149
+ self.ff = FeedForward(
150
+ dim,
151
+ dropout=dropout, ### 0.0
152
+ activation_fn=activation_fn, ### approx GeLU
153
+ final_dropout=final_dropout, ### 0.0
154
+ inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
155
+ bias=ff_bias,
156
+ )
157
+
158
+ # 4. Skip Connection
159
+ if skip:
160
+ self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True)
161
+ self.skip_linear = nn.Linear(2 * dim, dim)
162
+ else:
163
+ self.skip_linear = None
164
+
165
+ # let chunk size default to None
166
+ self._chunk_size = None
167
+ self._chunk_dim = 0
168
+
169
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
170
+ # Sets chunk feed-forward
171
+ self._chunk_size = chunk_size
172
+ self._chunk_dim = dim
173
+
174
+ def forward(
175
+ self,
176
+ hidden_states: torch.Tensor,
177
+ encoder_hidden_states: Optional[torch.Tensor] = None,
178
+ temb: Optional[torch.Tensor] = None,
179
+ image_rotary_emb=None,
180
+ skip=None,
181
+ ) -> torch.Tensor:
182
+ # Notice that normalization is always applied before the real computation in the following blocks.
183
+ # 0. Long Skip Connection
184
+ if self.skip_linear is not None:
185
+ cat = torch.cat([hidden_states, skip], dim=-1)
186
+ cat = self.skip_norm(cat)
187
+ hidden_states = self.skip_linear(cat)
188
+
189
+ # 1. Self-Attention
190
+ norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct
191
+ attn_output = self.attn1(
192
+ norm_hidden_states,
193
+ image_rotary_emb=image_rotary_emb,
194
+ )
195
+ hidden_states = hidden_states + attn_output
196
+
197
+ # 2. Cross-Attention
198
+ hidden_states = hidden_states + self.attn2(
199
+ self.norm2(hidden_states),
200
+ encoder_hidden_states=encoder_hidden_states,
201
+ image_rotary_emb=image_rotary_emb,
202
+ )
203
+
204
+ # FFN Layer ### TODO: switch norm2 and norm3 in the state dict
205
+ mlp_inputs = self.norm3(hidden_states)
206
+ hidden_states = hidden_states + self.ff(mlp_inputs)
207
+
208
+ return hidden_states
209
+
210
+
211
+ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
212
+ """
213
+ HunYuanDiT: Diffusion model with a Transformer backbone.
214
+
215
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
216
+
217
+ Parameters:
218
+ num_attention_heads (`int`, *optional*, defaults to 16):
219
+ The number of heads to use for multi-head attention.
220
+ attention_head_dim (`int`, *optional*, defaults to 88):
221
+ The number of channels in each head.
222
+ in_channels (`int`, *optional*):
223
+ The number of channels in the input and output (specify if the input is **continuous**).
224
+ patch_size (`int`, *optional*):
225
+ The size of the patch to use for the input.
226
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
227
+ Activation function to use in feed-forward.
228
+ sample_size (`int`, *optional*):
229
+ The width of the latent images. This is fixed during training since it is used to learn a number of
230
+ position embeddings.
231
+ dropout (`float`, *optional*, defaults to 0.0):
232
+ The dropout probability to use.
233
+ cross_attention_dim (`int`, *optional*):
234
+ The number of dimension in the clip text embedding.
235
+ hidden_size (`int`, *optional*):
236
+ The size of hidden layer in the conditioning embedding layers.
237
+ num_layers (`int`, *optional*, defaults to 1):
238
+ The number of layers of Transformer blocks to use.
239
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
240
+ The ratio of the hidden layer size to the input size.
241
+ learn_sigma (`bool`, *optional*, defaults to `True`):
242
+ Whether to predict variance.
243
+ cross_attention_dim_t5 (`int`, *optional*):
244
+ The number dimensions in t5 text embedding.
245
+ pooled_projection_dim (`int`, *optional*):
246
+ The size of the pooled projection.
247
+ text_len (`int`, *optional*):
248
+ The length of the clip text embedding.
249
+ text_len_t5 (`int`, *optional*):
250
+ The length of the T5 text embedding.
251
+ """
252
+
253
+ @register_to_config
254
+ def __init__(
255
+ self,
256
+ num_attention_heads: int = 16,
257
+ attention_head_dim: int = 88,
258
+ in_channels: Optional[int] = None,
259
+ patch_size: Optional[int] = None,
260
+ activation_fn: str = "gelu-approximate",
261
+ sample_size=32,
262
+ hidden_size=1152,
263
+ num_layers: int = 28,
264
+ mlp_ratio: float = 4.0,
265
+ learn_sigma: bool = True,
266
+ cross_attention_dim: int = 1024,
267
+ norm_type: str = "layer_norm",
268
+ cross_attention_dim_t5: int = 2048,
269
+ pooled_projection_dim: int = 1024,
270
+ text_len: int = 77,
271
+ text_len_t5: int = 256,
272
+ ):
273
+ super().__init__()
274
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
275
+ self.num_heads = num_attention_heads
276
+ self.inner_dim = num_attention_heads * attention_head_dim
277
+
278
+ self.text_embedder = PixArtAlphaTextProjection(
279
+ in_features=cross_attention_dim_t5,
280
+ hidden_size=cross_attention_dim_t5 * 4,
281
+ out_features=cross_attention_dim,
282
+ act_fn="silu_fp32",
283
+ )
284
+
285
+ self.text_embedding_padding = nn.Parameter(
286
+ torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
287
+ )
288
+
289
+ self.pos_embed = PatchEmbed(
290
+ height=sample_size,
291
+ width=sample_size,
292
+ in_channels=in_channels,
293
+ embed_dim=hidden_size,
294
+ patch_size=patch_size,
295
+ pos_embed_type=None,
296
+ )
297
+
298
+ self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
299
+ hidden_size,
300
+ pooled_projection_dim=pooled_projection_dim,
301
+ seq_len=text_len_t5,
302
+ cross_attention_dim=cross_attention_dim_t5,
303
+ )
304
+
305
+ # HunyuanDiT Blocks
306
+ self.blocks = nn.ModuleList(
307
+ [
308
+ HunyuanDiTBlock(
309
+ dim=self.inner_dim,
310
+ num_attention_heads=self.config.num_attention_heads,
311
+ activation_fn=activation_fn,
312
+ ff_inner_dim=int(self.inner_dim * mlp_ratio),
313
+ cross_attention_dim=cross_attention_dim,
314
+ qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
315
+ skip=layer > num_layers // 2,
316
+ )
317
+ for layer in range(num_layers)
318
+ ]
319
+ )
320
+
321
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
322
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
323
+
324
+ def forward(
325
+ self,
326
+ hidden_states,
327
+ timestep,
328
+ encoder_hidden_states=None,
329
+ text_embedding_mask=None,
330
+ encoder_hidden_states_t5=None,
331
+ text_embedding_mask_t5=None,
332
+ image_meta_size=None,
333
+ style=None,
334
+ image_rotary_emb=None,
335
+ return_dict=True,
336
+ ):
337
+ """
338
+ The [`HunyuanDiT2DModel`] forward method.
339
+
340
+ Args:
341
+ hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
342
+ The input tensor.
343
+ timestep ( `torch.LongTensor`, *optional*):
344
+ Used to indicate denoising step.
345
+ encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
346
+ Conditional embeddings for cross attention layer. This is the output of `BertModel`.
347
+ text_embedding_mask: torch.Tensor
348
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
349
+ of `BertModel`.
350
+ encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
351
+ Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
352
+ text_embedding_mask_t5: torch.Tensor
353
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
354
+ of T5 Text Encoder.
355
+ image_meta_size (torch.Tensor):
356
+ Conditional embedding indicate the image sizes
357
+ style: torch.Tensor:
358
+ Conditional embedding indicate the style
359
+ image_rotary_emb (`torch.Tensor`):
360
+ The image rotary embeddings to apply on query and key tensors during attention calculation.
361
+ return_dict: bool
362
+ Whether to return a dictionary.
363
+ """
364
+
365
+ height, width = hidden_states.shape[-2:]
366
+
367
+ hidden_states = self.pos_embed(hidden_states)
368
+
369
+ temb = self.time_extra_emb(
370
+ timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
371
+ ) # [B, D]
372
+
373
+ # text projection
374
+ batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
375
+ encoder_hidden_states_t5 = self.text_embedder(
376
+ encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
377
+ )
378
+ encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)
379
+
380
+ encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1)
381
+ text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
382
+ text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
383
+
384
+ encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
385
+
386
+ skips = []
387
+ for layer, block in enumerate(self.blocks):
388
+ if layer > self.config.num_layers // 2:
389
+ skip = skips.pop()
390
+ hidden_states = block(
391
+ hidden_states,
392
+ temb=temb,
393
+ encoder_hidden_states=encoder_hidden_states,
394
+ image_rotary_emb=image_rotary_emb,
395
+ skip=skip,
396
+ ) # (N, L, D)
397
+ else:
398
+ hidden_states = block(
399
+ hidden_states,
400
+ temb=temb,
401
+ encoder_hidden_states=encoder_hidden_states,
402
+ image_rotary_emb=image_rotary_emb,
403
+ ) # (N, L, D)
404
+
405
+ if layer < (self.config.num_layers // 2 - 1):
406
+ skips.append(hidden_states)
407
+
408
+ # final layer
409
+ hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
410
+ hidden_states = self.proj_out(hidden_states)
411
+ # (N, L, patch_size ** 2 * out_channels)
412
+
413
+ # unpatchify: (N, out_channels, H, W)
414
+ patch_size = self.pos_embed.patch_size
415
+ height = height // patch_size
416
+ width = width // patch_size
417
+
418
+ hidden_states = hidden_states.reshape(
419
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
420
+ )
421
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
422
+ output = hidden_states.reshape(
423
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
424
+ )
425
+ if not return_dict:
426
+ return (output,)
427
+ return Transformer2DModelOutput(sample=output)