diffusers 0.28.0__py3-none-any.whl → 0.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. diffusers/__init__.py +9 -1
  2. diffusers/configuration_utils.py +17 -0
  3. diffusers/models/__init__.py +6 -0
  4. diffusers/models/activations.py +12 -0
  5. diffusers/models/attention_processor.py +108 -0
  6. diffusers/models/embeddings.py +216 -8
  7. diffusers/models/model_loading_utils.py +28 -0
  8. diffusers/models/modeling_outputs.py +14 -0
  9. diffusers/models/modeling_utils.py +57 -1
  10. diffusers/models/normalization.py +2 -1
  11. diffusers/models/transformers/__init__.py +3 -0
  12. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  13. diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
  14. diffusers/models/transformers/pixart_transformer_2d.py +336 -0
  15. diffusers/models/transformers/transformer_2d.py +37 -45
  16. diffusers/pipelines/__init__.py +2 -0
  17. diffusers/pipelines/dit/pipeline_dit.py +4 -4
  18. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  19. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
  20. diffusers/pipelines/pipeline_loading_utils.py +1 -0
  21. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +4 -4
  22. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +2 -2
  23. diffusers/utils/dummy_pt_objects.py +45 -0
  24. diffusers/utils/dummy_torch_and_transformers_objects.py +15 -0
  25. {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/METADATA +44 -44
  26. {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/RECORD +30 -25
  27. {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/WHEEL +1 -1
  28. {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
  29. {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
  30. {diffusers-0.28.0.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
diffusers/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.28.0"
1
+ __version__ = "0.28.1"
2
2
 
3
3
  from typing import TYPE_CHECKING
4
4
 
@@ -82,11 +82,14 @@ else:
82
82
  "ConsistencyDecoderVAE",
83
83
  "ControlNetModel",
84
84
  "ControlNetXSAdapter",
85
+ "DiTTransformer2DModel",
86
+ "HunyuanDiT2DModel",
85
87
  "I2VGenXLUNet",
86
88
  "Kandinsky3UNet",
87
89
  "ModelMixin",
88
90
  "MotionAdapter",
89
91
  "MultiAdapter",
92
+ "PixArtTransformer2DModel",
90
93
  "PriorTransformer",
91
94
  "StableCascadeUNet",
92
95
  "T2IAdapter",
@@ -227,6 +230,7 @@ else:
227
230
  "BlipDiffusionPipeline",
228
231
  "CLIPImageProjection",
229
232
  "CycleDiffusionPipeline",
233
+ "HunyuanDiTPipeline",
230
234
  "I2VGenXLPipeline",
231
235
  "IFImg2ImgPipeline",
232
236
  "IFImg2ImgSuperResolutionPipeline",
@@ -484,11 +488,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
484
488
  ConsistencyDecoderVAE,
485
489
  ControlNetModel,
486
490
  ControlNetXSAdapter,
491
+ DiTTransformer2DModel,
492
+ HunyuanDiT2DModel,
487
493
  I2VGenXLUNet,
488
494
  Kandinsky3UNet,
489
495
  ModelMixin,
490
496
  MotionAdapter,
491
497
  MultiAdapter,
498
+ PixArtTransformer2DModel,
492
499
  PriorTransformer,
493
500
  T2IAdapter,
494
501
  T5FilmDecoder,
@@ -607,6 +614,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
607
614
  AudioLDMPipeline,
608
615
  CLIPImageProjection,
609
616
  CycleDiffusionPipeline,
617
+ HunyuanDiTPipeline,
610
618
  I2VGenXLPipeline,
611
619
  IFImg2ImgPipeline,
612
620
  IFImg2ImgSuperResolutionPipeline,
@@ -706,3 +706,20 @@ def flax_register_to_config(cls):
706
706
 
707
707
  cls.__init__ = init
708
708
  return cls
709
+
710
+
711
+ class LegacyConfigMixin(ConfigMixin):
712
+ r"""
713
+ A subclass of `ConfigMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
714
+ pipeline-specific classes (like `DiTTransformer2DModel`).
715
+ """
716
+
717
+ @classmethod
718
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
719
+ # To prevent depedency import problem.
720
+ from .models.model_loading_utils import _fetch_remapped_cls_from_config
721
+
722
+ # resolve remapping
723
+ remapped_class = _fetch_remapped_cls_from_config(config, cls)
724
+
725
+ return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
@@ -36,6 +36,9 @@ if is_torch_available():
36
36
  _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
37
37
  _import_structure["embeddings"] = ["ImageProjection"]
38
38
  _import_structure["modeling_utils"] = ["ModelMixin"]
39
+ _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
40
+ _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
41
+ _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
39
42
  _import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
40
43
  _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
41
44
  _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
@@ -73,7 +76,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
73
76
  from .embeddings import ImageProjection
74
77
  from .modeling_utils import ModelMixin
75
78
  from .transformers import (
79
+ DiTTransformer2DModel,
76
80
  DualTransformer2DModel,
81
+ HunyuanDiT2DModel,
82
+ PixArtTransformer2DModel,
77
83
  PriorTransformer,
78
84
  T5FilmDecoder,
79
85
  Transformer2DModel,
@@ -50,6 +50,18 @@ def get_activation(act_fn: str) -> nn.Module:
50
50
  raise ValueError(f"Unsupported activation function: {act_fn}")
51
51
 
52
52
 
53
+ class FP32SiLU(nn.Module):
54
+ r"""
55
+ SiLU activation function with input upcasted to torch.float32.
56
+ """
57
+
58
+ def __init__(self):
59
+ super().__init__()
60
+
61
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
62
+ return F.silu(inputs.float(), inplace=False).to(inputs.dtype)
63
+
64
+
53
65
  class GELU(nn.Module):
54
66
  r"""
55
67
  GELU activation function with tanh approximation support with `approximate="tanh"`.
@@ -103,6 +103,7 @@ class Attention(nn.Module):
103
103
  upcast_softmax: bool = False,
104
104
  cross_attention_norm: Optional[str] = None,
105
105
  cross_attention_norm_num_groups: int = 32,
106
+ qk_norm: Optional[str] = None,
106
107
  added_kv_proj_dim: Optional[int] = None,
107
108
  norm_num_groups: Optional[int] = None,
108
109
  spatial_norm_dim: Optional[int] = None,
@@ -161,6 +162,15 @@ class Attention(nn.Module):
161
162
  else:
162
163
  self.spatial_norm = None
163
164
 
165
+ if qk_norm is None:
166
+ self.norm_q = None
167
+ self.norm_k = None
168
+ elif qk_norm == "layer_norm":
169
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps)
170
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps)
171
+ else:
172
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
173
+
164
174
  if cross_attention_norm is None:
165
175
  self.norm_cross = None
166
176
  elif cross_attention_norm == "layer_norm":
@@ -1426,6 +1436,104 @@ class AttnProcessor2_0:
1426
1436
  return hidden_states
1427
1437
 
1428
1438
 
1439
+ class HunyuanAttnProcessor2_0:
1440
+ r"""
1441
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
1442
+ used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
1443
+ """
1444
+
1445
+ def __init__(self):
1446
+ if not hasattr(F, "scaled_dot_product_attention"):
1447
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1448
+
1449
+ def __call__(
1450
+ self,
1451
+ attn: Attention,
1452
+ hidden_states: torch.Tensor,
1453
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1454
+ attention_mask: Optional[torch.Tensor] = None,
1455
+ temb: Optional[torch.Tensor] = None,
1456
+ image_rotary_emb: Optional[torch.Tensor] = None,
1457
+ ) -> torch.Tensor:
1458
+ from .embeddings import apply_rotary_emb
1459
+
1460
+ residual = hidden_states
1461
+ if attn.spatial_norm is not None:
1462
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1463
+
1464
+ input_ndim = hidden_states.ndim
1465
+
1466
+ if input_ndim == 4:
1467
+ batch_size, channel, height, width = hidden_states.shape
1468
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1469
+
1470
+ batch_size, sequence_length, _ = (
1471
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1472
+ )
1473
+
1474
+ if attention_mask is not None:
1475
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1476
+ # scaled_dot_product_attention expects attention_mask shape to be
1477
+ # (batch, heads, source_length, target_length)
1478
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1479
+
1480
+ if attn.group_norm is not None:
1481
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1482
+
1483
+ query = attn.to_q(hidden_states)
1484
+
1485
+ if encoder_hidden_states is None:
1486
+ encoder_hidden_states = hidden_states
1487
+ elif attn.norm_cross:
1488
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1489
+
1490
+ key = attn.to_k(encoder_hidden_states)
1491
+ value = attn.to_v(encoder_hidden_states)
1492
+
1493
+ inner_dim = key.shape[-1]
1494
+ head_dim = inner_dim // attn.heads
1495
+
1496
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1497
+
1498
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1499
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1500
+
1501
+ if attn.norm_q is not None:
1502
+ query = attn.norm_q(query)
1503
+ if attn.norm_k is not None:
1504
+ key = attn.norm_k(key)
1505
+
1506
+ # Apply RoPE if needed
1507
+ if image_rotary_emb is not None:
1508
+ query = apply_rotary_emb(query, image_rotary_emb)
1509
+ if not attn.is_cross_attention:
1510
+ key = apply_rotary_emb(key, image_rotary_emb)
1511
+
1512
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1513
+ # TODO: add support for attn.scale when we move to Torch 2.1
1514
+ hidden_states = F.scaled_dot_product_attention(
1515
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1516
+ )
1517
+
1518
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1519
+ hidden_states = hidden_states.to(query.dtype)
1520
+
1521
+ # linear proj
1522
+ hidden_states = attn.to_out[0](hidden_states)
1523
+ # dropout
1524
+ hidden_states = attn.to_out[1](hidden_states)
1525
+
1526
+ if input_ndim == 4:
1527
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1528
+
1529
+ if attn.residual_connection:
1530
+ hidden_states = hidden_states + residual
1531
+
1532
+ hidden_states = hidden_states / attn.rescale_output_factor
1533
+
1534
+ return hidden_states
1535
+
1536
+
1429
1537
  class FusedAttnProcessor2_0:
1430
1538
  r"""
1431
1539
  Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
@@ -16,10 +16,11 @@ from typing import List, Optional, Tuple, Union
16
16
 
17
17
  import numpy as np
18
18
  import torch
19
+ import torch.nn.functional as F
19
20
  from torch import nn
20
21
 
21
22
  from ..utils import deprecate
22
- from .activations import get_activation
23
+ from .activations import FP32SiLU, get_activation
23
24
  from .attention_processor import Attention
24
25
 
25
26
 
@@ -135,6 +136,7 @@ class PatchEmbed(nn.Module):
135
136
  flatten=True,
136
137
  bias=True,
137
138
  interpolation_scale=1,
139
+ pos_embed_type="sincos",
138
140
  ):
139
141
  super().__init__()
140
142
 
@@ -156,10 +158,18 @@ class PatchEmbed(nn.Module):
156
158
  self.height, self.width = height // patch_size, width // patch_size
157
159
  self.base_size = height // patch_size
158
160
  self.interpolation_scale = interpolation_scale
159
- pos_embed = get_2d_sincos_pos_embed(
160
- embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
161
- )
162
- self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
161
+ if pos_embed_type is None:
162
+ self.pos_embed = None
163
+ elif pos_embed_type == "sincos":
164
+ pos_embed = get_2d_sincos_pos_embed(
165
+ embed_dim,
166
+ int(num_patches**0.5),
167
+ base_size=self.base_size,
168
+ interpolation_scale=self.interpolation_scale,
169
+ )
170
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
171
+ else:
172
+ raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
163
173
 
164
174
  def forward(self, latent):
165
175
  height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
@@ -169,6 +179,8 @@ class PatchEmbed(nn.Module):
169
179
  latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
170
180
  if self.layer_norm:
171
181
  latent = self.norm(latent)
182
+ if self.pos_embed is None:
183
+ return latent.to(latent.dtype)
172
184
 
173
185
  # Interpolate positional embeddings if needed.
174
186
  # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
@@ -187,6 +199,113 @@ class PatchEmbed(nn.Module):
187
199
  return (latent + pos_embed).to(latent.dtype)
188
200
 
189
201
 
202
+ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
203
+ """
204
+ RoPE for image tokens with 2d structure.
205
+
206
+ Args:
207
+ embed_dim: (`int`):
208
+ The embedding dimension size
209
+ crops_coords (`Tuple[int]`)
210
+ The top-left and bottom-right coordinates of the crop.
211
+ grid_size (`Tuple[int]`):
212
+ The grid size of the positional embedding.
213
+ use_real (`bool`):
214
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
215
+
216
+ Returns:
217
+ `torch.Tensor`: positional embdding with shape `( grid_size * grid_size, embed_dim/2)`.
218
+ """
219
+ start, stop = crops_coords
220
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
221
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
222
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
223
+ grid = np.stack(grid, axis=0) # [2, W, H]
224
+
225
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
226
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
227
+ return pos_embed
228
+
229
+
230
+ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
231
+ assert embed_dim % 4 == 0
232
+
233
+ # use half of dimensions to encode grid_h
234
+ emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
235
+ emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
236
+
237
+ if use_real:
238
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
239
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
240
+ return cos, sin
241
+ else:
242
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
243
+ return emb
244
+
245
+
246
+ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
247
+ """
248
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
249
+
250
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
251
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
252
+ data type.
253
+
254
+ Args:
255
+ dim (`int`): Dimension of the frequency tensor.
256
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
257
+ theta (`float`, *optional*, defaults to 10000.0):
258
+ Scaling factor for frequency computation. Defaults to 10000.0.
259
+ use_real (`bool`, *optional*):
260
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
261
+
262
+ Returns:
263
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
264
+ """
265
+ if isinstance(pos, int):
266
+ pos = np.arange(pos)
267
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
268
+ t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
269
+ freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
270
+ if use_real:
271
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
272
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
273
+ return freqs_cos, freqs_sin
274
+ else:
275
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
276
+ return freqs_cis
277
+
278
+
279
+ def apply_rotary_emb(
280
+ x: torch.Tensor,
281
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
282
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
283
+ """
284
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
285
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
286
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
287
+ tensors contain rotary embeddings and are returned as real tensors.
288
+
289
+ Args:
290
+ x (`torch.Tensor`):
291
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
292
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
293
+
294
+ Returns:
295
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
296
+ """
297
+ cos, sin = freqs_cis # [S, D]
298
+ cos = cos[None, None]
299
+ sin = sin[None, None]
300
+ cos, sin = cos.to(x.device), sin.to(x.device)
301
+
302
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
303
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
304
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
305
+
306
+ return out
307
+
308
+
190
309
  class TimestepEmbedding(nn.Module):
191
310
  def __init__(
192
311
  self,
@@ -507,6 +626,88 @@ class CombinedTimestepLabelEmbeddings(nn.Module):
507
626
  return conditioning
508
627
 
509
628
 
629
+ class HunyuanDiTAttentionPool(nn.Module):
630
+ # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
631
+
632
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
633
+ super().__init__()
634
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
635
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
636
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
637
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
638
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
639
+ self.num_heads = num_heads
640
+
641
+ def forward(self, x):
642
+ x = x.permute(1, 0, 2) # NLC -> LNC
643
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
644
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
645
+ x, _ = F.multi_head_attention_forward(
646
+ query=x[:1],
647
+ key=x,
648
+ value=x,
649
+ embed_dim_to_check=x.shape[-1],
650
+ num_heads=self.num_heads,
651
+ q_proj_weight=self.q_proj.weight,
652
+ k_proj_weight=self.k_proj.weight,
653
+ v_proj_weight=self.v_proj.weight,
654
+ in_proj_weight=None,
655
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
656
+ bias_k=None,
657
+ bias_v=None,
658
+ add_zero_attn=False,
659
+ dropout_p=0,
660
+ out_proj_weight=self.c_proj.weight,
661
+ out_proj_bias=self.c_proj.bias,
662
+ use_separate_proj_weight=True,
663
+ training=self.training,
664
+ need_weights=False,
665
+ )
666
+ return x.squeeze(0)
667
+
668
+
669
+ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
670
+ def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048):
671
+ super().__init__()
672
+
673
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
674
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
675
+
676
+ self.pooler = HunyuanDiTAttentionPool(
677
+ seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
678
+ )
679
+ # Here we use a default learned embedder layer for future extension.
680
+ self.style_embedder = nn.Embedding(1, embedding_dim)
681
+ extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
682
+ self.extra_embedder = PixArtAlphaTextProjection(
683
+ in_features=extra_in_dim,
684
+ hidden_size=embedding_dim * 4,
685
+ out_features=embedding_dim,
686
+ act_fn="silu_fp32",
687
+ )
688
+
689
+ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
690
+ timesteps_proj = self.time_proj(timestep)
691
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256)
692
+
693
+ # extra condition1: text
694
+ pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
695
+
696
+ # extra condition2: image meta size embdding
697
+ image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
698
+ image_meta_size = image_meta_size.to(dtype=hidden_dtype)
699
+ image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
700
+
701
+ # extra condition3: style embedding
702
+ style_embedding = self.style_embedder(style) # (N, embedding_dim)
703
+
704
+ # Concatenate all extra vectors
705
+ extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
706
+ conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
707
+
708
+ return conditioning
709
+
710
+
510
711
  class TextTimeEmbedding(nn.Module):
511
712
  def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
512
713
  super().__init__()
@@ -793,11 +994,18 @@ class PixArtAlphaTextProjection(nn.Module):
793
994
  Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
794
995
  """
795
996
 
796
- def __init__(self, in_features, hidden_size, num_tokens=120):
997
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
797
998
  super().__init__()
999
+ if out_features is None:
1000
+ out_features = hidden_size
798
1001
  self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
799
- self.act_1 = nn.GELU(approximate="tanh")
800
- self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
1002
+ if act_fn == "gelu_tanh":
1003
+ self.act_1 = nn.GELU(approximate="tanh")
1004
+ elif act_fn == "silu_fp32":
1005
+ self.act_1 = FP32SiLU()
1006
+ else:
1007
+ raise ValueError(f"Unknown activation function: {act_fn}")
1008
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
801
1009
 
802
1010
  def forward(self, caption):
803
1011
  hidden_states = self.linear_1(caption)
@@ -14,6 +14,7 @@
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
16
 
17
+ import importlib
17
18
  import inspect
18
19
  import os
19
20
  from collections import OrderedDict
@@ -32,6 +33,13 @@ from ..utils import (
32
33
 
33
34
  logger = logging.get_logger(__name__)
34
35
 
36
+ _CLASS_REMAPPING_DICT = {
37
+ "Transformer2DModel": {
38
+ "ada_norm_zero": "DiTTransformer2DModel",
39
+ "ada_norm_single": "PixArtTransformer2DModel",
40
+ }
41
+ }
42
+
35
43
 
36
44
  if is_accelerate_available():
37
45
  from accelerate import infer_auto_device_map
@@ -61,6 +69,26 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_
61
69
  return device_map
62
70
 
63
71
 
72
+ def _fetch_remapped_cls_from_config(config, old_class):
73
+ previous_class_name = old_class.__name__
74
+ remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None)
75
+
76
+ # Details:
77
+ # https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818
78
+ if remapped_class_name:
79
+ # load diffusers library to import compatible and original scheduler
80
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
81
+ remapped_class = getattr(diffusers_library, remapped_class_name)
82
+ logger.info(
83
+ f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type."
84
+ f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this"
85
+ " DOESN'T affect the final results."
86
+ )
87
+ return remapped_class
88
+ else:
89
+ return old_class
90
+
91
+
64
92
  def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
65
93
  """
66
94
  Reads a checkpoint file, returning properly formatted errors if they arise.
@@ -15,3 +15,17 @@ class AutoencoderKLOutput(BaseOutput):
15
15
  """
16
16
 
17
17
  latent_dist: "DiagonalGaussianDistribution" # noqa: F821
18
+
19
+
20
+ @dataclass
21
+ class Transformer2DModelOutput(BaseOutput):
22
+ """
23
+ The output of [`Transformer2DModel`].
24
+
25
+ Args:
26
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
27
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
28
+ distributions for the unnoised latent pixels.
29
+ """
30
+
31
+ sample: "torch.Tensor" # noqa: F821
@@ -42,7 +42,11 @@ from ..utils import (
42
42
  is_torch_version,
43
43
  logging,
44
44
  )
45
- from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
45
+ from ..utils.hub_utils import (
46
+ PushToHubMixin,
47
+ load_or_create_model_card,
48
+ populate_model_card,
49
+ )
46
50
  from .model_loading_utils import (
47
51
  _determine_device_map,
48
52
  _load_state_dict_into_model,
@@ -1039,3 +1043,55 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1039
1043
  del module.key
1040
1044
  del module.value
1041
1045
  del module.proj_attn
1046
+
1047
+
1048
+ class LegacyModelMixin(ModelMixin):
1049
+ r"""
1050
+ A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
1051
+ pipeline-specific classes (like `DiTTransformer2DModel`).
1052
+ """
1053
+
1054
+ @classmethod
1055
+ @validate_hf_hub_args
1056
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
1057
+ # To prevent depedency import problem.
1058
+ from .model_loading_utils import _fetch_remapped_cls_from_config
1059
+
1060
+ cache_dir = kwargs.pop("cache_dir", None)
1061
+ force_download = kwargs.pop("force_download", False)
1062
+ resume_download = kwargs.pop("resume_download", None)
1063
+ proxies = kwargs.pop("proxies", None)
1064
+ local_files_only = kwargs.pop("local_files_only", None)
1065
+ token = kwargs.pop("token", None)
1066
+ revision = kwargs.pop("revision", None)
1067
+ subfolder = kwargs.pop("subfolder", None)
1068
+
1069
+ # Load config if we don't provide a configuration
1070
+ config_path = pretrained_model_name_or_path
1071
+
1072
+ user_agent = {
1073
+ "diffusers": __version__,
1074
+ "file_type": "model",
1075
+ "framework": "pytorch",
1076
+ }
1077
+
1078
+ # load config
1079
+ config, _, _ = cls.load_config(
1080
+ config_path,
1081
+ cache_dir=cache_dir,
1082
+ return_unused_kwargs=True,
1083
+ return_commit_hash=True,
1084
+ force_download=force_download,
1085
+ resume_download=resume_download,
1086
+ proxies=proxies,
1087
+ local_files_only=local_files_only,
1088
+ token=token,
1089
+ revision=revision,
1090
+ subfolder=subfolder,
1091
+ user_agent=user_agent,
1092
+ **kwargs,
1093
+ )
1094
+ # resolve remapping
1095
+ remapped_class = _fetch_remapped_cls_from_config(config, cls)
1096
+
1097
+ return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
@@ -176,7 +176,8 @@ class AdaLayerNormContinuous(nn.Module):
176
176
  raise ValueError(f"unknown norm_type {norm_type}")
177
177
 
178
178
  def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
179
- emb = self.linear(self.silu(conditioning_embedding))
179
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
180
+ emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
180
181
  scale, shift = torch.chunk(emb, 2, dim=1)
181
182
  x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
182
183
  return x
@@ -2,7 +2,10 @@ from ...utils import is_torch_available
2
2
 
3
3
 
4
4
  if is_torch_available():
5
+ from .dit_transformer_2d import DiTTransformer2DModel
5
6
  from .dual_transformer_2d import DualTransformer2DModel
7
+ from .hunyuan_transformer_2d import HunyuanDiT2DModel
8
+ from .pixart_transformer_2d import PixArtTransformer2DModel
6
9
  from .prior_transformer import PriorTransformer
7
10
  from .t5_film_transformer import T5FilmDecoder
8
11
  from .transformer_2d import Transformer2DModel