diffsynth-engine 0.6.1.dev21__py3-none-any.whl → 0.6.1.dev23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
- diffsynth_engine/configs/pipeline.py +35 -5
- diffsynth_engine/models/basic/attention.py +59 -20
- diffsynth_engine/models/basic/video_sparse_attention.py +235 -0
- diffsynth_engine/models/flux/flux_controlnet.py +7 -19
- diffsynth_engine/models/flux/flux_dit.py +22 -36
- diffsynth_engine/models/flux/flux_dit_fbcache.py +9 -7
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/qwen_image/qwen_image_dit.py +13 -15
- diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
- diffsynth_engine/models/wan/wan_dit.py +62 -22
- diffsynth_engine/pipelines/flux_image.py +11 -10
- diffsynth_engine/pipelines/qwen_image.py +26 -28
- diffsynth_engine/pipelines/wan_s2v.py +3 -8
- diffsynth_engine/pipelines/wan_video.py +11 -13
- diffsynth_engine/utils/constants.py +13 -12
- diffsynth_engine/utils/flag.py +6 -0
- diffsynth_engine/utils/parallel.py +51 -6
- {diffsynth_engine-0.6.1.dev21.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.6.1.dev21.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/RECORD +34 -32
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
- {diffsynth_engine-0.6.1.dev21.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.6.1.dev21.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.6.1.dev21.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/top_level.txt +0 -0
|
@@ -176,7 +176,6 @@ class FluxDoubleAttention(nn.Module):
|
|
|
176
176
|
dim_b,
|
|
177
177
|
num_heads,
|
|
178
178
|
head_dim,
|
|
179
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
180
179
|
device: str = "cuda:0",
|
|
181
180
|
dtype: torch.dtype = torch.bfloat16,
|
|
182
181
|
):
|
|
@@ -194,19 +193,20 @@ class FluxDoubleAttention(nn.Module):
|
|
|
194
193
|
|
|
195
194
|
self.a_to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype)
|
|
196
195
|
self.b_to_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype)
|
|
197
|
-
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
198
196
|
|
|
199
197
|
def attention_callback(self, attn_out_a, attn_out_b, x_a, x_b, q_a, q_b, k_a, k_b, v_a, v_b, rope_emb, image_emb):
|
|
200
198
|
return attn_out_a, attn_out_b
|
|
201
199
|
|
|
202
|
-
def forward(self, image, text, rope_emb, image_emb):
|
|
200
|
+
def forward(self, image, text, rope_emb, image_emb, attn_kwargs=None):
|
|
203
201
|
q_a, k_a, v_a = rearrange(self.a_to_qkv(image), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
|
|
204
202
|
q_b, k_b, v_b = rearrange(self.b_to_qkv(text), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
|
|
205
203
|
q = torch.cat([self.norm_q_b(q_b), self.norm_q_a(q_a)], dim=1)
|
|
206
204
|
k = torch.cat([self.norm_k_b(k_b), self.norm_k_a(k_a)], dim=1)
|
|
207
205
|
v = torch.cat([v_b, v_a], dim=1)
|
|
208
206
|
q, k = apply_rope(q, k, rope_emb)
|
|
209
|
-
|
|
207
|
+
|
|
208
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
209
|
+
attn_out = attention_ops.attention(q, k, v, **attn_kwargs)
|
|
210
210
|
attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
|
|
211
211
|
text_out, image_out = attn_out[:, : text.shape[1]], attn_out[:, text.shape[1] :]
|
|
212
212
|
image_out, text_out = self.attention_callback(
|
|
@@ -231,14 +231,11 @@ class FluxDoubleTransformerBlock(nn.Module):
|
|
|
231
231
|
self,
|
|
232
232
|
dim,
|
|
233
233
|
num_heads,
|
|
234
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
235
234
|
device: str = "cuda:0",
|
|
236
235
|
dtype: torch.dtype = torch.bfloat16,
|
|
237
236
|
):
|
|
238
237
|
super().__init__()
|
|
239
|
-
self.attn = FluxDoubleAttention(
|
|
240
|
-
dim, dim, num_heads, dim // num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype
|
|
241
|
-
)
|
|
238
|
+
self.attn = FluxDoubleAttention(dim, dim, num_heads, dim // num_heads, device=device, dtype=dtype)
|
|
242
239
|
# Image
|
|
243
240
|
self.norm_msa_a = AdaLayerNormZero(dim, device=device, dtype=dtype)
|
|
244
241
|
self.norm_mlp_a = AdaLayerNormZero(dim, device=device, dtype=dtype)
|
|
@@ -256,11 +253,11 @@ class FluxDoubleTransformerBlock(nn.Module):
|
|
|
256
253
|
nn.Linear(dim * 4, dim, device=device, dtype=dtype),
|
|
257
254
|
)
|
|
258
255
|
|
|
259
|
-
def forward(self, image, text, t_emb, rope_emb, image_emb=None):
|
|
256
|
+
def forward(self, image, text, t_emb, rope_emb, image_emb=None, attn_kwargs=None):
|
|
260
257
|
# AdaLayerNorm-Zero for Image and Text MSA
|
|
261
258
|
image_in, gate_a = self.norm_msa_a(image, t_emb)
|
|
262
259
|
text_in, gate_b = self.norm_msa_b(text, t_emb)
|
|
263
|
-
image_out, text_out = self.attn(image_in, text_in, rope_emb, image_emb)
|
|
260
|
+
image_out, text_out = self.attn(image_in, text_in, rope_emb, image_emb, attn_kwargs)
|
|
264
261
|
image = image + gate_a * image_out
|
|
265
262
|
text = text + gate_b * text_out
|
|
266
263
|
|
|
@@ -279,7 +276,6 @@ class FluxSingleAttention(nn.Module):
|
|
|
279
276
|
self,
|
|
280
277
|
dim,
|
|
281
278
|
num_heads,
|
|
282
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
283
279
|
device: str = "cuda:0",
|
|
284
280
|
dtype: torch.dtype = torch.bfloat16,
|
|
285
281
|
):
|
|
@@ -288,15 +284,16 @@ class FluxSingleAttention(nn.Module):
|
|
|
288
284
|
self.to_qkv = nn.Linear(dim, dim * 3, device=device, dtype=dtype)
|
|
289
285
|
self.norm_q_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype)
|
|
290
286
|
self.norm_k_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype)
|
|
291
|
-
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
292
287
|
|
|
293
288
|
def attention_callback(self, attn_out, x, q, k, v, rope_emb, image_emb):
|
|
294
289
|
return attn_out
|
|
295
290
|
|
|
296
|
-
def forward(self, x, rope_emb, image_emb):
|
|
291
|
+
def forward(self, x, rope_emb, image_emb, attn_kwargs=None):
|
|
297
292
|
q, k, v = rearrange(self.to_qkv(x), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
|
|
298
293
|
q, k = apply_rope(self.norm_q_a(q), self.norm_k_a(k), rope_emb)
|
|
299
|
-
|
|
294
|
+
|
|
295
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
296
|
+
attn_out = attention_ops.attention(q, k, v, **attn_kwargs)
|
|
300
297
|
attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
|
|
301
298
|
return self.attention_callback(attn_out=attn_out, x=x, q=q, k=k, v=v, rope_emb=rope_emb, image_emb=image_emb)
|
|
302
299
|
|
|
@@ -306,23 +303,22 @@ class FluxSingleTransformerBlock(nn.Module):
|
|
|
306
303
|
self,
|
|
307
304
|
dim,
|
|
308
305
|
num_heads,
|
|
309
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
310
306
|
device: str = "cuda:0",
|
|
311
307
|
dtype: torch.dtype = torch.bfloat16,
|
|
312
308
|
):
|
|
313
309
|
super().__init__()
|
|
314
310
|
self.dim = dim
|
|
315
311
|
self.norm = AdaLayerNormZero(dim, device=device, dtype=dtype)
|
|
316
|
-
self.attn = FluxSingleAttention(dim, num_heads,
|
|
312
|
+
self.attn = FluxSingleAttention(dim, num_heads, device=device, dtype=dtype)
|
|
317
313
|
self.mlp = nn.Sequential(
|
|
318
314
|
nn.Linear(dim, dim * 4, device=device, dtype=dtype),
|
|
319
315
|
nn.GELU(approximate="tanh"),
|
|
320
316
|
)
|
|
321
317
|
self.proj_out = nn.Linear(dim * 5, dim, device=device, dtype=dtype)
|
|
322
318
|
|
|
323
|
-
def forward(self, x, t_emb, rope_emb, image_emb=None):
|
|
319
|
+
def forward(self, x, t_emb, rope_emb, image_emb=None, attn_kwargs=None):
|
|
324
320
|
h, gate = self.norm(x, emb=t_emb)
|
|
325
|
-
attn_output = self.attn(h, rope_emb, image_emb)
|
|
321
|
+
attn_output = self.attn(h, rope_emb, image_emb, attn_kwargs)
|
|
326
322
|
mlp_output = self.mlp(h)
|
|
327
323
|
return x + gate * self.proj_out(torch.cat([attn_output, mlp_output], dim=2))
|
|
328
324
|
|
|
@@ -334,7 +330,6 @@ class FluxDiT(PreTrainedModel):
|
|
|
334
330
|
def __init__(
|
|
335
331
|
self,
|
|
336
332
|
in_channel: int = 64,
|
|
337
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
338
333
|
device: str = "cuda:0",
|
|
339
334
|
dtype: torch.dtype = torch.bfloat16,
|
|
340
335
|
):
|
|
@@ -352,16 +347,10 @@ class FluxDiT(PreTrainedModel):
|
|
|
352
347
|
self.x_embedder = nn.Linear(in_channel, 3072, device=device, dtype=dtype)
|
|
353
348
|
|
|
354
349
|
self.blocks = nn.ModuleList(
|
|
355
|
-
[
|
|
356
|
-
FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
|
|
357
|
-
for _ in range(19)
|
|
358
|
-
]
|
|
350
|
+
[FluxDoubleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(19)]
|
|
359
351
|
)
|
|
360
352
|
self.single_blocks = nn.ModuleList(
|
|
361
|
-
[
|
|
362
|
-
FluxSingleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
|
|
363
|
-
for _ in range(38)
|
|
364
|
-
]
|
|
353
|
+
[FluxSingleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(38)]
|
|
365
354
|
)
|
|
366
355
|
self.final_norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
|
|
367
356
|
self.final_proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
|
|
@@ -403,6 +392,7 @@ class FluxDiT(PreTrainedModel):
|
|
|
403
392
|
text_ids: torch.Tensor,
|
|
404
393
|
guidance: torch.Tensor,
|
|
405
394
|
image_emb: torch.Tensor | None = None,
|
|
395
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
406
396
|
controlnet_double_block_output: List[torch.Tensor] | None = None,
|
|
407
397
|
controlnet_single_block_output: List[torch.Tensor] | None = None,
|
|
408
398
|
**kwargs,
|
|
@@ -470,14 +460,16 @@ class FluxDiT(PreTrainedModel):
|
|
|
470
460
|
rope_emb = torch.cat((text_rope_emb, image_rope_emb), dim=2)
|
|
471
461
|
|
|
472
462
|
for i, block in enumerate(self.blocks):
|
|
473
|
-
hidden_states, prompt_emb = block(
|
|
463
|
+
hidden_states, prompt_emb = block(
|
|
464
|
+
hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
|
|
465
|
+
)
|
|
474
466
|
if len(controlnet_double_block_output) > 0:
|
|
475
467
|
interval_control = len(self.blocks) / len(controlnet_double_block_output)
|
|
476
468
|
interval_control = int(np.ceil(interval_control))
|
|
477
469
|
hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
|
|
478
470
|
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
|
479
471
|
for i, block in enumerate(self.single_blocks):
|
|
480
|
-
hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
|
|
472
|
+
hidden_states = block(hidden_states, conditioning, rope_emb, image_emb, attn_kwargs)
|
|
481
473
|
if len(controlnet_single_block_output) > 0:
|
|
482
474
|
interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
|
|
483
475
|
interval_control = int(np.ceil(interval_control))
|
|
@@ -498,14 +490,8 @@ class FluxDiT(PreTrainedModel):
|
|
|
498
490
|
device: str,
|
|
499
491
|
dtype: torch.dtype,
|
|
500
492
|
in_channel: int = 64,
|
|
501
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
502
493
|
):
|
|
503
|
-
model = cls(
|
|
504
|
-
device="meta",
|
|
505
|
-
dtype=dtype,
|
|
506
|
-
in_channel=in_channel,
|
|
507
|
-
attn_kwargs=attn_kwargs,
|
|
508
|
-
)
|
|
494
|
+
model = cls(device="meta", dtype=dtype, in_channel=in_channel)
|
|
509
495
|
model = model.requires_grad_(False)
|
|
510
496
|
model.load_state_dict(state_dict, assign=True)
|
|
511
497
|
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
@@ -20,12 +20,11 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
20
20
|
def __init__(
|
|
21
21
|
self,
|
|
22
22
|
in_channel: int = 64,
|
|
23
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
24
23
|
device: str = "cuda:0",
|
|
25
24
|
dtype: torch.dtype = torch.bfloat16,
|
|
26
25
|
relative_l1_threshold: float = 0.05,
|
|
27
26
|
):
|
|
28
|
-
super().__init__(in_channel=in_channel,
|
|
27
|
+
super().__init__(in_channel=in_channel, device=device, dtype=dtype)
|
|
29
28
|
self.relative_l1_threshold = relative_l1_threshold
|
|
30
29
|
self.step_count = 0
|
|
31
30
|
self.num_inference_steps = 0
|
|
@@ -56,6 +55,7 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
56
55
|
text_ids: torch.Tensor,
|
|
57
56
|
guidance: torch.Tensor,
|
|
58
57
|
image_emb: torch.Tensor | None = None,
|
|
58
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
59
59
|
controlnet_double_block_output: List[torch.Tensor] | None = None,
|
|
60
60
|
controlnet_single_block_output: List[torch.Tensor] | None = None,
|
|
61
61
|
**kwargs,
|
|
@@ -124,7 +124,9 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
124
124
|
|
|
125
125
|
# first block
|
|
126
126
|
original_hidden_states = hidden_states
|
|
127
|
-
hidden_states, prompt_emb = self.blocks[0](
|
|
127
|
+
hidden_states, prompt_emb = self.blocks[0](
|
|
128
|
+
hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
|
|
129
|
+
)
|
|
128
130
|
first_hidden_states_residual = hidden_states - original_hidden_states
|
|
129
131
|
|
|
130
132
|
(first_hidden_states_residual,) = sequence_parallel_unshard(
|
|
@@ -149,14 +151,16 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
149
151
|
|
|
150
152
|
first_hidden_states = hidden_states.clone()
|
|
151
153
|
for i, block in enumerate(self.blocks[1:]):
|
|
152
|
-
hidden_states, prompt_emb = block(
|
|
154
|
+
hidden_states, prompt_emb = block(
|
|
155
|
+
hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
|
|
156
|
+
)
|
|
153
157
|
if len(controlnet_double_block_output) > 0:
|
|
154
158
|
interval_control = len(self.blocks) / len(controlnet_double_block_output)
|
|
155
159
|
interval_control = int(np.ceil(interval_control))
|
|
156
160
|
hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
|
|
157
161
|
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
|
158
162
|
for i, block in enumerate(self.single_blocks):
|
|
159
|
-
hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
|
|
163
|
+
hidden_states = block(hidden_states, conditioning, rope_emb, image_emb, attn_kwargs)
|
|
160
164
|
if len(controlnet_single_block_output) > 0:
|
|
161
165
|
interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
|
|
162
166
|
interval_control = int(np.ceil(interval_control))
|
|
@@ -182,14 +186,12 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
182
186
|
device: str,
|
|
183
187
|
dtype: torch.dtype,
|
|
184
188
|
in_channel: int = 64,
|
|
185
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
186
189
|
relative_l1_threshold: float = 0.05,
|
|
187
190
|
):
|
|
188
191
|
model = cls(
|
|
189
192
|
device="meta",
|
|
190
193
|
dtype=dtype,
|
|
191
194
|
in_channel=in_channel,
|
|
192
|
-
attn_kwargs=attn_kwargs,
|
|
193
195
|
relative_l1_threshold=relative_l1_threshold,
|
|
194
196
|
)
|
|
195
197
|
model = model.requires_grad_(False)
|
|
@@ -2,7 +2,7 @@ import torch
|
|
|
2
2
|
from einops import rearrange
|
|
3
3
|
from torch import nn
|
|
4
4
|
from PIL import Image
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Dict, List
|
|
6
6
|
from functools import partial
|
|
7
7
|
from diffsynth_engine.models.text_encoder.siglip import SiglipImageEncoder
|
|
8
8
|
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
|
|
@@ -18,7 +18,6 @@ class FluxIPAdapterAttention(nn.Module):
|
|
|
18
18
|
dim: int = 3072,
|
|
19
19
|
head_num: int = 24,
|
|
20
20
|
scale: float = 1.0,
|
|
21
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
22
21
|
device: str = "cuda:0",
|
|
23
22
|
dtype: torch.dtype = torch.bfloat16,
|
|
24
23
|
):
|
|
@@ -28,12 +27,13 @@ class FluxIPAdapterAttention(nn.Module):
|
|
|
28
27
|
self.to_v_ip = nn.Linear(image_emb_dim, dim, device=device, dtype=dtype, bias=False)
|
|
29
28
|
self.head_num = head_num
|
|
30
29
|
self.scale = scale
|
|
31
|
-
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
32
30
|
|
|
33
|
-
def forward(self, query: torch.Tensor, image_emb: torch.Tensor):
|
|
31
|
+
def forward(self, query: torch.Tensor, image_emb: torch.Tensor, attn_kwargs=None):
|
|
34
32
|
key = rearrange(self.norm_k(self.to_k_ip(image_emb)), "b s (h d) -> b s h d", h=self.head_num)
|
|
35
33
|
value = rearrange(self.to_v_ip(image_emb), "b s (h d) -> b s h d", h=self.head_num)
|
|
36
|
-
|
|
34
|
+
|
|
35
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
36
|
+
attn_out = attention(query, key, value, **attn_kwargs)
|
|
37
37
|
return self.scale * rearrange(attn_out, "b s h d -> b s (h d)")
|
|
38
38
|
|
|
39
39
|
@classmethod
|
|
@@ -167,7 +167,6 @@ class QwenDoubleStreamAttention(nn.Module):
|
|
|
167
167
|
dim_b,
|
|
168
168
|
num_heads,
|
|
169
169
|
head_dim,
|
|
170
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
171
170
|
device: str = "cuda:0",
|
|
172
171
|
dtype: torch.dtype = torch.bfloat16,
|
|
173
172
|
):
|
|
@@ -189,7 +188,6 @@ class QwenDoubleStreamAttention(nn.Module):
|
|
|
189
188
|
|
|
190
189
|
self.to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype)
|
|
191
190
|
self.to_add_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype)
|
|
192
|
-
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
193
191
|
|
|
194
192
|
def forward(
|
|
195
193
|
self,
|
|
@@ -197,6 +195,7 @@ class QwenDoubleStreamAttention(nn.Module):
|
|
|
197
195
|
text: torch.FloatTensor,
|
|
198
196
|
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
199
197
|
attn_mask: Optional[torch.Tensor] = None,
|
|
198
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
200
199
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
201
200
|
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
|
|
202
201
|
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
|
|
@@ -227,7 +226,8 @@ class QwenDoubleStreamAttention(nn.Module):
|
|
|
227
226
|
joint_k = joint_k.transpose(1, 2)
|
|
228
227
|
joint_v = joint_v.transpose(1, 2)
|
|
229
228
|
|
|
230
|
-
|
|
229
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
230
|
+
joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **attn_kwargs)
|
|
231
231
|
|
|
232
232
|
joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype)
|
|
233
233
|
|
|
@@ -247,7 +247,6 @@ class QwenImageTransformerBlock(nn.Module):
|
|
|
247
247
|
num_attention_heads: int,
|
|
248
248
|
attention_head_dim: int,
|
|
249
249
|
eps: float = 1e-6,
|
|
250
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
251
250
|
device: str = "cuda:0",
|
|
252
251
|
dtype: torch.dtype = torch.bfloat16,
|
|
253
252
|
):
|
|
@@ -267,7 +266,6 @@ class QwenImageTransformerBlock(nn.Module):
|
|
|
267
266
|
dim_b=dim,
|
|
268
267
|
num_heads=num_attention_heads,
|
|
269
268
|
head_dim=attention_head_dim,
|
|
270
|
-
attn_kwargs=attn_kwargs,
|
|
271
269
|
device=device,
|
|
272
270
|
dtype=dtype,
|
|
273
271
|
)
|
|
@@ -293,6 +291,7 @@ class QwenImageTransformerBlock(nn.Module):
|
|
|
293
291
|
temb: torch.Tensor,
|
|
294
292
|
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
295
293
|
attn_mask: Optional[torch.Tensor] = None,
|
|
294
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
296
295
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
297
296
|
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
|
298
297
|
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
|
@@ -308,6 +307,7 @@ class QwenImageTransformerBlock(nn.Module):
|
|
|
308
307
|
text=txt_modulated,
|
|
309
308
|
rotary_emb=rotary_emb,
|
|
310
309
|
attn_mask=attn_mask,
|
|
310
|
+
attn_kwargs=attn_kwargs,
|
|
311
311
|
)
|
|
312
312
|
|
|
313
313
|
image = image + img_gate * img_attn_out
|
|
@@ -335,7 +335,6 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
335
335
|
def __init__(
|
|
336
336
|
self,
|
|
337
337
|
num_layers: int = 60,
|
|
338
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
339
338
|
device: str = "cuda:0",
|
|
340
339
|
dtype: torch.dtype = torch.bfloat16,
|
|
341
340
|
):
|
|
@@ -356,7 +355,6 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
356
355
|
dim=3072,
|
|
357
356
|
num_attention_heads=24,
|
|
358
357
|
attention_head_dim=128,
|
|
359
|
-
attn_kwargs=attn_kwargs,
|
|
360
358
|
device=device,
|
|
361
359
|
dtype=dtype,
|
|
362
360
|
)
|
|
@@ -444,6 +442,7 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
444
442
|
entity_text: Optional[List[torch.Tensor]] = None,
|
|
445
443
|
entity_seq_lens: Optional[List[torch.LongTensor]] = None,
|
|
446
444
|
entity_masks: Optional[List[torch.Tensor]] = None,
|
|
445
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
447
446
|
):
|
|
448
447
|
h, w = image.shape[-2:]
|
|
449
448
|
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
|
|
@@ -509,7 +508,12 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
509
508
|
rotary_emb = (img_freqs, txt_freqs)
|
|
510
509
|
for block in self.transformer_blocks:
|
|
511
510
|
text, image = block(
|
|
512
|
-
image=image,
|
|
511
|
+
image=image,
|
|
512
|
+
text=text,
|
|
513
|
+
temb=conditioning,
|
|
514
|
+
rotary_emb=rotary_emb,
|
|
515
|
+
attn_mask=attn_mask,
|
|
516
|
+
attn_kwargs=attn_kwargs,
|
|
513
517
|
)
|
|
514
518
|
image = self.norm_out(image, conditioning)
|
|
515
519
|
image = self.proj_out(image)
|
|
@@ -527,14 +531,8 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
527
531
|
device: str,
|
|
528
532
|
dtype: torch.dtype,
|
|
529
533
|
num_layers: int = 60,
|
|
530
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
531
534
|
):
|
|
532
|
-
model = cls(
|
|
533
|
-
device="meta",
|
|
534
|
-
dtype=dtype,
|
|
535
|
-
num_layers=num_layers,
|
|
536
|
-
attn_kwargs=attn_kwargs,
|
|
537
|
-
)
|
|
535
|
+
model = cls(device="meta", dtype=dtype, num_layers=num_layers)
|
|
538
536
|
model = model.requires_grad_(False)
|
|
539
537
|
model.load_state_dict(state_dict, assign=True)
|
|
540
538
|
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
@@ -11,12 +11,11 @@ class QwenImageDiTFBCache(QwenImageDiT):
|
|
|
11
11
|
def __init__(
|
|
12
12
|
self,
|
|
13
13
|
num_layers: int = 60,
|
|
14
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
15
14
|
device: str = "cuda:0",
|
|
16
15
|
dtype: torch.dtype = torch.bfloat16,
|
|
17
16
|
relative_l1_threshold: float = 0.05,
|
|
18
17
|
):
|
|
19
|
-
super().__init__(num_layers=num_layers,
|
|
18
|
+
super().__init__(num_layers=num_layers, device=device, dtype=dtype)
|
|
20
19
|
self.relative_l1_threshold = relative_l1_threshold
|
|
21
20
|
self.step_count = 0
|
|
22
21
|
self.num_inference_steps = 0
|
|
@@ -43,6 +42,7 @@ class QwenImageDiTFBCache(QwenImageDiT):
|
|
|
43
42
|
text: torch.Tensor = None,
|
|
44
43
|
timestep: torch.LongTensor = None,
|
|
45
44
|
txt_seq_lens: torch.LongTensor = None,
|
|
45
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
46
46
|
):
|
|
47
47
|
h, w = image.shape[-2:]
|
|
48
48
|
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
|
|
@@ -72,7 +72,11 @@ class QwenImageDiTFBCache(QwenImageDiT):
|
|
|
72
72
|
# first block
|
|
73
73
|
original_hidden_states = image
|
|
74
74
|
text, image = self.transformer_blocks[0](
|
|
75
|
-
image=image,
|
|
75
|
+
image=image,
|
|
76
|
+
text=text,
|
|
77
|
+
temb=conditioning,
|
|
78
|
+
image_rotary_emb=image_rotary_emb,
|
|
79
|
+
attn_kwargs=attn_kwargs,
|
|
76
80
|
)
|
|
77
81
|
first_hidden_states_residual = image - original_hidden_states
|
|
78
82
|
|
|
@@ -94,7 +98,13 @@ class QwenImageDiTFBCache(QwenImageDiT):
|
|
|
94
98
|
first_hidden_states = image.clone()
|
|
95
99
|
|
|
96
100
|
for block in self.transformer_blocks[1:]:
|
|
97
|
-
text, image = block(
|
|
101
|
+
text, image = block(
|
|
102
|
+
image=image,
|
|
103
|
+
text=text,
|
|
104
|
+
temb=conditioning,
|
|
105
|
+
image_rotary_emb=image_rotary_emb,
|
|
106
|
+
attn_kwargs=attn_kwargs,
|
|
107
|
+
)
|
|
98
108
|
|
|
99
109
|
previous_residual = image - first_hidden_states
|
|
100
110
|
self.previous_residual = previous_residual
|
|
@@ -114,14 +124,12 @@ class QwenImageDiTFBCache(QwenImageDiT):
|
|
|
114
124
|
device: str,
|
|
115
125
|
dtype: torch.dtype,
|
|
116
126
|
num_layers: int = 60,
|
|
117
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
118
127
|
relative_l1_threshold: float = 0.05,
|
|
119
128
|
):
|
|
120
129
|
model = cls(
|
|
121
130
|
device="meta",
|
|
122
131
|
dtype=dtype,
|
|
123
132
|
num_layers=num_layers,
|
|
124
|
-
attn_kwargs=attn_kwargs,
|
|
125
133
|
relative_l1_threshold=relative_l1_threshold,
|
|
126
134
|
)
|
|
127
135
|
model = model.requires_grad_(False)
|