diffsynth-engine 0.7.0__py3-none-any.whl → 0.7.1.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +6 -0
- diffsynth_engine/conf/models/flux2/qwen3_8B_config.json +68 -0
- diffsynth_engine/configs/__init__.py +4 -0
- diffsynth_engine/configs/pipeline.py +50 -1
- diffsynth_engine/models/flux2/__init__.py +7 -0
- diffsynth_engine/models/flux2/flux2_dit.py +1065 -0
- diffsynth_engine/models/flux2/flux2_vae.py +1992 -0
- diffsynth_engine/pipelines/__init__.py +2 -0
- diffsynth_engine/pipelines/flux2_klein_image.py +634 -0
- diffsynth_engine/utils/constants.py +1 -0
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/RECORD +15 -10
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/WHEEL +1 -1
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1065 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from diffsynth_engine.models.base import PreTrainedModel
|
|
10
|
+
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
|
|
11
|
+
from diffsynth_engine.models.basic import attention as attention_ops
|
|
12
|
+
from diffsynth_engine.utils.gguf import gguf_inference
|
|
13
|
+
from diffsynth_engine.utils.fp8_linear import fp8_inference
|
|
14
|
+
from diffsynth_engine.utils.parallel import (
|
|
15
|
+
cfg_parallel,
|
|
16
|
+
cfg_parallel_unshard,
|
|
17
|
+
sequence_parallel,
|
|
18
|
+
sequence_parallel_unshard,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_timestep_embedding(
|
|
23
|
+
timesteps: torch.Tensor,
|
|
24
|
+
embedding_dim: int,
|
|
25
|
+
flip_sin_to_cos: bool = False,
|
|
26
|
+
downscale_freq_shift: float = 1,
|
|
27
|
+
scale: float = 1,
|
|
28
|
+
max_period: int = 10000,
|
|
29
|
+
) -> torch.Tensor:
|
|
30
|
+
"""
|
|
31
|
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
|
32
|
+
|
|
33
|
+
Args
|
|
34
|
+
timesteps (torch.Tensor):
|
|
35
|
+
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
|
36
|
+
embedding_dim (int):
|
|
37
|
+
the dimension of the output.
|
|
38
|
+
flip_sin_to_cos (bool):
|
|
39
|
+
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
|
40
|
+
downscale_freq_shift (float):
|
|
41
|
+
Controls the delta between frequencies between dimensions
|
|
42
|
+
scale (float):
|
|
43
|
+
Scaling factor applied to the embeddings.
|
|
44
|
+
max_period (int):
|
|
45
|
+
Controls the maximum frequency of the embeddings
|
|
46
|
+
Returns
|
|
47
|
+
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
|
48
|
+
"""
|
|
49
|
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
|
50
|
+
|
|
51
|
+
half_dim = embedding_dim // 2
|
|
52
|
+
exponent = -math.log(max_period) * torch.arange(
|
|
53
|
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
|
54
|
+
)
|
|
55
|
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
|
56
|
+
|
|
57
|
+
emb = torch.exp(exponent)
|
|
58
|
+
emb = timesteps[:, None].float() * emb[None, :]
|
|
59
|
+
|
|
60
|
+
# scale embeddings
|
|
61
|
+
emb = scale * emb
|
|
62
|
+
|
|
63
|
+
# concat sine and cosine embeddings
|
|
64
|
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
|
65
|
+
|
|
66
|
+
# flip sine and cosine embeddings
|
|
67
|
+
if flip_sin_to_cos:
|
|
68
|
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
|
69
|
+
|
|
70
|
+
# zero pad
|
|
71
|
+
if embedding_dim % 2 == 1:
|
|
72
|
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
|
73
|
+
return emb
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class TimestepEmbedding(nn.Module):
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
in_channels: int,
|
|
80
|
+
time_embed_dim: int,
|
|
81
|
+
act_fn: str = "silu",
|
|
82
|
+
out_dim: int = None,
|
|
83
|
+
post_act_fn: Optional[str] = None,
|
|
84
|
+
cond_proj_dim=None,
|
|
85
|
+
sample_proj_bias=True,
|
|
86
|
+
):
|
|
87
|
+
super().__init__()
|
|
88
|
+
|
|
89
|
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
|
90
|
+
|
|
91
|
+
if cond_proj_dim is not None:
|
|
92
|
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
|
93
|
+
else:
|
|
94
|
+
self.cond_proj = None
|
|
95
|
+
|
|
96
|
+
self.act = torch.nn.SiLU()
|
|
97
|
+
|
|
98
|
+
if out_dim is not None:
|
|
99
|
+
time_embed_dim_out = out_dim
|
|
100
|
+
else:
|
|
101
|
+
time_embed_dim_out = time_embed_dim
|
|
102
|
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
|
103
|
+
|
|
104
|
+
if post_act_fn is None:
|
|
105
|
+
self.post_act = None
|
|
106
|
+
|
|
107
|
+
def forward(self, sample, condition=None):
|
|
108
|
+
if condition is not None:
|
|
109
|
+
sample = sample + self.cond_proj(condition)
|
|
110
|
+
sample = self.linear_1(sample)
|
|
111
|
+
|
|
112
|
+
if self.act is not None:
|
|
113
|
+
sample = self.act(sample)
|
|
114
|
+
|
|
115
|
+
sample = self.linear_2(sample)
|
|
116
|
+
|
|
117
|
+
if self.post_act is not None:
|
|
118
|
+
sample = self.post_act(sample)
|
|
119
|
+
return sample
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class Timesteps(nn.Module):
|
|
123
|
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.num_channels = num_channels
|
|
126
|
+
self.flip_sin_to_cos = flip_sin_to_cos
|
|
127
|
+
self.downscale_freq_shift = downscale_freq_shift
|
|
128
|
+
self.scale = scale
|
|
129
|
+
|
|
130
|
+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
|
131
|
+
t_emb = get_timestep_embedding(
|
|
132
|
+
timesteps,
|
|
133
|
+
self.num_channels,
|
|
134
|
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
|
135
|
+
downscale_freq_shift=self.downscale_freq_shift,
|
|
136
|
+
scale=self.scale,
|
|
137
|
+
)
|
|
138
|
+
return t_emb
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class AdaLayerNormContinuous(nn.Module):
|
|
142
|
+
r"""
|
|
143
|
+
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
embedding_dim (`int`): Embedding dimension to use during projection.
|
|
147
|
+
conditioning_embedding_dim (`int`): Dimension of the input condition.
|
|
148
|
+
elementwise_affine (`bool`, defaults to `True`):
|
|
149
|
+
Boolean flag to denote if affine transformation should be applied.
|
|
150
|
+
eps (`float`, defaults to 1e-5): Epsilon factor.
|
|
151
|
+
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
|
|
152
|
+
norm_type (`str`, defaults to `"layer_norm"`):
|
|
153
|
+
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
embedding_dim: int,
|
|
159
|
+
conditioning_embedding_dim: int,
|
|
160
|
+
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
|
161
|
+
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
|
162
|
+
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
|
163
|
+
# However, this is how it was implemented in the original code, and it's rather likely you should
|
|
164
|
+
# set `elementwise_affine` to False.
|
|
165
|
+
elementwise_affine=True,
|
|
166
|
+
eps=1e-5,
|
|
167
|
+
bias=True,
|
|
168
|
+
norm_type="layer_norm",
|
|
169
|
+
):
|
|
170
|
+
super().__init__()
|
|
171
|
+
self.silu = nn.SiLU()
|
|
172
|
+
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
|
173
|
+
if norm_type == "layer_norm":
|
|
174
|
+
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
|
175
|
+
|
|
176
|
+
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
|
177
|
+
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
|
178
|
+
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
|
|
179
|
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
|
180
|
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
|
181
|
+
return x
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def get_1d_rotary_pos_embed(
|
|
185
|
+
dim: int,
|
|
186
|
+
pos: Union[np.ndarray, int],
|
|
187
|
+
theta: float = 10000.0,
|
|
188
|
+
use_real=False,
|
|
189
|
+
linear_factor=1.0,
|
|
190
|
+
ntk_factor=1.0,
|
|
191
|
+
repeat_interleave_real=True,
|
|
192
|
+
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
|
193
|
+
):
|
|
194
|
+
"""
|
|
195
|
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
|
196
|
+
|
|
197
|
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
|
198
|
+
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
|
199
|
+
data type.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
dim (`int`): Dimension of the frequency tensor.
|
|
203
|
+
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
|
204
|
+
theta (`float`, *optional*, defaults to 10000.0):
|
|
205
|
+
Scaling factor for frequency computation. Defaults to 10000.0.
|
|
206
|
+
use_real (`bool`, *optional*):
|
|
207
|
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
|
208
|
+
linear_factor (`float`, *optional*, defaults to 1.0):
|
|
209
|
+
Scaling factor for the context extrapolation. Defaults to 1.0.
|
|
210
|
+
ntk_factor (`float`, *optional*, defaults to 1.0):
|
|
211
|
+
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
|
212
|
+
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
|
213
|
+
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
|
214
|
+
Otherwise, they are concateanted with themselves.
|
|
215
|
+
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
|
216
|
+
the dtype of the frequency tensor.
|
|
217
|
+
Returns:
|
|
218
|
+
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
|
219
|
+
"""
|
|
220
|
+
assert dim % 2 == 0
|
|
221
|
+
|
|
222
|
+
if isinstance(pos, int):
|
|
223
|
+
pos = torch.arange(pos)
|
|
224
|
+
if isinstance(pos, np.ndarray):
|
|
225
|
+
pos = torch.from_numpy(pos) # type: ignore # [S]
|
|
226
|
+
|
|
227
|
+
theta = theta * ntk_factor
|
|
228
|
+
freqs = (
|
|
229
|
+
1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
|
|
230
|
+
) # [D/2]
|
|
231
|
+
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
|
232
|
+
is_npu = freqs.device.type == "npu"
|
|
233
|
+
if is_npu:
|
|
234
|
+
freqs = freqs.float()
|
|
235
|
+
if use_real and repeat_interleave_real:
|
|
236
|
+
# flux, hunyuan-dit, cogvideox
|
|
237
|
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
|
238
|
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
|
239
|
+
return freqs_cos, freqs_sin
|
|
240
|
+
elif use_real:
|
|
241
|
+
# stable audio, allegro
|
|
242
|
+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
|
243
|
+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
|
244
|
+
return freqs_cos, freqs_sin
|
|
245
|
+
else:
|
|
246
|
+
# lumina
|
|
247
|
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
|
248
|
+
return freqs_cis
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def apply_rotary_emb(
|
|
252
|
+
x: torch.Tensor,
|
|
253
|
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
|
254
|
+
use_real: bool = True,
|
|
255
|
+
use_real_unbind_dim: int = -1,
|
|
256
|
+
sequence_dim: int = 2,
|
|
257
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
258
|
+
"""
|
|
259
|
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
|
260
|
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
|
261
|
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
|
262
|
+
tensors contain rotary embeddings and are returned as real tensors.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
x (`torch.Tensor`):
|
|
266
|
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
|
267
|
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
|
271
|
+
"""
|
|
272
|
+
if use_real:
|
|
273
|
+
cos, sin = freqs_cis # [S, D]
|
|
274
|
+
if sequence_dim == 2:
|
|
275
|
+
cos = cos[None, None, :, :]
|
|
276
|
+
sin = sin[None, None, :, :]
|
|
277
|
+
elif sequence_dim == 1:
|
|
278
|
+
cos = cos[None, :, None, :]
|
|
279
|
+
sin = sin[None, :, None, :]
|
|
280
|
+
else:
|
|
281
|
+
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
|
|
282
|
+
|
|
283
|
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
|
284
|
+
|
|
285
|
+
if use_real_unbind_dim == -1:
|
|
286
|
+
# Used for flux, cogvideox, hunyuan-dit
|
|
287
|
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
|
|
288
|
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
|
289
|
+
elif use_real_unbind_dim == -2:
|
|
290
|
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
|
291
|
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
|
|
292
|
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
|
293
|
+
else:
|
|
294
|
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
|
295
|
+
|
|
296
|
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
|
297
|
+
|
|
298
|
+
return out
|
|
299
|
+
else:
|
|
300
|
+
# used for lumina
|
|
301
|
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
|
302
|
+
freqs_cis = freqs_cis.unsqueeze(2)
|
|
303
|
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
|
304
|
+
|
|
305
|
+
return x_out.type_as(x)
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
|
309
|
+
query = attn.to_q(hidden_states)
|
|
310
|
+
key = attn.to_k(hidden_states)
|
|
311
|
+
value = attn.to_v(hidden_states)
|
|
312
|
+
|
|
313
|
+
encoder_query = encoder_key = encoder_value = None
|
|
314
|
+
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
|
315
|
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
|
316
|
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
|
317
|
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
|
318
|
+
|
|
319
|
+
return query, key, value, encoder_query, encoder_key, encoder_value
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def _get_fused_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
|
323
|
+
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
|
324
|
+
|
|
325
|
+
encoder_query = encoder_key = encoder_value = (None,)
|
|
326
|
+
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
|
|
327
|
+
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
|
|
328
|
+
|
|
329
|
+
return query, key, value, encoder_query, encoder_key, encoder_value
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
|
333
|
+
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
class Flux2SwiGLU(nn.Module):
|
|
337
|
+
"""
|
|
338
|
+
Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection
|
|
339
|
+
layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters.
|
|
340
|
+
"""
|
|
341
|
+
|
|
342
|
+
def __init__(self):
|
|
343
|
+
super().__init__()
|
|
344
|
+
self.gate_fn = nn.SiLU()
|
|
345
|
+
|
|
346
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
347
|
+
x1, x2 = x.chunk(2, dim=-1)
|
|
348
|
+
x = self.gate_fn(x1) * x2
|
|
349
|
+
return x
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
class Flux2FeedForward(nn.Module):
|
|
353
|
+
def __init__(
|
|
354
|
+
self,
|
|
355
|
+
dim: int,
|
|
356
|
+
dim_out: Optional[int] = None,
|
|
357
|
+
mult: float = 3.0,
|
|
358
|
+
inner_dim: Optional[int] = None,
|
|
359
|
+
bias: bool = False,
|
|
360
|
+
):
|
|
361
|
+
super().__init__()
|
|
362
|
+
if inner_dim is None:
|
|
363
|
+
inner_dim = int(dim * mult)
|
|
364
|
+
dim_out = dim_out or dim
|
|
365
|
+
|
|
366
|
+
# Flux2SwiGLU will reduce the dimension by half
|
|
367
|
+
self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias)
|
|
368
|
+
self.act_fn = Flux2SwiGLU()
|
|
369
|
+
self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias)
|
|
370
|
+
|
|
371
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
372
|
+
x = self.linear_in(x)
|
|
373
|
+
x = self.act_fn(x)
|
|
374
|
+
x = self.linear_out(x)
|
|
375
|
+
return x
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
class Flux2AttnProcessor:
|
|
379
|
+
_attention_backend = None
|
|
380
|
+
_parallel_config = None
|
|
381
|
+
|
|
382
|
+
def __init__(self):
|
|
383
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
|
384
|
+
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
|
385
|
+
|
|
386
|
+
def __call__(
|
|
387
|
+
self,
|
|
388
|
+
attn: "Flux2Attention",
|
|
389
|
+
hidden_states: torch.Tensor,
|
|
390
|
+
encoder_hidden_states: torch.Tensor = None,
|
|
391
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
392
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
393
|
+
) -> torch.Tensor:
|
|
394
|
+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
|
395
|
+
attn, hidden_states, encoder_hidden_states
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
query = query.unflatten(-1, (attn.heads, -1))
|
|
399
|
+
key = key.unflatten(-1, (attn.heads, -1))
|
|
400
|
+
value = value.unflatten(-1, (attn.heads, -1))
|
|
401
|
+
|
|
402
|
+
query = attn.norm_q(query)
|
|
403
|
+
key = attn.norm_k(key)
|
|
404
|
+
|
|
405
|
+
if attn.added_kv_proj_dim is not None:
|
|
406
|
+
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
|
407
|
+
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
|
408
|
+
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
|
409
|
+
|
|
410
|
+
encoder_query = attn.norm_added_q(encoder_query)
|
|
411
|
+
encoder_key = attn.norm_added_k(encoder_key)
|
|
412
|
+
|
|
413
|
+
query = torch.cat([encoder_query, query], dim=1)
|
|
414
|
+
key = torch.cat([encoder_key, key], dim=1)
|
|
415
|
+
value = torch.cat([encoder_value, value], dim=1)
|
|
416
|
+
|
|
417
|
+
if image_rotary_emb is not None:
|
|
418
|
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
|
419
|
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
|
420
|
+
|
|
421
|
+
hidden_states = attention_ops.attention(
|
|
422
|
+
query,
|
|
423
|
+
key,
|
|
424
|
+
value,
|
|
425
|
+
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
|
426
|
+
)
|
|
427
|
+
hidden_states = hidden_states.flatten(2, 3)
|
|
428
|
+
hidden_states = hidden_states.to(query.dtype)
|
|
429
|
+
|
|
430
|
+
if encoder_hidden_states is not None:
|
|
431
|
+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
|
432
|
+
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
|
433
|
+
)
|
|
434
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
|
435
|
+
|
|
436
|
+
hidden_states = attn.to_out[0](hidden_states)
|
|
437
|
+
hidden_states = attn.to_out[1](hidden_states)
|
|
438
|
+
|
|
439
|
+
if encoder_hidden_states is not None:
|
|
440
|
+
return hidden_states, encoder_hidden_states
|
|
441
|
+
else:
|
|
442
|
+
return hidden_states
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
class Flux2Attention(torch.nn.Module):
|
|
446
|
+
_default_processor_cls = Flux2AttnProcessor
|
|
447
|
+
_available_processors = [Flux2AttnProcessor]
|
|
448
|
+
|
|
449
|
+
def __init__(
|
|
450
|
+
self,
|
|
451
|
+
query_dim: int,
|
|
452
|
+
heads: int = 8,
|
|
453
|
+
dim_head: int = 64,
|
|
454
|
+
dropout: float = 0.0,
|
|
455
|
+
bias: bool = False,
|
|
456
|
+
added_kv_proj_dim: Optional[int] = None,
|
|
457
|
+
added_proj_bias: Optional[bool] = True,
|
|
458
|
+
out_bias: bool = True,
|
|
459
|
+
eps: float = 1e-5,
|
|
460
|
+
out_dim: int = None,
|
|
461
|
+
elementwise_affine: bool = True,
|
|
462
|
+
processor=None,
|
|
463
|
+
):
|
|
464
|
+
super().__init__()
|
|
465
|
+
|
|
466
|
+
self.head_dim = dim_head
|
|
467
|
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
|
468
|
+
self.query_dim = query_dim
|
|
469
|
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
|
470
|
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
|
471
|
+
|
|
472
|
+
self.use_bias = bias
|
|
473
|
+
self.dropout = dropout
|
|
474
|
+
|
|
475
|
+
self.added_kv_proj_dim = added_kv_proj_dim
|
|
476
|
+
self.added_proj_bias = added_proj_bias
|
|
477
|
+
|
|
478
|
+
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
|
479
|
+
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
|
480
|
+
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
|
481
|
+
|
|
482
|
+
# QK Norm
|
|
483
|
+
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
|
484
|
+
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
|
485
|
+
|
|
486
|
+
self.to_out = torch.nn.ModuleList([])
|
|
487
|
+
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
|
488
|
+
self.to_out.append(torch.nn.Dropout(dropout))
|
|
489
|
+
|
|
490
|
+
if added_kv_proj_dim is not None:
|
|
491
|
+
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
|
492
|
+
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
|
493
|
+
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
|
494
|
+
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
|
495
|
+
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
|
496
|
+
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
|
|
497
|
+
|
|
498
|
+
if processor is None:
|
|
499
|
+
processor = self._default_processor_cls()
|
|
500
|
+
self.processor = processor
|
|
501
|
+
|
|
502
|
+
def forward(
|
|
503
|
+
self,
|
|
504
|
+
hidden_states: torch.Tensor,
|
|
505
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
506
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
507
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
508
|
+
**kwargs,
|
|
509
|
+
) -> torch.Tensor:
|
|
510
|
+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
class Flux2ParallelSelfAttnProcessor:
|
|
514
|
+
_attention_backend = None
|
|
515
|
+
_parallel_config = None
|
|
516
|
+
|
|
517
|
+
def __init__(self):
|
|
518
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
|
519
|
+
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
|
520
|
+
|
|
521
|
+
def __call__(
|
|
522
|
+
self,
|
|
523
|
+
attn: "Flux2ParallelSelfAttention",
|
|
524
|
+
hidden_states: torch.Tensor,
|
|
525
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
526
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
527
|
+
) -> torch.Tensor:
|
|
528
|
+
# Parallel in (QKV + MLP in) projection
|
|
529
|
+
hidden_states = attn.to_qkv_mlp_proj(hidden_states)
|
|
530
|
+
qkv, mlp_hidden_states = torch.split(
|
|
531
|
+
hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
# Handle the attention logic
|
|
535
|
+
query, key, value = qkv.chunk(3, dim=-1)
|
|
536
|
+
|
|
537
|
+
query = query.unflatten(-1, (attn.heads, -1))
|
|
538
|
+
key = key.unflatten(-1, (attn.heads, -1))
|
|
539
|
+
value = value.unflatten(-1, (attn.heads, -1))
|
|
540
|
+
|
|
541
|
+
query = attn.norm_q(query)
|
|
542
|
+
key = attn.norm_k(key)
|
|
543
|
+
|
|
544
|
+
if image_rotary_emb is not None:
|
|
545
|
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
|
546
|
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
|
547
|
+
|
|
548
|
+
hidden_states = attention_ops.attention(
|
|
549
|
+
query,
|
|
550
|
+
key,
|
|
551
|
+
value,
|
|
552
|
+
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
|
553
|
+
)
|
|
554
|
+
hidden_states = hidden_states.flatten(2, 3)
|
|
555
|
+
hidden_states = hidden_states.to(query.dtype)
|
|
556
|
+
|
|
557
|
+
# Handle the feedforward (FF) logic
|
|
558
|
+
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
|
|
559
|
+
|
|
560
|
+
# Concatenate and parallel output projection
|
|
561
|
+
hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
|
|
562
|
+
hidden_states = attn.to_out(hidden_states)
|
|
563
|
+
|
|
564
|
+
return hidden_states
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
class Flux2ParallelSelfAttention(torch.nn.Module):
|
|
568
|
+
"""
|
|
569
|
+
Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
|
|
570
|
+
|
|
571
|
+
This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF)
|
|
572
|
+
input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B
|
|
573
|
+
paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block.
|
|
574
|
+
"""
|
|
575
|
+
|
|
576
|
+
_default_processor_cls = Flux2ParallelSelfAttnProcessor
|
|
577
|
+
_available_processors = [Flux2ParallelSelfAttnProcessor]
|
|
578
|
+
# Does not support QKV fusion as the QKV projections are always fused
|
|
579
|
+
_supports_qkv_fusion = False
|
|
580
|
+
|
|
581
|
+
def __init__(
|
|
582
|
+
self,
|
|
583
|
+
query_dim: int,
|
|
584
|
+
heads: int = 8,
|
|
585
|
+
dim_head: int = 64,
|
|
586
|
+
dropout: float = 0.0,
|
|
587
|
+
bias: bool = False,
|
|
588
|
+
out_bias: bool = True,
|
|
589
|
+
eps: float = 1e-5,
|
|
590
|
+
out_dim: int = None,
|
|
591
|
+
elementwise_affine: bool = True,
|
|
592
|
+
mlp_ratio: float = 4.0,
|
|
593
|
+
mlp_mult_factor: int = 2,
|
|
594
|
+
processor=None,
|
|
595
|
+
):
|
|
596
|
+
super().__init__()
|
|
597
|
+
|
|
598
|
+
self.head_dim = dim_head
|
|
599
|
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
|
600
|
+
self.query_dim = query_dim
|
|
601
|
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
|
602
|
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
|
603
|
+
|
|
604
|
+
self.use_bias = bias
|
|
605
|
+
self.dropout = dropout
|
|
606
|
+
|
|
607
|
+
self.mlp_ratio = mlp_ratio
|
|
608
|
+
self.mlp_hidden_dim = int(query_dim * self.mlp_ratio)
|
|
609
|
+
self.mlp_mult_factor = mlp_mult_factor
|
|
610
|
+
|
|
611
|
+
# Fused QKV projections + MLP input projection
|
|
612
|
+
self.to_qkv_mlp_proj = torch.nn.Linear(
|
|
613
|
+
self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias
|
|
614
|
+
)
|
|
615
|
+
self.mlp_act_fn = Flux2SwiGLU()
|
|
616
|
+
|
|
617
|
+
# QK Norm
|
|
618
|
+
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
|
619
|
+
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
|
620
|
+
|
|
621
|
+
# Fused attention output projection + MLP output projection
|
|
622
|
+
self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias)
|
|
623
|
+
|
|
624
|
+
if processor is None:
|
|
625
|
+
processor = self._default_processor_cls()
|
|
626
|
+
self.processor = processor
|
|
627
|
+
|
|
628
|
+
def forward(
|
|
629
|
+
self,
|
|
630
|
+
hidden_states: torch.Tensor,
|
|
631
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
632
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
633
|
+
**kwargs,
|
|
634
|
+
) -> torch.Tensor:
|
|
635
|
+
return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
class Flux2SingleTransformerBlock(nn.Module):
|
|
639
|
+
def __init__(
|
|
640
|
+
self,
|
|
641
|
+
dim: int,
|
|
642
|
+
num_attention_heads: int,
|
|
643
|
+
attention_head_dim: int,
|
|
644
|
+
mlp_ratio: float = 3.0,
|
|
645
|
+
eps: float = 1e-6,
|
|
646
|
+
bias: bool = False,
|
|
647
|
+
):
|
|
648
|
+
super().__init__()
|
|
649
|
+
|
|
650
|
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
|
651
|
+
|
|
652
|
+
# Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this
|
|
653
|
+
# is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442)
|
|
654
|
+
# for a visual depiction of this type of transformer block.
|
|
655
|
+
self.attn = Flux2ParallelSelfAttention(
|
|
656
|
+
query_dim=dim,
|
|
657
|
+
dim_head=attention_head_dim,
|
|
658
|
+
heads=num_attention_heads,
|
|
659
|
+
out_dim=dim,
|
|
660
|
+
bias=bias,
|
|
661
|
+
out_bias=bias,
|
|
662
|
+
eps=eps,
|
|
663
|
+
mlp_ratio=mlp_ratio,
|
|
664
|
+
mlp_mult_factor=2,
|
|
665
|
+
processor=Flux2ParallelSelfAttnProcessor(),
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
def forward(
|
|
669
|
+
self,
|
|
670
|
+
hidden_states: torch.Tensor,
|
|
671
|
+
encoder_hidden_states: Optional[torch.Tensor],
|
|
672
|
+
temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
673
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
674
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
675
|
+
split_hidden_states: bool = False,
|
|
676
|
+
text_seq_len: Optional[int] = None,
|
|
677
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
678
|
+
# If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
|
|
679
|
+
# concatenated
|
|
680
|
+
if encoder_hidden_states is not None:
|
|
681
|
+
text_seq_len = encoder_hidden_states.shape[1]
|
|
682
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
683
|
+
|
|
684
|
+
mod_shift, mod_scale, mod_gate = temb_mod_params
|
|
685
|
+
|
|
686
|
+
norm_hidden_states = self.norm(hidden_states)
|
|
687
|
+
norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
|
|
688
|
+
|
|
689
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
|
690
|
+
attn_output = self.attn(
|
|
691
|
+
hidden_states=norm_hidden_states,
|
|
692
|
+
image_rotary_emb=image_rotary_emb,
|
|
693
|
+
**joint_attention_kwargs,
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
hidden_states = hidden_states + mod_gate * attn_output
|
|
697
|
+
if hidden_states.dtype == torch.float16:
|
|
698
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
|
699
|
+
|
|
700
|
+
if split_hidden_states:
|
|
701
|
+
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
|
|
702
|
+
return encoder_hidden_states, hidden_states
|
|
703
|
+
else:
|
|
704
|
+
return hidden_states
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
class Flux2TransformerBlock(nn.Module):
|
|
708
|
+
def __init__(
|
|
709
|
+
self,
|
|
710
|
+
dim: int,
|
|
711
|
+
num_attention_heads: int,
|
|
712
|
+
attention_head_dim: int,
|
|
713
|
+
mlp_ratio: float = 3.0,
|
|
714
|
+
eps: float = 1e-6,
|
|
715
|
+
bias: bool = False,
|
|
716
|
+
):
|
|
717
|
+
super().__init__()
|
|
718
|
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
|
719
|
+
|
|
720
|
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
|
721
|
+
self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
|
722
|
+
|
|
723
|
+
self.attn = Flux2Attention(
|
|
724
|
+
query_dim=dim,
|
|
725
|
+
added_kv_proj_dim=dim,
|
|
726
|
+
dim_head=attention_head_dim,
|
|
727
|
+
heads=num_attention_heads,
|
|
728
|
+
out_dim=dim,
|
|
729
|
+
bias=bias,
|
|
730
|
+
added_proj_bias=bias,
|
|
731
|
+
out_bias=bias,
|
|
732
|
+
eps=eps,
|
|
733
|
+
processor=Flux2AttnProcessor(),
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
|
737
|
+
self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
|
|
738
|
+
|
|
739
|
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
|
740
|
+
self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
|
|
741
|
+
|
|
742
|
+
def forward(
|
|
743
|
+
self,
|
|
744
|
+
hidden_states: torch.Tensor,
|
|
745
|
+
encoder_hidden_states: torch.Tensor,
|
|
746
|
+
temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
|
|
747
|
+
temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
|
|
748
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
749
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
750
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
751
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
|
752
|
+
|
|
753
|
+
# Modulation parameters shape: [1, 1, self.dim]
|
|
754
|
+
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
|
|
755
|
+
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
|
|
756
|
+
|
|
757
|
+
# Img stream
|
|
758
|
+
norm_hidden_states = self.norm1(hidden_states)
|
|
759
|
+
norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa
|
|
760
|
+
|
|
761
|
+
# Conditioning txt stream
|
|
762
|
+
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
|
|
763
|
+
norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa
|
|
764
|
+
|
|
765
|
+
# Attention on concatenated img + txt stream
|
|
766
|
+
attention_outputs = self.attn(
|
|
767
|
+
hidden_states=norm_hidden_states,
|
|
768
|
+
encoder_hidden_states=norm_encoder_hidden_states,
|
|
769
|
+
image_rotary_emb=image_rotary_emb,
|
|
770
|
+
**joint_attention_kwargs,
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
attn_output, context_attn_output = attention_outputs
|
|
774
|
+
|
|
775
|
+
# Process attention outputs for the image stream (`hidden_states`).
|
|
776
|
+
attn_output = gate_msa * attn_output
|
|
777
|
+
hidden_states = hidden_states + attn_output
|
|
778
|
+
|
|
779
|
+
norm_hidden_states = self.norm2(hidden_states)
|
|
780
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
|
781
|
+
|
|
782
|
+
ff_output = self.ff(norm_hidden_states)
|
|
783
|
+
hidden_states = hidden_states + gate_mlp * ff_output
|
|
784
|
+
|
|
785
|
+
# Process attention outputs for the text stream (`encoder_hidden_states`).
|
|
786
|
+
context_attn_output = c_gate_msa * context_attn_output
|
|
787
|
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
|
788
|
+
|
|
789
|
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
|
790
|
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
|
|
791
|
+
|
|
792
|
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
|
793
|
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
|
|
794
|
+
if encoder_hidden_states.dtype == torch.float16:
|
|
795
|
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
|
796
|
+
|
|
797
|
+
return encoder_hidden_states, hidden_states
|
|
798
|
+
|
|
799
|
+
|
|
800
|
+
class Flux2PosEmbed(nn.Module):
|
|
801
|
+
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
|
802
|
+
def __init__(self, theta: int, axes_dim: List[int]):
|
|
803
|
+
super().__init__()
|
|
804
|
+
self.theta = theta
|
|
805
|
+
self.axes_dim = axes_dim
|
|
806
|
+
|
|
807
|
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
|
808
|
+
# Expected ids shape: [S, len(self.axes_dim)]
|
|
809
|
+
cos_out = []
|
|
810
|
+
sin_out = []
|
|
811
|
+
pos = ids.float()
|
|
812
|
+
is_mps = ids.device.type == "mps"
|
|
813
|
+
is_npu = ids.device.type == "npu"
|
|
814
|
+
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
|
815
|
+
# Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]
|
|
816
|
+
for i in range(len(self.axes_dim)):
|
|
817
|
+
cos, sin = get_1d_rotary_pos_embed(
|
|
818
|
+
self.axes_dim[i],
|
|
819
|
+
pos[..., i],
|
|
820
|
+
theta=self.theta,
|
|
821
|
+
repeat_interleave_real=True,
|
|
822
|
+
use_real=True,
|
|
823
|
+
freqs_dtype=freqs_dtype,
|
|
824
|
+
)
|
|
825
|
+
cos_out.append(cos)
|
|
826
|
+
sin_out.append(sin)
|
|
827
|
+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
|
828
|
+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
|
829
|
+
return freqs_cos, freqs_sin
|
|
830
|
+
|
|
831
|
+
|
|
832
|
+
class Flux2TimestepGuidanceEmbeddings(nn.Module):
|
|
833
|
+
def __init__(
|
|
834
|
+
self,
|
|
835
|
+
in_channels: int = 256,
|
|
836
|
+
embedding_dim: int = 6144,
|
|
837
|
+
bias: bool = False,
|
|
838
|
+
guidance_embeds: bool = True,
|
|
839
|
+
):
|
|
840
|
+
super().__init__()
|
|
841
|
+
|
|
842
|
+
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
|
843
|
+
self.timestep_embedder = TimestepEmbedding(
|
|
844
|
+
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
if guidance_embeds:
|
|
848
|
+
self.guidance_embedder = TimestepEmbedding(
|
|
849
|
+
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
|
850
|
+
)
|
|
851
|
+
else:
|
|
852
|
+
self.guidance_embedder = None
|
|
853
|
+
|
|
854
|
+
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
|
|
855
|
+
timesteps_proj = self.time_proj(timestep)
|
|
856
|
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
|
|
857
|
+
|
|
858
|
+
if guidance is not None and self.guidance_embedder is not None:
|
|
859
|
+
guidance_proj = self.time_proj(guidance)
|
|
860
|
+
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
|
|
861
|
+
time_guidance_emb = timesteps_emb + guidance_emb
|
|
862
|
+
return time_guidance_emb
|
|
863
|
+
else:
|
|
864
|
+
return timesteps_emb
|
|
865
|
+
|
|
866
|
+
|
|
867
|
+
class Flux2Modulation(nn.Module):
|
|
868
|
+
def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
|
|
869
|
+
super().__init__()
|
|
870
|
+
self.mod_param_sets = mod_param_sets
|
|
871
|
+
|
|
872
|
+
self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
|
|
873
|
+
self.act_fn = nn.SiLU()
|
|
874
|
+
|
|
875
|
+
def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
|
|
876
|
+
mod = self.act_fn(temb)
|
|
877
|
+
mod = self.linear(mod)
|
|
878
|
+
|
|
879
|
+
if mod.ndim == 2:
|
|
880
|
+
mod = mod.unsqueeze(1)
|
|
881
|
+
mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
|
|
882
|
+
# Return tuple of 3-tuples of modulation params shift/scale/gate
|
|
883
|
+
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
class Flux2DiT(PreTrainedModel):
|
|
887
|
+
_supports_gradient_checkpointing = True
|
|
888
|
+
_no_split_modules = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
|
|
889
|
+
|
|
890
|
+
def __init__(
|
|
891
|
+
self,
|
|
892
|
+
patch_size: int = 1,
|
|
893
|
+
in_channels: int = 128,
|
|
894
|
+
out_channels: Optional[int] = None,
|
|
895
|
+
num_layers: int = 5,
|
|
896
|
+
num_single_layers: int = 20,
|
|
897
|
+
attention_head_dim: int = 128,
|
|
898
|
+
num_attention_heads: int = 24,
|
|
899
|
+
joint_attention_dim: int = 7680,
|
|
900
|
+
timestep_guidance_channels: int = 256,
|
|
901
|
+
mlp_ratio: float = 3.0,
|
|
902
|
+
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
|
|
903
|
+
rope_theta: int = 2000,
|
|
904
|
+
eps: float = 1e-6,
|
|
905
|
+
guidance_embeds: bool = False,
|
|
906
|
+
device: str = "cuda:0",
|
|
907
|
+
dtype: torch.dtype = torch.float32,
|
|
908
|
+
):
|
|
909
|
+
super().__init__()
|
|
910
|
+
self.out_channels = out_channels or in_channels
|
|
911
|
+
self.inner_dim = num_attention_heads * attention_head_dim # 24 * 128 = 3072
|
|
912
|
+
|
|
913
|
+
# 1. Sinusoidal positional embedding for RoPE on image and text tokens
|
|
914
|
+
self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
|
|
915
|
+
|
|
916
|
+
# 2. Combined timestep + guidance embedding
|
|
917
|
+
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
|
|
918
|
+
in_channels=timestep_guidance_channels,
|
|
919
|
+
embedding_dim=self.inner_dim,
|
|
920
|
+
bias=False,
|
|
921
|
+
guidance_embeds=guidance_embeds,
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
|
|
925
|
+
# Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
|
|
926
|
+
self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
|
|
927
|
+
self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
|
|
928
|
+
# Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
|
|
929
|
+
self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
|
|
930
|
+
|
|
931
|
+
# 4. Input projections
|
|
932
|
+
self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
|
|
933
|
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
|
|
934
|
+
|
|
935
|
+
# 5. Double Stream Transformer Blocks
|
|
936
|
+
self.transformer_blocks = nn.ModuleList(
|
|
937
|
+
[
|
|
938
|
+
Flux2TransformerBlock(
|
|
939
|
+
dim=self.inner_dim,
|
|
940
|
+
num_attention_heads=num_attention_heads,
|
|
941
|
+
attention_head_dim=attention_head_dim,
|
|
942
|
+
mlp_ratio=mlp_ratio,
|
|
943
|
+
eps=eps,
|
|
944
|
+
bias=False,
|
|
945
|
+
)
|
|
946
|
+
for _ in range(num_layers)
|
|
947
|
+
]
|
|
948
|
+
)
|
|
949
|
+
|
|
950
|
+
# 6. Single Stream Transformer Blocks
|
|
951
|
+
self.single_transformer_blocks = nn.ModuleList(
|
|
952
|
+
[
|
|
953
|
+
Flux2SingleTransformerBlock(
|
|
954
|
+
dim=self.inner_dim,
|
|
955
|
+
num_attention_heads=num_attention_heads,
|
|
956
|
+
attention_head_dim=attention_head_dim,
|
|
957
|
+
mlp_ratio=mlp_ratio,
|
|
958
|
+
eps=eps,
|
|
959
|
+
bias=False,
|
|
960
|
+
)
|
|
961
|
+
for _ in range(num_single_layers)
|
|
962
|
+
]
|
|
963
|
+
)
|
|
964
|
+
|
|
965
|
+
# 7. Output layers
|
|
966
|
+
self.norm_out = AdaLayerNormContinuous(
|
|
967
|
+
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False
|
|
968
|
+
)
|
|
969
|
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
|
|
970
|
+
|
|
971
|
+
self.gradient_checkpointing = False
|
|
972
|
+
|
|
973
|
+
def forward(
|
|
974
|
+
self,
|
|
975
|
+
hidden_states: torch.Tensor,
|
|
976
|
+
encoder_hidden_states: torch.Tensor = None,
|
|
977
|
+
timestep: torch.LongTensor = None,
|
|
978
|
+
img_ids: torch.Tensor = None,
|
|
979
|
+
txt_ids: torch.Tensor = None,
|
|
980
|
+
guidance: torch.Tensor = None,
|
|
981
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
982
|
+
use_gradient_checkpointing=False,
|
|
983
|
+
use_gradient_checkpointing_offload=False,
|
|
984
|
+
):
|
|
985
|
+
# 0. Handle input arguments
|
|
986
|
+
if joint_attention_kwargs is not None:
|
|
987
|
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
|
988
|
+
|
|
989
|
+
num_txt_tokens = encoder_hidden_states.shape[1]
|
|
990
|
+
|
|
991
|
+
# 1. Calculate timestep embedding and modulation parameters
|
|
992
|
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
|
993
|
+
|
|
994
|
+
if guidance is not None:
|
|
995
|
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
|
996
|
+
|
|
997
|
+
temb = self.time_guidance_embed(timestep, guidance)
|
|
998
|
+
|
|
999
|
+
double_stream_mod_img = self.double_stream_modulation_img(temb)
|
|
1000
|
+
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
|
|
1001
|
+
single_stream_mod = self.single_stream_modulation(temb)[0]
|
|
1002
|
+
|
|
1003
|
+
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
|
|
1004
|
+
hidden_states = self.x_embedder(hidden_states)
|
|
1005
|
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
|
1006
|
+
|
|
1007
|
+
# 3. Calculate RoPE embeddings from image and text tokens
|
|
1008
|
+
# NOTE: the below logic means that we can't support batched inference with images of different resolutions or
|
|
1009
|
+
# text prompts of differents lengths. Is this a use case we want to support?
|
|
1010
|
+
if img_ids.ndim == 3:
|
|
1011
|
+
img_ids = img_ids[0]
|
|
1012
|
+
if txt_ids.ndim == 3:
|
|
1013
|
+
txt_ids = txt_ids[0]
|
|
1014
|
+
|
|
1015
|
+
image_rotary_emb = self.pos_embed(img_ids)
|
|
1016
|
+
text_rotary_emb = self.pos_embed(txt_ids)
|
|
1017
|
+
concat_rotary_emb = (
|
|
1018
|
+
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
|
|
1019
|
+
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
|
|
1020
|
+
)
|
|
1021
|
+
|
|
1022
|
+
# 4. Double Stream Transformer Blocks
|
|
1023
|
+
for index_block, block in enumerate(self.transformer_blocks):
|
|
1024
|
+
encoder_hidden_states, hidden_states = block(
|
|
1025
|
+
hidden_states=hidden_states,
|
|
1026
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
1027
|
+
temb_mod_params_img=double_stream_mod_img,
|
|
1028
|
+
temb_mod_params_txt=double_stream_mod_txt,
|
|
1029
|
+
image_rotary_emb=concat_rotary_emb,
|
|
1030
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
|
1031
|
+
)
|
|
1032
|
+
# Concatenate text and image streams for single-block inference
|
|
1033
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
1034
|
+
|
|
1035
|
+
# 5. Single Stream Transformer Blocks
|
|
1036
|
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
|
1037
|
+
hidden_states = block(
|
|
1038
|
+
hidden_states=hidden_states,
|
|
1039
|
+
encoder_hidden_states=None,
|
|
1040
|
+
temb_mod_params=single_stream_mod,
|
|
1041
|
+
image_rotary_emb=concat_rotary_emb,
|
|
1042
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
|
1043
|
+
)
|
|
1044
|
+
# Remove text tokens from concatenated stream
|
|
1045
|
+
hidden_states = hidden_states[:, num_txt_tokens:, ...]
|
|
1046
|
+
|
|
1047
|
+
# 6. Output layers
|
|
1048
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
|
1049
|
+
output = self.proj_out(hidden_states)
|
|
1050
|
+
|
|
1051
|
+
return output
|
|
1052
|
+
|
|
1053
|
+
@classmethod
|
|
1054
|
+
def from_state_dict(
|
|
1055
|
+
cls,
|
|
1056
|
+
state_dict: Dict[str, torch.Tensor],
|
|
1057
|
+
device: str = "cuda:0",
|
|
1058
|
+
dtype: torch.dtype = torch.float32,
|
|
1059
|
+
**kwargs,
|
|
1060
|
+
) -> "Flux2DiT":
|
|
1061
|
+
model = cls(device="meta", dtype=dtype, **kwargs)
|
|
1062
|
+
model = model.requires_grad_(False)
|
|
1063
|
+
model.load_state_dict(state_dict, assign=True)
|
|
1064
|
+
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
1065
|
+
return model
|