diffsynth-engine 0.7.0__py3-none-any.whl → 0.7.1.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1065 @@
1
+ import math
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+
9
+ from diffsynth_engine.models.base import PreTrainedModel
10
+ from diffsynth_engine.models.basic.transformer_helper import RMSNorm
11
+ from diffsynth_engine.models.basic import attention as attention_ops
12
+ from diffsynth_engine.utils.gguf import gguf_inference
13
+ from diffsynth_engine.utils.fp8_linear import fp8_inference
14
+ from diffsynth_engine.utils.parallel import (
15
+ cfg_parallel,
16
+ cfg_parallel_unshard,
17
+ sequence_parallel,
18
+ sequence_parallel_unshard,
19
+ )
20
+
21
+
22
+ def get_timestep_embedding(
23
+ timesteps: torch.Tensor,
24
+ embedding_dim: int,
25
+ flip_sin_to_cos: bool = False,
26
+ downscale_freq_shift: float = 1,
27
+ scale: float = 1,
28
+ max_period: int = 10000,
29
+ ) -> torch.Tensor:
30
+ """
31
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
32
+
33
+ Args
34
+ timesteps (torch.Tensor):
35
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
36
+ embedding_dim (int):
37
+ the dimension of the output.
38
+ flip_sin_to_cos (bool):
39
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
40
+ downscale_freq_shift (float):
41
+ Controls the delta between frequencies between dimensions
42
+ scale (float):
43
+ Scaling factor applied to the embeddings.
44
+ max_period (int):
45
+ Controls the maximum frequency of the embeddings
46
+ Returns
47
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
48
+ """
49
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
50
+
51
+ half_dim = embedding_dim // 2
52
+ exponent = -math.log(max_period) * torch.arange(
53
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
54
+ )
55
+ exponent = exponent / (half_dim - downscale_freq_shift)
56
+
57
+ emb = torch.exp(exponent)
58
+ emb = timesteps[:, None].float() * emb[None, :]
59
+
60
+ # scale embeddings
61
+ emb = scale * emb
62
+
63
+ # concat sine and cosine embeddings
64
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
65
+
66
+ # flip sine and cosine embeddings
67
+ if flip_sin_to_cos:
68
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
69
+
70
+ # zero pad
71
+ if embedding_dim % 2 == 1:
72
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
73
+ return emb
74
+
75
+
76
+ class TimestepEmbedding(nn.Module):
77
+ def __init__(
78
+ self,
79
+ in_channels: int,
80
+ time_embed_dim: int,
81
+ act_fn: str = "silu",
82
+ out_dim: int = None,
83
+ post_act_fn: Optional[str] = None,
84
+ cond_proj_dim=None,
85
+ sample_proj_bias=True,
86
+ ):
87
+ super().__init__()
88
+
89
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
90
+
91
+ if cond_proj_dim is not None:
92
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
93
+ else:
94
+ self.cond_proj = None
95
+
96
+ self.act = torch.nn.SiLU()
97
+
98
+ if out_dim is not None:
99
+ time_embed_dim_out = out_dim
100
+ else:
101
+ time_embed_dim_out = time_embed_dim
102
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
103
+
104
+ if post_act_fn is None:
105
+ self.post_act = None
106
+
107
+ def forward(self, sample, condition=None):
108
+ if condition is not None:
109
+ sample = sample + self.cond_proj(condition)
110
+ sample = self.linear_1(sample)
111
+
112
+ if self.act is not None:
113
+ sample = self.act(sample)
114
+
115
+ sample = self.linear_2(sample)
116
+
117
+ if self.post_act is not None:
118
+ sample = self.post_act(sample)
119
+ return sample
120
+
121
+
122
+ class Timesteps(nn.Module):
123
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
124
+ super().__init__()
125
+ self.num_channels = num_channels
126
+ self.flip_sin_to_cos = flip_sin_to_cos
127
+ self.downscale_freq_shift = downscale_freq_shift
128
+ self.scale = scale
129
+
130
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
131
+ t_emb = get_timestep_embedding(
132
+ timesteps,
133
+ self.num_channels,
134
+ flip_sin_to_cos=self.flip_sin_to_cos,
135
+ downscale_freq_shift=self.downscale_freq_shift,
136
+ scale=self.scale,
137
+ )
138
+ return t_emb
139
+
140
+
141
+ class AdaLayerNormContinuous(nn.Module):
142
+ r"""
143
+ Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
144
+
145
+ Args:
146
+ embedding_dim (`int`): Embedding dimension to use during projection.
147
+ conditioning_embedding_dim (`int`): Dimension of the input condition.
148
+ elementwise_affine (`bool`, defaults to `True`):
149
+ Boolean flag to denote if affine transformation should be applied.
150
+ eps (`float`, defaults to 1e-5): Epsilon factor.
151
+ bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
152
+ norm_type (`str`, defaults to `"layer_norm"`):
153
+ Normalization layer to use. Values supported: "layer_norm", "rms_norm".
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ embedding_dim: int,
159
+ conditioning_embedding_dim: int,
160
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
161
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
162
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
163
+ # However, this is how it was implemented in the original code, and it's rather likely you should
164
+ # set `elementwise_affine` to False.
165
+ elementwise_affine=True,
166
+ eps=1e-5,
167
+ bias=True,
168
+ norm_type="layer_norm",
169
+ ):
170
+ super().__init__()
171
+ self.silu = nn.SiLU()
172
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
173
+ if norm_type == "layer_norm":
174
+ self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
175
+
176
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
177
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
178
+ emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
179
+ scale, shift = torch.chunk(emb, 2, dim=1)
180
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
181
+ return x
182
+
183
+
184
+ def get_1d_rotary_pos_embed(
185
+ dim: int,
186
+ pos: Union[np.ndarray, int],
187
+ theta: float = 10000.0,
188
+ use_real=False,
189
+ linear_factor=1.0,
190
+ ntk_factor=1.0,
191
+ repeat_interleave_real=True,
192
+ freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
193
+ ):
194
+ """
195
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
196
+
197
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
198
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
199
+ data type.
200
+
201
+ Args:
202
+ dim (`int`): Dimension of the frequency tensor.
203
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
204
+ theta (`float`, *optional*, defaults to 10000.0):
205
+ Scaling factor for frequency computation. Defaults to 10000.0.
206
+ use_real (`bool`, *optional*):
207
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
208
+ linear_factor (`float`, *optional*, defaults to 1.0):
209
+ Scaling factor for the context extrapolation. Defaults to 1.0.
210
+ ntk_factor (`float`, *optional*, defaults to 1.0):
211
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
212
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
213
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
214
+ Otherwise, they are concateanted with themselves.
215
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
216
+ the dtype of the frequency tensor.
217
+ Returns:
218
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
219
+ """
220
+ assert dim % 2 == 0
221
+
222
+ if isinstance(pos, int):
223
+ pos = torch.arange(pos)
224
+ if isinstance(pos, np.ndarray):
225
+ pos = torch.from_numpy(pos) # type: ignore # [S]
226
+
227
+ theta = theta * ntk_factor
228
+ freqs = (
229
+ 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
230
+ ) # [D/2]
231
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
232
+ is_npu = freqs.device.type == "npu"
233
+ if is_npu:
234
+ freqs = freqs.float()
235
+ if use_real and repeat_interleave_real:
236
+ # flux, hunyuan-dit, cogvideox
237
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
238
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
239
+ return freqs_cos, freqs_sin
240
+ elif use_real:
241
+ # stable audio, allegro
242
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
243
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
244
+ return freqs_cos, freqs_sin
245
+ else:
246
+ # lumina
247
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
248
+ return freqs_cis
249
+
250
+
251
+ def apply_rotary_emb(
252
+ x: torch.Tensor,
253
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
254
+ use_real: bool = True,
255
+ use_real_unbind_dim: int = -1,
256
+ sequence_dim: int = 2,
257
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """
259
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
260
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
261
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
262
+ tensors contain rotary embeddings and are returned as real tensors.
263
+
264
+ Args:
265
+ x (`torch.Tensor`):
266
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
267
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
268
+
269
+ Returns:
270
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
271
+ """
272
+ if use_real:
273
+ cos, sin = freqs_cis # [S, D]
274
+ if sequence_dim == 2:
275
+ cos = cos[None, None, :, :]
276
+ sin = sin[None, None, :, :]
277
+ elif sequence_dim == 1:
278
+ cos = cos[None, :, None, :]
279
+ sin = sin[None, :, None, :]
280
+ else:
281
+ raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
282
+
283
+ cos, sin = cos.to(x.device), sin.to(x.device)
284
+
285
+ if use_real_unbind_dim == -1:
286
+ # Used for flux, cogvideox, hunyuan-dit
287
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
288
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
289
+ elif use_real_unbind_dim == -2:
290
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
291
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
292
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
293
+ else:
294
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
295
+
296
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
297
+
298
+ return out
299
+ else:
300
+ # used for lumina
301
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
302
+ freqs_cis = freqs_cis.unsqueeze(2)
303
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
304
+
305
+ return x_out.type_as(x)
306
+
307
+
308
+ def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
309
+ query = attn.to_q(hidden_states)
310
+ key = attn.to_k(hidden_states)
311
+ value = attn.to_v(hidden_states)
312
+
313
+ encoder_query = encoder_key = encoder_value = None
314
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
315
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
316
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
317
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
318
+
319
+ return query, key, value, encoder_query, encoder_key, encoder_value
320
+
321
+
322
+ def _get_fused_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
323
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
324
+
325
+ encoder_query = encoder_key = encoder_value = (None,)
326
+ if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
327
+ encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
328
+
329
+ return query, key, value, encoder_query, encoder_key, encoder_value
330
+
331
+
332
+ def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
333
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
334
+
335
+
336
+ class Flux2SwiGLU(nn.Module):
337
+ """
338
+ Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection
339
+ layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters.
340
+ """
341
+
342
+ def __init__(self):
343
+ super().__init__()
344
+ self.gate_fn = nn.SiLU()
345
+
346
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
347
+ x1, x2 = x.chunk(2, dim=-1)
348
+ x = self.gate_fn(x1) * x2
349
+ return x
350
+
351
+
352
+ class Flux2FeedForward(nn.Module):
353
+ def __init__(
354
+ self,
355
+ dim: int,
356
+ dim_out: Optional[int] = None,
357
+ mult: float = 3.0,
358
+ inner_dim: Optional[int] = None,
359
+ bias: bool = False,
360
+ ):
361
+ super().__init__()
362
+ if inner_dim is None:
363
+ inner_dim = int(dim * mult)
364
+ dim_out = dim_out or dim
365
+
366
+ # Flux2SwiGLU will reduce the dimension by half
367
+ self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias)
368
+ self.act_fn = Flux2SwiGLU()
369
+ self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias)
370
+
371
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
372
+ x = self.linear_in(x)
373
+ x = self.act_fn(x)
374
+ x = self.linear_out(x)
375
+ return x
376
+
377
+
378
+ class Flux2AttnProcessor:
379
+ _attention_backend = None
380
+ _parallel_config = None
381
+
382
+ def __init__(self):
383
+ if not hasattr(F, "scaled_dot_product_attention"):
384
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
385
+
386
+ def __call__(
387
+ self,
388
+ attn: "Flux2Attention",
389
+ hidden_states: torch.Tensor,
390
+ encoder_hidden_states: torch.Tensor = None,
391
+ attention_mask: Optional[torch.Tensor] = None,
392
+ image_rotary_emb: Optional[torch.Tensor] = None,
393
+ ) -> torch.Tensor:
394
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
395
+ attn, hidden_states, encoder_hidden_states
396
+ )
397
+
398
+ query = query.unflatten(-1, (attn.heads, -1))
399
+ key = key.unflatten(-1, (attn.heads, -1))
400
+ value = value.unflatten(-1, (attn.heads, -1))
401
+
402
+ query = attn.norm_q(query)
403
+ key = attn.norm_k(key)
404
+
405
+ if attn.added_kv_proj_dim is not None:
406
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
407
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
408
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
409
+
410
+ encoder_query = attn.norm_added_q(encoder_query)
411
+ encoder_key = attn.norm_added_k(encoder_key)
412
+
413
+ query = torch.cat([encoder_query, query], dim=1)
414
+ key = torch.cat([encoder_key, key], dim=1)
415
+ value = torch.cat([encoder_value, value], dim=1)
416
+
417
+ if image_rotary_emb is not None:
418
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
419
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
420
+
421
+ hidden_states = attention_ops.attention(
422
+ query,
423
+ key,
424
+ value,
425
+ q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
426
+ )
427
+ hidden_states = hidden_states.flatten(2, 3)
428
+ hidden_states = hidden_states.to(query.dtype)
429
+
430
+ if encoder_hidden_states is not None:
431
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
432
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
433
+ )
434
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
435
+
436
+ hidden_states = attn.to_out[0](hidden_states)
437
+ hidden_states = attn.to_out[1](hidden_states)
438
+
439
+ if encoder_hidden_states is not None:
440
+ return hidden_states, encoder_hidden_states
441
+ else:
442
+ return hidden_states
443
+
444
+
445
+ class Flux2Attention(torch.nn.Module):
446
+ _default_processor_cls = Flux2AttnProcessor
447
+ _available_processors = [Flux2AttnProcessor]
448
+
449
+ def __init__(
450
+ self,
451
+ query_dim: int,
452
+ heads: int = 8,
453
+ dim_head: int = 64,
454
+ dropout: float = 0.0,
455
+ bias: bool = False,
456
+ added_kv_proj_dim: Optional[int] = None,
457
+ added_proj_bias: Optional[bool] = True,
458
+ out_bias: bool = True,
459
+ eps: float = 1e-5,
460
+ out_dim: int = None,
461
+ elementwise_affine: bool = True,
462
+ processor=None,
463
+ ):
464
+ super().__init__()
465
+
466
+ self.head_dim = dim_head
467
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
468
+ self.query_dim = query_dim
469
+ self.out_dim = out_dim if out_dim is not None else query_dim
470
+ self.heads = out_dim // dim_head if out_dim is not None else heads
471
+
472
+ self.use_bias = bias
473
+ self.dropout = dropout
474
+
475
+ self.added_kv_proj_dim = added_kv_proj_dim
476
+ self.added_proj_bias = added_proj_bias
477
+
478
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
479
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
480
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
481
+
482
+ # QK Norm
483
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
484
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
485
+
486
+ self.to_out = torch.nn.ModuleList([])
487
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
488
+ self.to_out.append(torch.nn.Dropout(dropout))
489
+
490
+ if added_kv_proj_dim is not None:
491
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
492
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
493
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
494
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
495
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
496
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
497
+
498
+ if processor is None:
499
+ processor = self._default_processor_cls()
500
+ self.processor = processor
501
+
502
+ def forward(
503
+ self,
504
+ hidden_states: torch.Tensor,
505
+ encoder_hidden_states: Optional[torch.Tensor] = None,
506
+ attention_mask: Optional[torch.Tensor] = None,
507
+ image_rotary_emb: Optional[torch.Tensor] = None,
508
+ **kwargs,
509
+ ) -> torch.Tensor:
510
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
511
+
512
+
513
+ class Flux2ParallelSelfAttnProcessor:
514
+ _attention_backend = None
515
+ _parallel_config = None
516
+
517
+ def __init__(self):
518
+ if not hasattr(F, "scaled_dot_product_attention"):
519
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
520
+
521
+ def __call__(
522
+ self,
523
+ attn: "Flux2ParallelSelfAttention",
524
+ hidden_states: torch.Tensor,
525
+ attention_mask: Optional[torch.Tensor] = None,
526
+ image_rotary_emb: Optional[torch.Tensor] = None,
527
+ ) -> torch.Tensor:
528
+ # Parallel in (QKV + MLP in) projection
529
+ hidden_states = attn.to_qkv_mlp_proj(hidden_states)
530
+ qkv, mlp_hidden_states = torch.split(
531
+ hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
532
+ )
533
+
534
+ # Handle the attention logic
535
+ query, key, value = qkv.chunk(3, dim=-1)
536
+
537
+ query = query.unflatten(-1, (attn.heads, -1))
538
+ key = key.unflatten(-1, (attn.heads, -1))
539
+ value = value.unflatten(-1, (attn.heads, -1))
540
+
541
+ query = attn.norm_q(query)
542
+ key = attn.norm_k(key)
543
+
544
+ if image_rotary_emb is not None:
545
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
546
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
547
+
548
+ hidden_states = attention_ops.attention(
549
+ query,
550
+ key,
551
+ value,
552
+ q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
553
+ )
554
+ hidden_states = hidden_states.flatten(2, 3)
555
+ hidden_states = hidden_states.to(query.dtype)
556
+
557
+ # Handle the feedforward (FF) logic
558
+ mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
559
+
560
+ # Concatenate and parallel output projection
561
+ hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
562
+ hidden_states = attn.to_out(hidden_states)
563
+
564
+ return hidden_states
565
+
566
+
567
+ class Flux2ParallelSelfAttention(torch.nn.Module):
568
+ """
569
+ Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
570
+
571
+ This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF)
572
+ input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B
573
+ paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block.
574
+ """
575
+
576
+ _default_processor_cls = Flux2ParallelSelfAttnProcessor
577
+ _available_processors = [Flux2ParallelSelfAttnProcessor]
578
+ # Does not support QKV fusion as the QKV projections are always fused
579
+ _supports_qkv_fusion = False
580
+
581
+ def __init__(
582
+ self,
583
+ query_dim: int,
584
+ heads: int = 8,
585
+ dim_head: int = 64,
586
+ dropout: float = 0.0,
587
+ bias: bool = False,
588
+ out_bias: bool = True,
589
+ eps: float = 1e-5,
590
+ out_dim: int = None,
591
+ elementwise_affine: bool = True,
592
+ mlp_ratio: float = 4.0,
593
+ mlp_mult_factor: int = 2,
594
+ processor=None,
595
+ ):
596
+ super().__init__()
597
+
598
+ self.head_dim = dim_head
599
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
600
+ self.query_dim = query_dim
601
+ self.out_dim = out_dim if out_dim is not None else query_dim
602
+ self.heads = out_dim // dim_head if out_dim is not None else heads
603
+
604
+ self.use_bias = bias
605
+ self.dropout = dropout
606
+
607
+ self.mlp_ratio = mlp_ratio
608
+ self.mlp_hidden_dim = int(query_dim * self.mlp_ratio)
609
+ self.mlp_mult_factor = mlp_mult_factor
610
+
611
+ # Fused QKV projections + MLP input projection
612
+ self.to_qkv_mlp_proj = torch.nn.Linear(
613
+ self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias
614
+ )
615
+ self.mlp_act_fn = Flux2SwiGLU()
616
+
617
+ # QK Norm
618
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
619
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
620
+
621
+ # Fused attention output projection + MLP output projection
622
+ self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias)
623
+
624
+ if processor is None:
625
+ processor = self._default_processor_cls()
626
+ self.processor = processor
627
+
628
+ def forward(
629
+ self,
630
+ hidden_states: torch.Tensor,
631
+ attention_mask: Optional[torch.Tensor] = None,
632
+ image_rotary_emb: Optional[torch.Tensor] = None,
633
+ **kwargs,
634
+ ) -> torch.Tensor:
635
+ return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)
636
+
637
+
638
+ class Flux2SingleTransformerBlock(nn.Module):
639
+ def __init__(
640
+ self,
641
+ dim: int,
642
+ num_attention_heads: int,
643
+ attention_head_dim: int,
644
+ mlp_ratio: float = 3.0,
645
+ eps: float = 1e-6,
646
+ bias: bool = False,
647
+ ):
648
+ super().__init__()
649
+
650
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
651
+
652
+ # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this
653
+ # is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442)
654
+ # for a visual depiction of this type of transformer block.
655
+ self.attn = Flux2ParallelSelfAttention(
656
+ query_dim=dim,
657
+ dim_head=attention_head_dim,
658
+ heads=num_attention_heads,
659
+ out_dim=dim,
660
+ bias=bias,
661
+ out_bias=bias,
662
+ eps=eps,
663
+ mlp_ratio=mlp_ratio,
664
+ mlp_mult_factor=2,
665
+ processor=Flux2ParallelSelfAttnProcessor(),
666
+ )
667
+
668
+ def forward(
669
+ self,
670
+ hidden_states: torch.Tensor,
671
+ encoder_hidden_states: Optional[torch.Tensor],
672
+ temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
673
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
674
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
675
+ split_hidden_states: bool = False,
676
+ text_seq_len: Optional[int] = None,
677
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
678
+ # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
679
+ # concatenated
680
+ if encoder_hidden_states is not None:
681
+ text_seq_len = encoder_hidden_states.shape[1]
682
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
683
+
684
+ mod_shift, mod_scale, mod_gate = temb_mod_params
685
+
686
+ norm_hidden_states = self.norm(hidden_states)
687
+ norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
688
+
689
+ joint_attention_kwargs = joint_attention_kwargs or {}
690
+ attn_output = self.attn(
691
+ hidden_states=norm_hidden_states,
692
+ image_rotary_emb=image_rotary_emb,
693
+ **joint_attention_kwargs,
694
+ )
695
+
696
+ hidden_states = hidden_states + mod_gate * attn_output
697
+ if hidden_states.dtype == torch.float16:
698
+ hidden_states = hidden_states.clip(-65504, 65504)
699
+
700
+ if split_hidden_states:
701
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
702
+ return encoder_hidden_states, hidden_states
703
+ else:
704
+ return hidden_states
705
+
706
+
707
+ class Flux2TransformerBlock(nn.Module):
708
+ def __init__(
709
+ self,
710
+ dim: int,
711
+ num_attention_heads: int,
712
+ attention_head_dim: int,
713
+ mlp_ratio: float = 3.0,
714
+ eps: float = 1e-6,
715
+ bias: bool = False,
716
+ ):
717
+ super().__init__()
718
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
719
+
720
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
721
+ self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
722
+
723
+ self.attn = Flux2Attention(
724
+ query_dim=dim,
725
+ added_kv_proj_dim=dim,
726
+ dim_head=attention_head_dim,
727
+ heads=num_attention_heads,
728
+ out_dim=dim,
729
+ bias=bias,
730
+ added_proj_bias=bias,
731
+ out_bias=bias,
732
+ eps=eps,
733
+ processor=Flux2AttnProcessor(),
734
+ )
735
+
736
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
737
+ self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
738
+
739
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
740
+ self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
741
+
742
+ def forward(
743
+ self,
744
+ hidden_states: torch.Tensor,
745
+ encoder_hidden_states: torch.Tensor,
746
+ temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
747
+ temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
748
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
749
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
750
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
751
+ joint_attention_kwargs = joint_attention_kwargs or {}
752
+
753
+ # Modulation parameters shape: [1, 1, self.dim]
754
+ (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
755
+ (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
756
+
757
+ # Img stream
758
+ norm_hidden_states = self.norm1(hidden_states)
759
+ norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa
760
+
761
+ # Conditioning txt stream
762
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
763
+ norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa
764
+
765
+ # Attention on concatenated img + txt stream
766
+ attention_outputs = self.attn(
767
+ hidden_states=norm_hidden_states,
768
+ encoder_hidden_states=norm_encoder_hidden_states,
769
+ image_rotary_emb=image_rotary_emb,
770
+ **joint_attention_kwargs,
771
+ )
772
+
773
+ attn_output, context_attn_output = attention_outputs
774
+
775
+ # Process attention outputs for the image stream (`hidden_states`).
776
+ attn_output = gate_msa * attn_output
777
+ hidden_states = hidden_states + attn_output
778
+
779
+ norm_hidden_states = self.norm2(hidden_states)
780
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
781
+
782
+ ff_output = self.ff(norm_hidden_states)
783
+ hidden_states = hidden_states + gate_mlp * ff_output
784
+
785
+ # Process attention outputs for the text stream (`encoder_hidden_states`).
786
+ context_attn_output = c_gate_msa * context_attn_output
787
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
788
+
789
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
790
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
791
+
792
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
793
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
794
+ if encoder_hidden_states.dtype == torch.float16:
795
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
796
+
797
+ return encoder_hidden_states, hidden_states
798
+
799
+
800
+ class Flux2PosEmbed(nn.Module):
801
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
802
+ def __init__(self, theta: int, axes_dim: List[int]):
803
+ super().__init__()
804
+ self.theta = theta
805
+ self.axes_dim = axes_dim
806
+
807
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
808
+ # Expected ids shape: [S, len(self.axes_dim)]
809
+ cos_out = []
810
+ sin_out = []
811
+ pos = ids.float()
812
+ is_mps = ids.device.type == "mps"
813
+ is_npu = ids.device.type == "npu"
814
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
815
+ # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]
816
+ for i in range(len(self.axes_dim)):
817
+ cos, sin = get_1d_rotary_pos_embed(
818
+ self.axes_dim[i],
819
+ pos[..., i],
820
+ theta=self.theta,
821
+ repeat_interleave_real=True,
822
+ use_real=True,
823
+ freqs_dtype=freqs_dtype,
824
+ )
825
+ cos_out.append(cos)
826
+ sin_out.append(sin)
827
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
828
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
829
+ return freqs_cos, freqs_sin
830
+
831
+
832
+ class Flux2TimestepGuidanceEmbeddings(nn.Module):
833
+ def __init__(
834
+ self,
835
+ in_channels: int = 256,
836
+ embedding_dim: int = 6144,
837
+ bias: bool = False,
838
+ guidance_embeds: bool = True,
839
+ ):
840
+ super().__init__()
841
+
842
+ self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
843
+ self.timestep_embedder = TimestepEmbedding(
844
+ in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
845
+ )
846
+
847
+ if guidance_embeds:
848
+ self.guidance_embedder = TimestepEmbedding(
849
+ in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
850
+ )
851
+ else:
852
+ self.guidance_embedder = None
853
+
854
+ def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
855
+ timesteps_proj = self.time_proj(timestep)
856
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
857
+
858
+ if guidance is not None and self.guidance_embedder is not None:
859
+ guidance_proj = self.time_proj(guidance)
860
+ guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
861
+ time_guidance_emb = timesteps_emb + guidance_emb
862
+ return time_guidance_emb
863
+ else:
864
+ return timesteps_emb
865
+
866
+
867
+ class Flux2Modulation(nn.Module):
868
+ def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
869
+ super().__init__()
870
+ self.mod_param_sets = mod_param_sets
871
+
872
+ self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
873
+ self.act_fn = nn.SiLU()
874
+
875
+ def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
876
+ mod = self.act_fn(temb)
877
+ mod = self.linear(mod)
878
+
879
+ if mod.ndim == 2:
880
+ mod = mod.unsqueeze(1)
881
+ mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
882
+ # Return tuple of 3-tuples of modulation params shift/scale/gate
883
+ return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
884
+
885
+
886
+ class Flux2DiT(PreTrainedModel):
887
+ _supports_gradient_checkpointing = True
888
+ _no_split_modules = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
889
+
890
+ def __init__(
891
+ self,
892
+ patch_size: int = 1,
893
+ in_channels: int = 128,
894
+ out_channels: Optional[int] = None,
895
+ num_layers: int = 5,
896
+ num_single_layers: int = 20,
897
+ attention_head_dim: int = 128,
898
+ num_attention_heads: int = 24,
899
+ joint_attention_dim: int = 7680,
900
+ timestep_guidance_channels: int = 256,
901
+ mlp_ratio: float = 3.0,
902
+ axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
903
+ rope_theta: int = 2000,
904
+ eps: float = 1e-6,
905
+ guidance_embeds: bool = False,
906
+ device: str = "cuda:0",
907
+ dtype: torch.dtype = torch.float32,
908
+ ):
909
+ super().__init__()
910
+ self.out_channels = out_channels or in_channels
911
+ self.inner_dim = num_attention_heads * attention_head_dim # 24 * 128 = 3072
912
+
913
+ # 1. Sinusoidal positional embedding for RoPE on image and text tokens
914
+ self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
915
+
916
+ # 2. Combined timestep + guidance embedding
917
+ self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
918
+ in_channels=timestep_guidance_channels,
919
+ embedding_dim=self.inner_dim,
920
+ bias=False,
921
+ guidance_embeds=guidance_embeds,
922
+ )
923
+
924
+ # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
925
+ # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
926
+ self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
927
+ self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
928
+ # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
929
+ self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
930
+
931
+ # 4. Input projections
932
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
933
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
934
+
935
+ # 5. Double Stream Transformer Blocks
936
+ self.transformer_blocks = nn.ModuleList(
937
+ [
938
+ Flux2TransformerBlock(
939
+ dim=self.inner_dim,
940
+ num_attention_heads=num_attention_heads,
941
+ attention_head_dim=attention_head_dim,
942
+ mlp_ratio=mlp_ratio,
943
+ eps=eps,
944
+ bias=False,
945
+ )
946
+ for _ in range(num_layers)
947
+ ]
948
+ )
949
+
950
+ # 6. Single Stream Transformer Blocks
951
+ self.single_transformer_blocks = nn.ModuleList(
952
+ [
953
+ Flux2SingleTransformerBlock(
954
+ dim=self.inner_dim,
955
+ num_attention_heads=num_attention_heads,
956
+ attention_head_dim=attention_head_dim,
957
+ mlp_ratio=mlp_ratio,
958
+ eps=eps,
959
+ bias=False,
960
+ )
961
+ for _ in range(num_single_layers)
962
+ ]
963
+ )
964
+
965
+ # 7. Output layers
966
+ self.norm_out = AdaLayerNormContinuous(
967
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False
968
+ )
969
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
970
+
971
+ self.gradient_checkpointing = False
972
+
973
+ def forward(
974
+ self,
975
+ hidden_states: torch.Tensor,
976
+ encoder_hidden_states: torch.Tensor = None,
977
+ timestep: torch.LongTensor = None,
978
+ img_ids: torch.Tensor = None,
979
+ txt_ids: torch.Tensor = None,
980
+ guidance: torch.Tensor = None,
981
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
982
+ use_gradient_checkpointing=False,
983
+ use_gradient_checkpointing_offload=False,
984
+ ):
985
+ # 0. Handle input arguments
986
+ if joint_attention_kwargs is not None:
987
+ joint_attention_kwargs = joint_attention_kwargs.copy()
988
+
989
+ num_txt_tokens = encoder_hidden_states.shape[1]
990
+
991
+ # 1. Calculate timestep embedding and modulation parameters
992
+ timestep = timestep.to(hidden_states.dtype) * 1000
993
+
994
+ if guidance is not None:
995
+ guidance = guidance.to(hidden_states.dtype) * 1000
996
+
997
+ temb = self.time_guidance_embed(timestep, guidance)
998
+
999
+ double_stream_mod_img = self.double_stream_modulation_img(temb)
1000
+ double_stream_mod_txt = self.double_stream_modulation_txt(temb)
1001
+ single_stream_mod = self.single_stream_modulation(temb)[0]
1002
+
1003
+ # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
1004
+ hidden_states = self.x_embedder(hidden_states)
1005
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
1006
+
1007
+ # 3. Calculate RoPE embeddings from image and text tokens
1008
+ # NOTE: the below logic means that we can't support batched inference with images of different resolutions or
1009
+ # text prompts of differents lengths. Is this a use case we want to support?
1010
+ if img_ids.ndim == 3:
1011
+ img_ids = img_ids[0]
1012
+ if txt_ids.ndim == 3:
1013
+ txt_ids = txt_ids[0]
1014
+
1015
+ image_rotary_emb = self.pos_embed(img_ids)
1016
+ text_rotary_emb = self.pos_embed(txt_ids)
1017
+ concat_rotary_emb = (
1018
+ torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
1019
+ torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
1020
+ )
1021
+
1022
+ # 4. Double Stream Transformer Blocks
1023
+ for index_block, block in enumerate(self.transformer_blocks):
1024
+ encoder_hidden_states, hidden_states = block(
1025
+ hidden_states=hidden_states,
1026
+ encoder_hidden_states=encoder_hidden_states,
1027
+ temb_mod_params_img=double_stream_mod_img,
1028
+ temb_mod_params_txt=double_stream_mod_txt,
1029
+ image_rotary_emb=concat_rotary_emb,
1030
+ joint_attention_kwargs=joint_attention_kwargs,
1031
+ )
1032
+ # Concatenate text and image streams for single-block inference
1033
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1034
+
1035
+ # 5. Single Stream Transformer Blocks
1036
+ for index_block, block in enumerate(self.single_transformer_blocks):
1037
+ hidden_states = block(
1038
+ hidden_states=hidden_states,
1039
+ encoder_hidden_states=None,
1040
+ temb_mod_params=single_stream_mod,
1041
+ image_rotary_emb=concat_rotary_emb,
1042
+ joint_attention_kwargs=joint_attention_kwargs,
1043
+ )
1044
+ # Remove text tokens from concatenated stream
1045
+ hidden_states = hidden_states[:, num_txt_tokens:, ...]
1046
+
1047
+ # 6. Output layers
1048
+ hidden_states = self.norm_out(hidden_states, temb)
1049
+ output = self.proj_out(hidden_states)
1050
+
1051
+ return output
1052
+
1053
+ @classmethod
1054
+ def from_state_dict(
1055
+ cls,
1056
+ state_dict: Dict[str, torch.Tensor],
1057
+ device: str = "cuda:0",
1058
+ dtype: torch.dtype = torch.float32,
1059
+ **kwargs,
1060
+ ) -> "Flux2DiT":
1061
+ model = cls(device="meta", dtype=dtype, **kwargs)
1062
+ model = model.requires_grad_(False)
1063
+ model.load_state_dict(state_dict, assign=True)
1064
+ model.to(device=device, dtype=dtype, non_blocking=True)
1065
+ return model