diffusers 0.28.0__py3-none-any.whl → 0.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +9 -1
- diffusers/configuration_utils.py +17 -0
- diffusers/models/__init__.py +6 -0
- diffusers/models/activations.py +12 -0
- diffusers/models/attention_processor.py +108 -0
- diffusers/models/embeddings.py +216 -8
- diffusers/models/model_loading_utils.py +28 -0
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_utils.py +57 -1
- diffusers/models/normalization.py +2 -1
- diffusers/models/transformers/__init__.py +3 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
- diffusers/models/transformers/pixart_transformer_2d.py +336 -0
- diffusers/models/transformers/transformer_2d.py +37 -45
- diffusers/pipelines/__init__.py +2 -0
- diffusers/pipelines/dit/pipeline_dit.py +4 -4
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
- diffusers/pipelines/pipeline_loading_utils.py +1 -0
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +4 -4
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +2 -2
- diffusers/utils/dummy_pt_objects.py +45 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +15 -0
- {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/METADATA +44 -44
- {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/RECORD +30 -25
- {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/WHEEL +1 -1
- {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
- {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
diffusers/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
__version__ = "0.28.
|
1
|
+
__version__ = "0.28.1"
|
2
2
|
|
3
3
|
from typing import TYPE_CHECKING
|
4
4
|
|
@@ -82,11 +82,14 @@ else:
|
|
82
82
|
"ConsistencyDecoderVAE",
|
83
83
|
"ControlNetModel",
|
84
84
|
"ControlNetXSAdapter",
|
85
|
+
"DiTTransformer2DModel",
|
86
|
+
"HunyuanDiT2DModel",
|
85
87
|
"I2VGenXLUNet",
|
86
88
|
"Kandinsky3UNet",
|
87
89
|
"ModelMixin",
|
88
90
|
"MotionAdapter",
|
89
91
|
"MultiAdapter",
|
92
|
+
"PixArtTransformer2DModel",
|
90
93
|
"PriorTransformer",
|
91
94
|
"StableCascadeUNet",
|
92
95
|
"T2IAdapter",
|
@@ -227,6 +230,7 @@ else:
|
|
227
230
|
"BlipDiffusionPipeline",
|
228
231
|
"CLIPImageProjection",
|
229
232
|
"CycleDiffusionPipeline",
|
233
|
+
"HunyuanDiTPipeline",
|
230
234
|
"I2VGenXLPipeline",
|
231
235
|
"IFImg2ImgPipeline",
|
232
236
|
"IFImg2ImgSuperResolutionPipeline",
|
@@ -484,11 +488,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
484
488
|
ConsistencyDecoderVAE,
|
485
489
|
ControlNetModel,
|
486
490
|
ControlNetXSAdapter,
|
491
|
+
DiTTransformer2DModel,
|
492
|
+
HunyuanDiT2DModel,
|
487
493
|
I2VGenXLUNet,
|
488
494
|
Kandinsky3UNet,
|
489
495
|
ModelMixin,
|
490
496
|
MotionAdapter,
|
491
497
|
MultiAdapter,
|
498
|
+
PixArtTransformer2DModel,
|
492
499
|
PriorTransformer,
|
493
500
|
T2IAdapter,
|
494
501
|
T5FilmDecoder,
|
@@ -607,6 +614,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
607
614
|
AudioLDMPipeline,
|
608
615
|
CLIPImageProjection,
|
609
616
|
CycleDiffusionPipeline,
|
617
|
+
HunyuanDiTPipeline,
|
610
618
|
I2VGenXLPipeline,
|
611
619
|
IFImg2ImgPipeline,
|
612
620
|
IFImg2ImgSuperResolutionPipeline,
|
diffusers/configuration_utils.py
CHANGED
@@ -706,3 +706,20 @@ def flax_register_to_config(cls):
|
|
706
706
|
|
707
707
|
cls.__init__ = init
|
708
708
|
return cls
|
709
|
+
|
710
|
+
|
711
|
+
class LegacyConfigMixin(ConfigMixin):
|
712
|
+
r"""
|
713
|
+
A subclass of `ConfigMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
|
714
|
+
pipeline-specific classes (like `DiTTransformer2DModel`).
|
715
|
+
"""
|
716
|
+
|
717
|
+
@classmethod
|
718
|
+
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
719
|
+
# To prevent depedency import problem.
|
720
|
+
from .models.model_loading_utils import _fetch_remapped_cls_from_config
|
721
|
+
|
722
|
+
# resolve remapping
|
723
|
+
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
724
|
+
|
725
|
+
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
|
diffusers/models/__init__.py
CHANGED
@@ -36,6 +36,9 @@ if is_torch_available():
|
|
36
36
|
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
37
37
|
_import_structure["embeddings"] = ["ImageProjection"]
|
38
38
|
_import_structure["modeling_utils"] = ["ModelMixin"]
|
39
|
+
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
|
40
|
+
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
|
41
|
+
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
|
39
42
|
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
|
40
43
|
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
41
44
|
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
@@ -73,7 +76,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
73
76
|
from .embeddings import ImageProjection
|
74
77
|
from .modeling_utils import ModelMixin
|
75
78
|
from .transformers import (
|
79
|
+
DiTTransformer2DModel,
|
76
80
|
DualTransformer2DModel,
|
81
|
+
HunyuanDiT2DModel,
|
82
|
+
PixArtTransformer2DModel,
|
77
83
|
PriorTransformer,
|
78
84
|
T5FilmDecoder,
|
79
85
|
Transformer2DModel,
|
diffusers/models/activations.py
CHANGED
@@ -50,6 +50,18 @@ def get_activation(act_fn: str) -> nn.Module:
|
|
50
50
|
raise ValueError(f"Unsupported activation function: {act_fn}")
|
51
51
|
|
52
52
|
|
53
|
+
class FP32SiLU(nn.Module):
|
54
|
+
r"""
|
55
|
+
SiLU activation function with input upcasted to torch.float32.
|
56
|
+
"""
|
57
|
+
|
58
|
+
def __init__(self):
|
59
|
+
super().__init__()
|
60
|
+
|
61
|
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
62
|
+
return F.silu(inputs.float(), inplace=False).to(inputs.dtype)
|
63
|
+
|
64
|
+
|
53
65
|
class GELU(nn.Module):
|
54
66
|
r"""
|
55
67
|
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
@@ -103,6 +103,7 @@ class Attention(nn.Module):
|
|
103
103
|
upcast_softmax: bool = False,
|
104
104
|
cross_attention_norm: Optional[str] = None,
|
105
105
|
cross_attention_norm_num_groups: int = 32,
|
106
|
+
qk_norm: Optional[str] = None,
|
106
107
|
added_kv_proj_dim: Optional[int] = None,
|
107
108
|
norm_num_groups: Optional[int] = None,
|
108
109
|
spatial_norm_dim: Optional[int] = None,
|
@@ -161,6 +162,15 @@ class Attention(nn.Module):
|
|
161
162
|
else:
|
162
163
|
self.spatial_norm = None
|
163
164
|
|
165
|
+
if qk_norm is None:
|
166
|
+
self.norm_q = None
|
167
|
+
self.norm_k = None
|
168
|
+
elif qk_norm == "layer_norm":
|
169
|
+
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
|
170
|
+
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
|
171
|
+
else:
|
172
|
+
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
|
173
|
+
|
164
174
|
if cross_attention_norm is None:
|
165
175
|
self.norm_cross = None
|
166
176
|
elif cross_attention_norm == "layer_norm":
|
@@ -1426,6 +1436,104 @@ class AttnProcessor2_0:
|
|
1426
1436
|
return hidden_states
|
1427
1437
|
|
1428
1438
|
|
1439
|
+
class HunyuanAttnProcessor2_0:
|
1440
|
+
r"""
|
1441
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
1442
|
+
used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
|
1443
|
+
"""
|
1444
|
+
|
1445
|
+
def __init__(self):
|
1446
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1447
|
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1448
|
+
|
1449
|
+
def __call__(
|
1450
|
+
self,
|
1451
|
+
attn: Attention,
|
1452
|
+
hidden_states: torch.Tensor,
|
1453
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1454
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1455
|
+
temb: Optional[torch.Tensor] = None,
|
1456
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
1457
|
+
) -> torch.Tensor:
|
1458
|
+
from .embeddings import apply_rotary_emb
|
1459
|
+
|
1460
|
+
residual = hidden_states
|
1461
|
+
if attn.spatial_norm is not None:
|
1462
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1463
|
+
|
1464
|
+
input_ndim = hidden_states.ndim
|
1465
|
+
|
1466
|
+
if input_ndim == 4:
|
1467
|
+
batch_size, channel, height, width = hidden_states.shape
|
1468
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1469
|
+
|
1470
|
+
batch_size, sequence_length, _ = (
|
1471
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1472
|
+
)
|
1473
|
+
|
1474
|
+
if attention_mask is not None:
|
1475
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1476
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
1477
|
+
# (batch, heads, source_length, target_length)
|
1478
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1479
|
+
|
1480
|
+
if attn.group_norm is not None:
|
1481
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1482
|
+
|
1483
|
+
query = attn.to_q(hidden_states)
|
1484
|
+
|
1485
|
+
if encoder_hidden_states is None:
|
1486
|
+
encoder_hidden_states = hidden_states
|
1487
|
+
elif attn.norm_cross:
|
1488
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1489
|
+
|
1490
|
+
key = attn.to_k(encoder_hidden_states)
|
1491
|
+
value = attn.to_v(encoder_hidden_states)
|
1492
|
+
|
1493
|
+
inner_dim = key.shape[-1]
|
1494
|
+
head_dim = inner_dim // attn.heads
|
1495
|
+
|
1496
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1497
|
+
|
1498
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1499
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1500
|
+
|
1501
|
+
if attn.norm_q is not None:
|
1502
|
+
query = attn.norm_q(query)
|
1503
|
+
if attn.norm_k is not None:
|
1504
|
+
key = attn.norm_k(key)
|
1505
|
+
|
1506
|
+
# Apply RoPE if needed
|
1507
|
+
if image_rotary_emb is not None:
|
1508
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
1509
|
+
if not attn.is_cross_attention:
|
1510
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
1511
|
+
|
1512
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1513
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1514
|
+
hidden_states = F.scaled_dot_product_attention(
|
1515
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1516
|
+
)
|
1517
|
+
|
1518
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1519
|
+
hidden_states = hidden_states.to(query.dtype)
|
1520
|
+
|
1521
|
+
# linear proj
|
1522
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1523
|
+
# dropout
|
1524
|
+
hidden_states = attn.to_out[1](hidden_states)
|
1525
|
+
|
1526
|
+
if input_ndim == 4:
|
1527
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1528
|
+
|
1529
|
+
if attn.residual_connection:
|
1530
|
+
hidden_states = hidden_states + residual
|
1531
|
+
|
1532
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1533
|
+
|
1534
|
+
return hidden_states
|
1535
|
+
|
1536
|
+
|
1429
1537
|
class FusedAttnProcessor2_0:
|
1430
1538
|
r"""
|
1431
1539
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
|
diffusers/models/embeddings.py
CHANGED
@@ -16,10 +16,11 @@ from typing import List, Optional, Tuple, Union
|
|
16
16
|
|
17
17
|
import numpy as np
|
18
18
|
import torch
|
19
|
+
import torch.nn.functional as F
|
19
20
|
from torch import nn
|
20
21
|
|
21
22
|
from ..utils import deprecate
|
22
|
-
from .activations import get_activation
|
23
|
+
from .activations import FP32SiLU, get_activation
|
23
24
|
from .attention_processor import Attention
|
24
25
|
|
25
26
|
|
@@ -135,6 +136,7 @@ class PatchEmbed(nn.Module):
|
|
135
136
|
flatten=True,
|
136
137
|
bias=True,
|
137
138
|
interpolation_scale=1,
|
139
|
+
pos_embed_type="sincos",
|
138
140
|
):
|
139
141
|
super().__init__()
|
140
142
|
|
@@ -156,10 +158,18 @@ class PatchEmbed(nn.Module):
|
|
156
158
|
self.height, self.width = height // patch_size, width // patch_size
|
157
159
|
self.base_size = height // patch_size
|
158
160
|
self.interpolation_scale = interpolation_scale
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
161
|
+
if pos_embed_type is None:
|
162
|
+
self.pos_embed = None
|
163
|
+
elif pos_embed_type == "sincos":
|
164
|
+
pos_embed = get_2d_sincos_pos_embed(
|
165
|
+
embed_dim,
|
166
|
+
int(num_patches**0.5),
|
167
|
+
base_size=self.base_size,
|
168
|
+
interpolation_scale=self.interpolation_scale,
|
169
|
+
)
|
170
|
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
|
171
|
+
else:
|
172
|
+
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
|
163
173
|
|
164
174
|
def forward(self, latent):
|
165
175
|
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
@@ -169,6 +179,8 @@ class PatchEmbed(nn.Module):
|
|
169
179
|
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
170
180
|
if self.layer_norm:
|
171
181
|
latent = self.norm(latent)
|
182
|
+
if self.pos_embed is None:
|
183
|
+
return latent.to(latent.dtype)
|
172
184
|
|
173
185
|
# Interpolate positional embeddings if needed.
|
174
186
|
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
|
@@ -187,6 +199,113 @@ class PatchEmbed(nn.Module):
|
|
187
199
|
return (latent + pos_embed).to(latent.dtype)
|
188
200
|
|
189
201
|
|
202
|
+
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
203
|
+
"""
|
204
|
+
RoPE for image tokens with 2d structure.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
embed_dim: (`int`):
|
208
|
+
The embedding dimension size
|
209
|
+
crops_coords (`Tuple[int]`)
|
210
|
+
The top-left and bottom-right coordinates of the crop.
|
211
|
+
grid_size (`Tuple[int]`):
|
212
|
+
The grid size of the positional embedding.
|
213
|
+
use_real (`bool`):
|
214
|
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
215
|
+
|
216
|
+
Returns:
|
217
|
+
`torch.Tensor`: positional embdding with shape `( grid_size * grid_size, embed_dim/2)`.
|
218
|
+
"""
|
219
|
+
start, stop = crops_coords
|
220
|
+
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
221
|
+
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
222
|
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
223
|
+
grid = np.stack(grid, axis=0) # [2, W, H]
|
224
|
+
|
225
|
+
grid = grid.reshape([2, 1, *grid.shape[1:]])
|
226
|
+
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
227
|
+
return pos_embed
|
228
|
+
|
229
|
+
|
230
|
+
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
231
|
+
assert embed_dim % 4 == 0
|
232
|
+
|
233
|
+
# use half of dimensions to encode grid_h
|
234
|
+
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
|
235
|
+
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
|
236
|
+
|
237
|
+
if use_real:
|
238
|
+
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
|
239
|
+
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
|
240
|
+
return cos, sin
|
241
|
+
else:
|
242
|
+
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
243
|
+
return emb
|
244
|
+
|
245
|
+
|
246
|
+
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
|
247
|
+
"""
|
248
|
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
249
|
+
|
250
|
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
251
|
+
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
252
|
+
data type.
|
253
|
+
|
254
|
+
Args:
|
255
|
+
dim (`int`): Dimension of the frequency tensor.
|
256
|
+
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
257
|
+
theta (`float`, *optional*, defaults to 10000.0):
|
258
|
+
Scaling factor for frequency computation. Defaults to 10000.0.
|
259
|
+
use_real (`bool`, *optional*):
|
260
|
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
261
|
+
|
262
|
+
Returns:
|
263
|
+
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
264
|
+
"""
|
265
|
+
if isinstance(pos, int):
|
266
|
+
pos = np.arange(pos)
|
267
|
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
268
|
+
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
269
|
+
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
270
|
+
if use_real:
|
271
|
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
272
|
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
273
|
+
return freqs_cos, freqs_sin
|
274
|
+
else:
|
275
|
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
276
|
+
return freqs_cis
|
277
|
+
|
278
|
+
|
279
|
+
def apply_rotary_emb(
|
280
|
+
x: torch.Tensor,
|
281
|
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
282
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
283
|
+
"""
|
284
|
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
285
|
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
286
|
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
287
|
+
tensors contain rotary embeddings and are returned as real tensors.
|
288
|
+
|
289
|
+
Args:
|
290
|
+
x (`torch.Tensor`):
|
291
|
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
292
|
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
293
|
+
|
294
|
+
Returns:
|
295
|
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
296
|
+
"""
|
297
|
+
cos, sin = freqs_cis # [S, D]
|
298
|
+
cos = cos[None, None]
|
299
|
+
sin = sin[None, None]
|
300
|
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
301
|
+
|
302
|
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
303
|
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
304
|
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
305
|
+
|
306
|
+
return out
|
307
|
+
|
308
|
+
|
190
309
|
class TimestepEmbedding(nn.Module):
|
191
310
|
def __init__(
|
192
311
|
self,
|
@@ -507,6 +626,88 @@ class CombinedTimestepLabelEmbeddings(nn.Module):
|
|
507
626
|
return conditioning
|
508
627
|
|
509
628
|
|
629
|
+
class HunyuanDiTAttentionPool(nn.Module):
|
630
|
+
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
|
631
|
+
|
632
|
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
633
|
+
super().__init__()
|
634
|
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
|
635
|
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
636
|
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
637
|
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
638
|
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
639
|
+
self.num_heads = num_heads
|
640
|
+
|
641
|
+
def forward(self, x):
|
642
|
+
x = x.permute(1, 0, 2) # NLC -> LNC
|
643
|
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
644
|
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
645
|
+
x, _ = F.multi_head_attention_forward(
|
646
|
+
query=x[:1],
|
647
|
+
key=x,
|
648
|
+
value=x,
|
649
|
+
embed_dim_to_check=x.shape[-1],
|
650
|
+
num_heads=self.num_heads,
|
651
|
+
q_proj_weight=self.q_proj.weight,
|
652
|
+
k_proj_weight=self.k_proj.weight,
|
653
|
+
v_proj_weight=self.v_proj.weight,
|
654
|
+
in_proj_weight=None,
|
655
|
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
656
|
+
bias_k=None,
|
657
|
+
bias_v=None,
|
658
|
+
add_zero_attn=False,
|
659
|
+
dropout_p=0,
|
660
|
+
out_proj_weight=self.c_proj.weight,
|
661
|
+
out_proj_bias=self.c_proj.bias,
|
662
|
+
use_separate_proj_weight=True,
|
663
|
+
training=self.training,
|
664
|
+
need_weights=False,
|
665
|
+
)
|
666
|
+
return x.squeeze(0)
|
667
|
+
|
668
|
+
|
669
|
+
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
|
670
|
+
def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048):
|
671
|
+
super().__init__()
|
672
|
+
|
673
|
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
674
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
675
|
+
|
676
|
+
self.pooler = HunyuanDiTAttentionPool(
|
677
|
+
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
|
678
|
+
)
|
679
|
+
# Here we use a default learned embedder layer for future extension.
|
680
|
+
self.style_embedder = nn.Embedding(1, embedding_dim)
|
681
|
+
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
|
682
|
+
self.extra_embedder = PixArtAlphaTextProjection(
|
683
|
+
in_features=extra_in_dim,
|
684
|
+
hidden_size=embedding_dim * 4,
|
685
|
+
out_features=embedding_dim,
|
686
|
+
act_fn="silu_fp32",
|
687
|
+
)
|
688
|
+
|
689
|
+
def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
|
690
|
+
timesteps_proj = self.time_proj(timestep)
|
691
|
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256)
|
692
|
+
|
693
|
+
# extra condition1: text
|
694
|
+
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
|
695
|
+
|
696
|
+
# extra condition2: image meta size embdding
|
697
|
+
image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
|
698
|
+
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
|
699
|
+
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
|
700
|
+
|
701
|
+
# extra condition3: style embedding
|
702
|
+
style_embedding = self.style_embedder(style) # (N, embedding_dim)
|
703
|
+
|
704
|
+
# Concatenate all extra vectors
|
705
|
+
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
|
706
|
+
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
|
707
|
+
|
708
|
+
return conditioning
|
709
|
+
|
710
|
+
|
510
711
|
class TextTimeEmbedding(nn.Module):
|
511
712
|
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
512
713
|
super().__init__()
|
@@ -793,11 +994,18 @@ class PixArtAlphaTextProjection(nn.Module):
|
|
793
994
|
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
794
995
|
"""
|
795
996
|
|
796
|
-
def __init__(self, in_features, hidden_size,
|
997
|
+
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
|
797
998
|
super().__init__()
|
999
|
+
if out_features is None:
|
1000
|
+
out_features = hidden_size
|
798
1001
|
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
799
|
-
|
800
|
-
|
1002
|
+
if act_fn == "gelu_tanh":
|
1003
|
+
self.act_1 = nn.GELU(approximate="tanh")
|
1004
|
+
elif act_fn == "silu_fp32":
|
1005
|
+
self.act_1 = FP32SiLU()
|
1006
|
+
else:
|
1007
|
+
raise ValueError(f"Unknown activation function: {act_fn}")
|
1008
|
+
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
|
801
1009
|
|
802
1010
|
def forward(self, caption):
|
803
1011
|
hidden_states = self.linear_1(caption)
|
@@ -14,6 +14,7 @@
|
|
14
14
|
# See the License for the specific language governing permissions and
|
15
15
|
# limitations under the License.
|
16
16
|
|
17
|
+
import importlib
|
17
18
|
import inspect
|
18
19
|
import os
|
19
20
|
from collections import OrderedDict
|
@@ -32,6 +33,13 @@ from ..utils import (
|
|
32
33
|
|
33
34
|
logger = logging.get_logger(__name__)
|
34
35
|
|
36
|
+
_CLASS_REMAPPING_DICT = {
|
37
|
+
"Transformer2DModel": {
|
38
|
+
"ada_norm_zero": "DiTTransformer2DModel",
|
39
|
+
"ada_norm_single": "PixArtTransformer2DModel",
|
40
|
+
}
|
41
|
+
}
|
42
|
+
|
35
43
|
|
36
44
|
if is_accelerate_available():
|
37
45
|
from accelerate import infer_auto_device_map
|
@@ -61,6 +69,26 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_
|
|
61
69
|
return device_map
|
62
70
|
|
63
71
|
|
72
|
+
def _fetch_remapped_cls_from_config(config, old_class):
|
73
|
+
previous_class_name = old_class.__name__
|
74
|
+
remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None)
|
75
|
+
|
76
|
+
# Details:
|
77
|
+
# https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818
|
78
|
+
if remapped_class_name:
|
79
|
+
# load diffusers library to import compatible and original scheduler
|
80
|
+
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
81
|
+
remapped_class = getattr(diffusers_library, remapped_class_name)
|
82
|
+
logger.info(
|
83
|
+
f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type."
|
84
|
+
f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this"
|
85
|
+
" DOESN'T affect the final results."
|
86
|
+
)
|
87
|
+
return remapped_class
|
88
|
+
else:
|
89
|
+
return old_class
|
90
|
+
|
91
|
+
|
64
92
|
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
65
93
|
"""
|
66
94
|
Reads a checkpoint file, returning properly formatted errors if they arise.
|
@@ -15,3 +15,17 @@ class AutoencoderKLOutput(BaseOutput):
|
|
15
15
|
"""
|
16
16
|
|
17
17
|
latent_dist: "DiagonalGaussianDistribution" # noqa: F821
|
18
|
+
|
19
|
+
|
20
|
+
@dataclass
|
21
|
+
class Transformer2DModelOutput(BaseOutput):
|
22
|
+
"""
|
23
|
+
The output of [`Transformer2DModel`].
|
24
|
+
|
25
|
+
Args:
|
26
|
+
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
27
|
+
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
28
|
+
distributions for the unnoised latent pixels.
|
29
|
+
"""
|
30
|
+
|
31
|
+
sample: "torch.Tensor" # noqa: F821
|
@@ -42,7 +42,11 @@ from ..utils import (
|
|
42
42
|
is_torch_version,
|
43
43
|
logging,
|
44
44
|
)
|
45
|
-
from ..utils.hub_utils import
|
45
|
+
from ..utils.hub_utils import (
|
46
|
+
PushToHubMixin,
|
47
|
+
load_or_create_model_card,
|
48
|
+
populate_model_card,
|
49
|
+
)
|
46
50
|
from .model_loading_utils import (
|
47
51
|
_determine_device_map,
|
48
52
|
_load_state_dict_into_model,
|
@@ -1039,3 +1043,55 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1039
1043
|
del module.key
|
1040
1044
|
del module.value
|
1041
1045
|
del module.proj_attn
|
1046
|
+
|
1047
|
+
|
1048
|
+
class LegacyModelMixin(ModelMixin):
|
1049
|
+
r"""
|
1050
|
+
A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
|
1051
|
+
pipeline-specific classes (like `DiTTransformer2DModel`).
|
1052
|
+
"""
|
1053
|
+
|
1054
|
+
@classmethod
|
1055
|
+
@validate_hf_hub_args
|
1056
|
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
1057
|
+
# To prevent depedency import problem.
|
1058
|
+
from .model_loading_utils import _fetch_remapped_cls_from_config
|
1059
|
+
|
1060
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
1061
|
+
force_download = kwargs.pop("force_download", False)
|
1062
|
+
resume_download = kwargs.pop("resume_download", None)
|
1063
|
+
proxies = kwargs.pop("proxies", None)
|
1064
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
1065
|
+
token = kwargs.pop("token", None)
|
1066
|
+
revision = kwargs.pop("revision", None)
|
1067
|
+
subfolder = kwargs.pop("subfolder", None)
|
1068
|
+
|
1069
|
+
# Load config if we don't provide a configuration
|
1070
|
+
config_path = pretrained_model_name_or_path
|
1071
|
+
|
1072
|
+
user_agent = {
|
1073
|
+
"diffusers": __version__,
|
1074
|
+
"file_type": "model",
|
1075
|
+
"framework": "pytorch",
|
1076
|
+
}
|
1077
|
+
|
1078
|
+
# load config
|
1079
|
+
config, _, _ = cls.load_config(
|
1080
|
+
config_path,
|
1081
|
+
cache_dir=cache_dir,
|
1082
|
+
return_unused_kwargs=True,
|
1083
|
+
return_commit_hash=True,
|
1084
|
+
force_download=force_download,
|
1085
|
+
resume_download=resume_download,
|
1086
|
+
proxies=proxies,
|
1087
|
+
local_files_only=local_files_only,
|
1088
|
+
token=token,
|
1089
|
+
revision=revision,
|
1090
|
+
subfolder=subfolder,
|
1091
|
+
user_agent=user_agent,
|
1092
|
+
**kwargs,
|
1093
|
+
)
|
1094
|
+
# resolve remapping
|
1095
|
+
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
1096
|
+
|
1097
|
+
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
@@ -176,7 +176,8 @@ class AdaLayerNormContinuous(nn.Module):
|
|
176
176
|
raise ValueError(f"unknown norm_type {norm_type}")
|
177
177
|
|
178
178
|
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
179
|
-
|
179
|
+
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
180
|
+
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
|
180
181
|
scale, shift = torch.chunk(emb, 2, dim=1)
|
181
182
|
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
182
183
|
return x
|
@@ -2,7 +2,10 @@ from ...utils import is_torch_available
|
|
2
2
|
|
3
3
|
|
4
4
|
if is_torch_available():
|
5
|
+
from .dit_transformer_2d import DiTTransformer2DModel
|
5
6
|
from .dual_transformer_2d import DualTransformer2DModel
|
7
|
+
from .hunyuan_transformer_2d import HunyuanDiT2DModel
|
8
|
+
from .pixart_transformer_2d import PixArtTransformer2DModel
|
6
9
|
from .prior_transformer import PriorTransformer
|
7
10
|
from .t5_film_transformer import T5FilmDecoder
|
8
11
|
from .transformer_2d import Transformer2DModel
|