diffsynth-engine 0.6.1.dev14__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +6 -2
- diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
- diffsynth_engine/configs/__init__.py +10 -6
- diffsynth_engine/configs/pipeline.py +17 -10
- diffsynth_engine/models/base.py +1 -1
- diffsynth_engine/models/basic/attention.py +59 -20
- diffsynth_engine/models/basic/transformer_helper.py +36 -2
- diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
- diffsynth_engine/models/flux/flux_controlnet.py +7 -19
- diffsynth_engine/models/flux/flux_dit.py +27 -38
- diffsynth_engine/models/flux/flux_dit_fbcache.py +9 -7
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
- diffsynth_engine/models/qwen_image/qwen_image_dit.py +28 -34
- diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
- diffsynth_engine/models/wan/wan_audio_encoder.py +0 -1
- diffsynth_engine/models/wan/wan_dit.py +64 -27
- diffsynth_engine/pipelines/base.py +36 -4
- diffsynth_engine/pipelines/flux_image.py +19 -17
- diffsynth_engine/pipelines/qwen_image.py +45 -36
- diffsynth_engine/pipelines/sdxl_image.py +1 -1
- diffsynth_engine/pipelines/utils.py +52 -0
- diffsynth_engine/pipelines/wan_s2v.py +4 -9
- diffsynth_engine/pipelines/wan_video.py +43 -19
- diffsynth_engine/tokenizers/base.py +6 -0
- diffsynth_engine/tokenizers/qwen2.py +12 -4
- diffsynth_engine/utils/constants.py +13 -12
- diffsynth_engine/utils/flag.py +6 -0
- diffsynth_engine/utils/parallel.py +62 -29
- {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +45 -43
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
- {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
|
@@ -176,7 +176,6 @@ class FluxDoubleAttention(nn.Module):
|
|
|
176
176
|
dim_b,
|
|
177
177
|
num_heads,
|
|
178
178
|
head_dim,
|
|
179
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
180
179
|
device: str = "cuda:0",
|
|
181
180
|
dtype: torch.dtype = torch.bfloat16,
|
|
182
181
|
):
|
|
@@ -194,19 +193,20 @@ class FluxDoubleAttention(nn.Module):
|
|
|
194
193
|
|
|
195
194
|
self.a_to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype)
|
|
196
195
|
self.b_to_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype)
|
|
197
|
-
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
198
196
|
|
|
199
197
|
def attention_callback(self, attn_out_a, attn_out_b, x_a, x_b, q_a, q_b, k_a, k_b, v_a, v_b, rope_emb, image_emb):
|
|
200
198
|
return attn_out_a, attn_out_b
|
|
201
199
|
|
|
202
|
-
def forward(self, image, text, rope_emb, image_emb):
|
|
200
|
+
def forward(self, image, text, rope_emb, image_emb, attn_kwargs=None):
|
|
203
201
|
q_a, k_a, v_a = rearrange(self.a_to_qkv(image), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
|
|
204
202
|
q_b, k_b, v_b = rearrange(self.b_to_qkv(text), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
|
|
205
203
|
q = torch.cat([self.norm_q_b(q_b), self.norm_q_a(q_a)], dim=1)
|
|
206
204
|
k = torch.cat([self.norm_k_b(k_b), self.norm_k_a(k_a)], dim=1)
|
|
207
205
|
v = torch.cat([v_b, v_a], dim=1)
|
|
208
206
|
q, k = apply_rope(q, k, rope_emb)
|
|
209
|
-
|
|
207
|
+
|
|
208
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
209
|
+
attn_out = attention_ops.attention(q, k, v, **attn_kwargs)
|
|
210
210
|
attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
|
|
211
211
|
text_out, image_out = attn_out[:, : text.shape[1]], attn_out[:, text.shape[1] :]
|
|
212
212
|
image_out, text_out = self.attention_callback(
|
|
@@ -231,14 +231,11 @@ class FluxDoubleTransformerBlock(nn.Module):
|
|
|
231
231
|
self,
|
|
232
232
|
dim,
|
|
233
233
|
num_heads,
|
|
234
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
235
234
|
device: str = "cuda:0",
|
|
236
235
|
dtype: torch.dtype = torch.bfloat16,
|
|
237
236
|
):
|
|
238
237
|
super().__init__()
|
|
239
|
-
self.attn = FluxDoubleAttention(
|
|
240
|
-
dim, dim, num_heads, dim // num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype
|
|
241
|
-
)
|
|
238
|
+
self.attn = FluxDoubleAttention(dim, dim, num_heads, dim // num_heads, device=device, dtype=dtype)
|
|
242
239
|
# Image
|
|
243
240
|
self.norm_msa_a = AdaLayerNormZero(dim, device=device, dtype=dtype)
|
|
244
241
|
self.norm_mlp_a = AdaLayerNormZero(dim, device=device, dtype=dtype)
|
|
@@ -256,11 +253,11 @@ class FluxDoubleTransformerBlock(nn.Module):
|
|
|
256
253
|
nn.Linear(dim * 4, dim, device=device, dtype=dtype),
|
|
257
254
|
)
|
|
258
255
|
|
|
259
|
-
def forward(self, image, text, t_emb, rope_emb, image_emb=None):
|
|
256
|
+
def forward(self, image, text, t_emb, rope_emb, image_emb=None, attn_kwargs=None):
|
|
260
257
|
# AdaLayerNorm-Zero for Image and Text MSA
|
|
261
258
|
image_in, gate_a = self.norm_msa_a(image, t_emb)
|
|
262
259
|
text_in, gate_b = self.norm_msa_b(text, t_emb)
|
|
263
|
-
image_out, text_out = self.attn(image_in, text_in, rope_emb, image_emb)
|
|
260
|
+
image_out, text_out = self.attn(image_in, text_in, rope_emb, image_emb, attn_kwargs)
|
|
264
261
|
image = image + gate_a * image_out
|
|
265
262
|
text = text + gate_b * text_out
|
|
266
263
|
|
|
@@ -279,7 +276,6 @@ class FluxSingleAttention(nn.Module):
|
|
|
279
276
|
self,
|
|
280
277
|
dim,
|
|
281
278
|
num_heads,
|
|
282
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
283
279
|
device: str = "cuda:0",
|
|
284
280
|
dtype: torch.dtype = torch.bfloat16,
|
|
285
281
|
):
|
|
@@ -288,15 +284,16 @@ class FluxSingleAttention(nn.Module):
|
|
|
288
284
|
self.to_qkv = nn.Linear(dim, dim * 3, device=device, dtype=dtype)
|
|
289
285
|
self.norm_q_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype)
|
|
290
286
|
self.norm_k_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype)
|
|
291
|
-
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
292
287
|
|
|
293
288
|
def attention_callback(self, attn_out, x, q, k, v, rope_emb, image_emb):
|
|
294
289
|
return attn_out
|
|
295
290
|
|
|
296
|
-
def forward(self, x, rope_emb, image_emb):
|
|
291
|
+
def forward(self, x, rope_emb, image_emb, attn_kwargs=None):
|
|
297
292
|
q, k, v = rearrange(self.to_qkv(x), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
|
|
298
293
|
q, k = apply_rope(self.norm_q_a(q), self.norm_k_a(k), rope_emb)
|
|
299
|
-
|
|
294
|
+
|
|
295
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
296
|
+
attn_out = attention_ops.attention(q, k, v, **attn_kwargs)
|
|
300
297
|
attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
|
|
301
298
|
return self.attention_callback(attn_out=attn_out, x=x, q=q, k=k, v=v, rope_emb=rope_emb, image_emb=image_emb)
|
|
302
299
|
|
|
@@ -306,23 +303,22 @@ class FluxSingleTransformerBlock(nn.Module):
|
|
|
306
303
|
self,
|
|
307
304
|
dim,
|
|
308
305
|
num_heads,
|
|
309
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
310
306
|
device: str = "cuda:0",
|
|
311
307
|
dtype: torch.dtype = torch.bfloat16,
|
|
312
308
|
):
|
|
313
309
|
super().__init__()
|
|
314
310
|
self.dim = dim
|
|
315
311
|
self.norm = AdaLayerNormZero(dim, device=device, dtype=dtype)
|
|
316
|
-
self.attn = FluxSingleAttention(dim, num_heads,
|
|
312
|
+
self.attn = FluxSingleAttention(dim, num_heads, device=device, dtype=dtype)
|
|
317
313
|
self.mlp = nn.Sequential(
|
|
318
314
|
nn.Linear(dim, dim * 4, device=device, dtype=dtype),
|
|
319
315
|
nn.GELU(approximate="tanh"),
|
|
320
316
|
)
|
|
321
317
|
self.proj_out = nn.Linear(dim * 5, dim, device=device, dtype=dtype)
|
|
322
318
|
|
|
323
|
-
def forward(self, x, t_emb, rope_emb, image_emb=None):
|
|
319
|
+
def forward(self, x, t_emb, rope_emb, image_emb=None, attn_kwargs=None):
|
|
324
320
|
h, gate = self.norm(x, emb=t_emb)
|
|
325
|
-
attn_output = self.attn(h, rope_emb, image_emb)
|
|
321
|
+
attn_output = self.attn(h, rope_emb, image_emb, attn_kwargs)
|
|
326
322
|
mlp_output = self.mlp(h)
|
|
327
323
|
return x + gate * self.proj_out(torch.cat([attn_output, mlp_output], dim=2))
|
|
328
324
|
|
|
@@ -334,7 +330,6 @@ class FluxDiT(PreTrainedModel):
|
|
|
334
330
|
def __init__(
|
|
335
331
|
self,
|
|
336
332
|
in_channel: int = 64,
|
|
337
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
338
333
|
device: str = "cuda:0",
|
|
339
334
|
dtype: torch.dtype = torch.bfloat16,
|
|
340
335
|
):
|
|
@@ -352,16 +347,10 @@ class FluxDiT(PreTrainedModel):
|
|
|
352
347
|
self.x_embedder = nn.Linear(in_channel, 3072, device=device, dtype=dtype)
|
|
353
348
|
|
|
354
349
|
self.blocks = nn.ModuleList(
|
|
355
|
-
[
|
|
356
|
-
FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
|
|
357
|
-
for _ in range(19)
|
|
358
|
-
]
|
|
350
|
+
[FluxDoubleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(19)]
|
|
359
351
|
)
|
|
360
352
|
self.single_blocks = nn.ModuleList(
|
|
361
|
-
[
|
|
362
|
-
FluxSingleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
|
|
363
|
-
for _ in range(38)
|
|
364
|
-
]
|
|
353
|
+
[FluxSingleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(38)]
|
|
365
354
|
)
|
|
366
355
|
self.final_norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
|
|
367
356
|
self.final_proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
|
|
@@ -403,6 +392,7 @@ class FluxDiT(PreTrainedModel):
|
|
|
403
392
|
text_ids: torch.Tensor,
|
|
404
393
|
guidance: torch.Tensor,
|
|
405
394
|
image_emb: torch.Tensor | None = None,
|
|
395
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
406
396
|
controlnet_double_block_output: List[torch.Tensor] | None = None,
|
|
407
397
|
controlnet_single_block_output: List[torch.Tensor] | None = None,
|
|
408
398
|
**kwargs,
|
|
@@ -470,14 +460,16 @@ class FluxDiT(PreTrainedModel):
|
|
|
470
460
|
rope_emb = torch.cat((text_rope_emb, image_rope_emb), dim=2)
|
|
471
461
|
|
|
472
462
|
for i, block in enumerate(self.blocks):
|
|
473
|
-
hidden_states, prompt_emb = block(
|
|
463
|
+
hidden_states, prompt_emb = block(
|
|
464
|
+
hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
|
|
465
|
+
)
|
|
474
466
|
if len(controlnet_double_block_output) > 0:
|
|
475
467
|
interval_control = len(self.blocks) / len(controlnet_double_block_output)
|
|
476
468
|
interval_control = int(np.ceil(interval_control))
|
|
477
469
|
hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
|
|
478
470
|
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
|
479
471
|
for i, block in enumerate(self.single_blocks):
|
|
480
|
-
hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
|
|
472
|
+
hidden_states = block(hidden_states, conditioning, rope_emb, image_emb, attn_kwargs)
|
|
481
473
|
if len(controlnet_single_block_output) > 0:
|
|
482
474
|
interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
|
|
483
475
|
interval_control = int(np.ceil(interval_control))
|
|
@@ -498,14 +490,8 @@ class FluxDiT(PreTrainedModel):
|
|
|
498
490
|
device: str,
|
|
499
491
|
dtype: torch.dtype,
|
|
500
492
|
in_channel: int = 64,
|
|
501
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
502
493
|
):
|
|
503
|
-
model = cls(
|
|
504
|
-
device="meta",
|
|
505
|
-
dtype=dtype,
|
|
506
|
-
in_channel=in_channel,
|
|
507
|
-
attn_kwargs=attn_kwargs,
|
|
508
|
-
)
|
|
494
|
+
model = cls(device="meta", dtype=dtype, in_channel=in_channel)
|
|
509
495
|
model = model.requires_grad_(False)
|
|
510
496
|
model.load_state_dict(state_dict, assign=True)
|
|
511
497
|
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
@@ -515,5 +501,8 @@ class FluxDiT(PreTrainedModel):
|
|
|
515
501
|
for block in self.blocks:
|
|
516
502
|
block.compile(*args, **kwargs)
|
|
517
503
|
|
|
518
|
-
|
|
519
|
-
|
|
504
|
+
for block in self.single_blocks:
|
|
505
|
+
block.compile(*args, **kwargs)
|
|
506
|
+
|
|
507
|
+
def get_fsdp_module_cls(self):
|
|
508
|
+
return {FluxDoubleTransformerBlock, FluxSingleTransformerBlock}
|
|
@@ -20,12 +20,11 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
20
20
|
def __init__(
|
|
21
21
|
self,
|
|
22
22
|
in_channel: int = 64,
|
|
23
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
24
23
|
device: str = "cuda:0",
|
|
25
24
|
dtype: torch.dtype = torch.bfloat16,
|
|
26
25
|
relative_l1_threshold: float = 0.05,
|
|
27
26
|
):
|
|
28
|
-
super().__init__(in_channel=in_channel,
|
|
27
|
+
super().__init__(in_channel=in_channel, device=device, dtype=dtype)
|
|
29
28
|
self.relative_l1_threshold = relative_l1_threshold
|
|
30
29
|
self.step_count = 0
|
|
31
30
|
self.num_inference_steps = 0
|
|
@@ -56,6 +55,7 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
56
55
|
text_ids: torch.Tensor,
|
|
57
56
|
guidance: torch.Tensor,
|
|
58
57
|
image_emb: torch.Tensor | None = None,
|
|
58
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
59
59
|
controlnet_double_block_output: List[torch.Tensor] | None = None,
|
|
60
60
|
controlnet_single_block_output: List[torch.Tensor] | None = None,
|
|
61
61
|
**kwargs,
|
|
@@ -124,7 +124,9 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
124
124
|
|
|
125
125
|
# first block
|
|
126
126
|
original_hidden_states = hidden_states
|
|
127
|
-
hidden_states, prompt_emb = self.blocks[0](
|
|
127
|
+
hidden_states, prompt_emb = self.blocks[0](
|
|
128
|
+
hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
|
|
129
|
+
)
|
|
128
130
|
first_hidden_states_residual = hidden_states - original_hidden_states
|
|
129
131
|
|
|
130
132
|
(first_hidden_states_residual,) = sequence_parallel_unshard(
|
|
@@ -149,14 +151,16 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
149
151
|
|
|
150
152
|
first_hidden_states = hidden_states.clone()
|
|
151
153
|
for i, block in enumerate(self.blocks[1:]):
|
|
152
|
-
hidden_states, prompt_emb = block(
|
|
154
|
+
hidden_states, prompt_emb = block(
|
|
155
|
+
hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
|
|
156
|
+
)
|
|
153
157
|
if len(controlnet_double_block_output) > 0:
|
|
154
158
|
interval_control = len(self.blocks) / len(controlnet_double_block_output)
|
|
155
159
|
interval_control = int(np.ceil(interval_control))
|
|
156
160
|
hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
|
|
157
161
|
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
|
158
162
|
for i, block in enumerate(self.single_blocks):
|
|
159
|
-
hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
|
|
163
|
+
hidden_states = block(hidden_states, conditioning, rope_emb, image_emb, attn_kwargs)
|
|
160
164
|
if len(controlnet_single_block_output) > 0:
|
|
161
165
|
interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
|
|
162
166
|
interval_control = int(np.ceil(interval_control))
|
|
@@ -182,14 +186,12 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
182
186
|
device: str,
|
|
183
187
|
dtype: torch.dtype,
|
|
184
188
|
in_channel: int = 64,
|
|
185
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
186
189
|
relative_l1_threshold: float = 0.05,
|
|
187
190
|
):
|
|
188
191
|
model = cls(
|
|
189
192
|
device="meta",
|
|
190
193
|
dtype=dtype,
|
|
191
194
|
in_channel=in_channel,
|
|
192
|
-
attn_kwargs=attn_kwargs,
|
|
193
195
|
relative_l1_threshold=relative_l1_threshold,
|
|
194
196
|
)
|
|
195
197
|
model = model.requires_grad_(False)
|
|
@@ -2,7 +2,7 @@ import torch
|
|
|
2
2
|
from einops import rearrange
|
|
3
3
|
from torch import nn
|
|
4
4
|
from PIL import Image
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Dict, List
|
|
6
6
|
from functools import partial
|
|
7
7
|
from diffsynth_engine.models.text_encoder.siglip import SiglipImageEncoder
|
|
8
8
|
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
|
|
@@ -18,7 +18,6 @@ class FluxIPAdapterAttention(nn.Module):
|
|
|
18
18
|
dim: int = 3072,
|
|
19
19
|
head_num: int = 24,
|
|
20
20
|
scale: float = 1.0,
|
|
21
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
22
21
|
device: str = "cuda:0",
|
|
23
22
|
dtype: torch.dtype = torch.bfloat16,
|
|
24
23
|
):
|
|
@@ -28,12 +27,13 @@ class FluxIPAdapterAttention(nn.Module):
|
|
|
28
27
|
self.to_v_ip = nn.Linear(image_emb_dim, dim, device=device, dtype=dtype, bias=False)
|
|
29
28
|
self.head_num = head_num
|
|
30
29
|
self.scale = scale
|
|
31
|
-
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
32
30
|
|
|
33
|
-
def forward(self, query: torch.Tensor, image_emb: torch.Tensor):
|
|
31
|
+
def forward(self, query: torch.Tensor, image_emb: torch.Tensor, attn_kwargs=None):
|
|
34
32
|
key = rearrange(self.norm_k(self.to_k_ip(image_emb)), "b s (h d) -> b s h d", h=self.head_num)
|
|
35
33
|
value = rearrange(self.to_v_ip(image_emb), "b s (h d) -> b s h d", h=self.head_num)
|
|
36
|
-
|
|
34
|
+
|
|
35
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
36
|
+
attn_out = attention(query, key, value, **attn_kwargs)
|
|
37
37
|
return self.scale * rearrange(attn_out, "b s h d -> b s (h d)")
|
|
38
38
|
|
|
39
39
|
@classmethod
|
|
@@ -942,6 +942,8 @@ class Qwen2_5_VLModel(nn.Module):
|
|
|
942
942
|
|
|
943
943
|
|
|
944
944
|
class Qwen2_5_VLForConditionalGeneration(PreTrainedModel):
|
|
945
|
+
_supports_parallelization = True
|
|
946
|
+
|
|
945
947
|
def __init__(
|
|
946
948
|
self,
|
|
947
949
|
vision_config: Qwen2_5_VLVisionConfig,
|
|
@@ -1173,6 +1175,9 @@ class Qwen2_5_VLForConditionalGeneration(PreTrainedModel):
|
|
|
1173
1175
|
|
|
1174
1176
|
return position_ids, mrope_position_deltas
|
|
1175
1177
|
|
|
1178
|
+
def get_fsdp_module_cls(self):
|
|
1179
|
+
return {Qwen2_5_VisionBlock, Qwen2_5_VLDecoderLayer}
|
|
1180
|
+
|
|
1176
1181
|
def forward(
|
|
1177
1182
|
self,
|
|
1178
1183
|
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -6,7 +6,7 @@ from einops import rearrange
|
|
|
6
6
|
from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
|
|
7
7
|
from diffsynth_engine.models.basic import attention as attention_ops
|
|
8
8
|
from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
|
|
9
|
-
from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm,
|
|
9
|
+
from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, GELU, RMSNorm
|
|
10
10
|
from diffsynth_engine.utils.gguf import gguf_inference
|
|
11
11
|
from diffsynth_engine.utils.fp8_linear import fp8_inference
|
|
12
12
|
from diffsynth_engine.utils.parallel import (
|
|
@@ -144,7 +144,7 @@ class QwenFeedForward(nn.Module):
|
|
|
144
144
|
super().__init__()
|
|
145
145
|
inner_dim = int(dim * 4)
|
|
146
146
|
self.net = nn.ModuleList([])
|
|
147
|
-
self.net.append(
|
|
147
|
+
self.net.append(GELU(dim, inner_dim, approximate="tanh", device=device, dtype=dtype))
|
|
148
148
|
self.net.append(nn.Dropout(dropout))
|
|
149
149
|
self.net.append(nn.Linear(inner_dim, dim_out, device=device, dtype=dtype))
|
|
150
150
|
|
|
@@ -155,8 +155,8 @@ class QwenFeedForward(nn.Module):
|
|
|
155
155
|
|
|
156
156
|
|
|
157
157
|
def apply_rotary_emb_qwen(x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]):
|
|
158
|
-
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
|
159
|
-
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
|
158
|
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) # (b, s, h, d) -> (b, s, h, d/2, 2)
|
|
159
|
+
x_out = torch.view_as_real(x_rotated * freqs_cis.unsqueeze(1)).flatten(3) # (b, s, h, d/2, 2) -> (b, s, h, d)
|
|
160
160
|
return x_out.type_as(x)
|
|
161
161
|
|
|
162
162
|
|
|
@@ -167,7 +167,6 @@ class QwenDoubleStreamAttention(nn.Module):
|
|
|
167
167
|
dim_b,
|
|
168
168
|
num_heads,
|
|
169
169
|
head_dim,
|
|
170
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
171
170
|
device: str = "cuda:0",
|
|
172
171
|
dtype: torch.dtype = torch.bfloat16,
|
|
173
172
|
):
|
|
@@ -189,7 +188,6 @@ class QwenDoubleStreamAttention(nn.Module):
|
|
|
189
188
|
|
|
190
189
|
self.to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype)
|
|
191
190
|
self.to_add_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype)
|
|
192
|
-
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
193
191
|
|
|
194
192
|
def forward(
|
|
195
193
|
self,
|
|
@@ -197,17 +195,18 @@ class QwenDoubleStreamAttention(nn.Module):
|
|
|
197
195
|
text: torch.FloatTensor,
|
|
198
196
|
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
199
197
|
attn_mask: Optional[torch.Tensor] = None,
|
|
198
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
200
199
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
201
200
|
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
|
|
202
201
|
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
|
|
203
202
|
|
|
204
|
-
img_q = rearrange(img_q, "b s (h d) -> b h
|
|
205
|
-
img_k = rearrange(img_k, "b s (h d) -> b h
|
|
206
|
-
img_v = rearrange(img_v, "b s (h d) -> b h
|
|
203
|
+
img_q = rearrange(img_q, "b s (h d) -> b s h d", h=self.num_heads)
|
|
204
|
+
img_k = rearrange(img_k, "b s (h d) -> b s h d", h=self.num_heads)
|
|
205
|
+
img_v = rearrange(img_v, "b s (h d) -> b s h d", h=self.num_heads)
|
|
207
206
|
|
|
208
|
-
txt_q = rearrange(txt_q, "b s (h d) -> b h
|
|
209
|
-
txt_k = rearrange(txt_k, "b s (h d) -> b h
|
|
210
|
-
txt_v = rearrange(txt_v, "b s (h d) -> b h
|
|
207
|
+
txt_q = rearrange(txt_q, "b s (h d) -> b s h d", h=self.num_heads)
|
|
208
|
+
txt_k = rearrange(txt_k, "b s (h d) -> b s h d", h=self.num_heads)
|
|
209
|
+
txt_v = rearrange(txt_v, "b s (h d) -> b s h d", h=self.num_heads)
|
|
211
210
|
|
|
212
211
|
img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
|
|
213
212
|
txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
|
|
@@ -219,15 +218,12 @@ class QwenDoubleStreamAttention(nn.Module):
|
|
|
219
218
|
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
|
|
220
219
|
txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
|
|
221
220
|
|
|
222
|
-
joint_q = torch.cat([txt_q, img_q], dim=
|
|
223
|
-
joint_k = torch.cat([txt_k, img_k], dim=
|
|
224
|
-
joint_v = torch.cat([txt_v, img_v], dim=
|
|
221
|
+
joint_q = torch.cat([txt_q, img_q], dim=1)
|
|
222
|
+
joint_k = torch.cat([txt_k, img_k], dim=1)
|
|
223
|
+
joint_v = torch.cat([txt_v, img_v], dim=1)
|
|
225
224
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
joint_v = joint_v.transpose(1, 2)
|
|
229
|
-
|
|
230
|
-
joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **self.attn_kwargs)
|
|
225
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
226
|
+
joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **attn_kwargs)
|
|
231
227
|
|
|
232
228
|
joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype)
|
|
233
229
|
|
|
@@ -247,7 +243,6 @@ class QwenImageTransformerBlock(nn.Module):
|
|
|
247
243
|
num_attention_heads: int,
|
|
248
244
|
attention_head_dim: int,
|
|
249
245
|
eps: float = 1e-6,
|
|
250
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
251
246
|
device: str = "cuda:0",
|
|
252
247
|
dtype: torch.dtype = torch.bfloat16,
|
|
253
248
|
):
|
|
@@ -267,7 +262,6 @@ class QwenImageTransformerBlock(nn.Module):
|
|
|
267
262
|
dim_b=dim,
|
|
268
263
|
num_heads=num_attention_heads,
|
|
269
264
|
head_dim=attention_head_dim,
|
|
270
|
-
attn_kwargs=attn_kwargs,
|
|
271
265
|
device=device,
|
|
272
266
|
dtype=dtype,
|
|
273
267
|
)
|
|
@@ -293,6 +287,7 @@ class QwenImageTransformerBlock(nn.Module):
|
|
|
293
287
|
temb: torch.Tensor,
|
|
294
288
|
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
295
289
|
attn_mask: Optional[torch.Tensor] = None,
|
|
290
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
296
291
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
297
292
|
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
|
298
293
|
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
|
@@ -308,6 +303,7 @@ class QwenImageTransformerBlock(nn.Module):
|
|
|
308
303
|
text=txt_modulated,
|
|
309
304
|
rotary_emb=rotary_emb,
|
|
310
305
|
attn_mask=attn_mask,
|
|
306
|
+
attn_kwargs=attn_kwargs,
|
|
311
307
|
)
|
|
312
308
|
|
|
313
309
|
image = image + img_gate * img_attn_out
|
|
@@ -335,7 +331,6 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
335
331
|
def __init__(
|
|
336
332
|
self,
|
|
337
333
|
num_layers: int = 60,
|
|
338
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
339
334
|
device: str = "cuda:0",
|
|
340
335
|
dtype: torch.dtype = torch.bfloat16,
|
|
341
336
|
):
|
|
@@ -356,7 +351,6 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
356
351
|
dim=3072,
|
|
357
352
|
num_attention_heads=24,
|
|
358
353
|
attention_head_dim=128,
|
|
359
|
-
attn_kwargs=attn_kwargs,
|
|
360
354
|
device=device,
|
|
361
355
|
dtype=dtype,
|
|
362
356
|
)
|
|
@@ -444,6 +438,7 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
444
438
|
entity_text: Optional[List[torch.Tensor]] = None,
|
|
445
439
|
entity_seq_lens: Optional[List[torch.LongTensor]] = None,
|
|
446
440
|
entity_masks: Optional[List[torch.Tensor]] = None,
|
|
441
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
447
442
|
):
|
|
448
443
|
h, w = image.shape[-2:]
|
|
449
444
|
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
|
|
@@ -509,7 +504,12 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
509
504
|
rotary_emb = (img_freqs, txt_freqs)
|
|
510
505
|
for block in self.transformer_blocks:
|
|
511
506
|
text, image = block(
|
|
512
|
-
image=image,
|
|
507
|
+
image=image,
|
|
508
|
+
text=text,
|
|
509
|
+
temb=conditioning,
|
|
510
|
+
rotary_emb=rotary_emb,
|
|
511
|
+
attn_mask=attn_mask,
|
|
512
|
+
attn_kwargs=attn_kwargs,
|
|
513
513
|
)
|
|
514
514
|
image = self.norm_out(image, conditioning)
|
|
515
515
|
image = self.proj_out(image)
|
|
@@ -527,14 +527,8 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
527
527
|
device: str,
|
|
528
528
|
dtype: torch.dtype,
|
|
529
529
|
num_layers: int = 60,
|
|
530
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
531
530
|
):
|
|
532
|
-
model = cls(
|
|
533
|
-
device="meta",
|
|
534
|
-
dtype=dtype,
|
|
535
|
-
num_layers=num_layers,
|
|
536
|
-
attn_kwargs=attn_kwargs,
|
|
537
|
-
)
|
|
531
|
+
model = cls(device="meta", dtype=dtype, num_layers=num_layers)
|
|
538
532
|
model = model.requires_grad_(False)
|
|
539
533
|
model.load_state_dict(state_dict, assign=True)
|
|
540
534
|
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
@@ -544,5 +538,5 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
544
538
|
for block in self.transformer_blocks:
|
|
545
539
|
block.compile(*args, **kwargs)
|
|
546
540
|
|
|
547
|
-
def
|
|
548
|
-
return
|
|
541
|
+
def get_fsdp_module_cls(self):
|
|
542
|
+
return {QwenImageTransformerBlock}
|
|
@@ -11,12 +11,11 @@ class QwenImageDiTFBCache(QwenImageDiT):
|
|
|
11
11
|
def __init__(
|
|
12
12
|
self,
|
|
13
13
|
num_layers: int = 60,
|
|
14
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
15
14
|
device: str = "cuda:0",
|
|
16
15
|
dtype: torch.dtype = torch.bfloat16,
|
|
17
16
|
relative_l1_threshold: float = 0.05,
|
|
18
17
|
):
|
|
19
|
-
super().__init__(num_layers=num_layers,
|
|
18
|
+
super().__init__(num_layers=num_layers, device=device, dtype=dtype)
|
|
20
19
|
self.relative_l1_threshold = relative_l1_threshold
|
|
21
20
|
self.step_count = 0
|
|
22
21
|
self.num_inference_steps = 0
|
|
@@ -43,6 +42,7 @@ class QwenImageDiTFBCache(QwenImageDiT):
|
|
|
43
42
|
text: torch.Tensor = None,
|
|
44
43
|
timestep: torch.LongTensor = None,
|
|
45
44
|
txt_seq_lens: torch.LongTensor = None,
|
|
45
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
46
46
|
):
|
|
47
47
|
h, w = image.shape[-2:]
|
|
48
48
|
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
|
|
@@ -72,7 +72,11 @@ class QwenImageDiTFBCache(QwenImageDiT):
|
|
|
72
72
|
# first block
|
|
73
73
|
original_hidden_states = image
|
|
74
74
|
text, image = self.transformer_blocks[0](
|
|
75
|
-
image=image,
|
|
75
|
+
image=image,
|
|
76
|
+
text=text,
|
|
77
|
+
temb=conditioning,
|
|
78
|
+
image_rotary_emb=image_rotary_emb,
|
|
79
|
+
attn_kwargs=attn_kwargs,
|
|
76
80
|
)
|
|
77
81
|
first_hidden_states_residual = image - original_hidden_states
|
|
78
82
|
|
|
@@ -94,7 +98,13 @@ class QwenImageDiTFBCache(QwenImageDiT):
|
|
|
94
98
|
first_hidden_states = image.clone()
|
|
95
99
|
|
|
96
100
|
for block in self.transformer_blocks[1:]:
|
|
97
|
-
text, image = block(
|
|
101
|
+
text, image = block(
|
|
102
|
+
image=image,
|
|
103
|
+
text=text,
|
|
104
|
+
temb=conditioning,
|
|
105
|
+
image_rotary_emb=image_rotary_emb,
|
|
106
|
+
attn_kwargs=attn_kwargs,
|
|
107
|
+
)
|
|
98
108
|
|
|
99
109
|
previous_residual = image - first_hidden_states
|
|
100
110
|
self.previous_residual = previous_residual
|
|
@@ -114,14 +124,12 @@ class QwenImageDiTFBCache(QwenImageDiT):
|
|
|
114
124
|
device: str,
|
|
115
125
|
dtype: torch.dtype,
|
|
116
126
|
num_layers: int = 60,
|
|
117
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
118
127
|
relative_l1_threshold: float = 0.05,
|
|
119
128
|
):
|
|
120
129
|
model = cls(
|
|
121
130
|
device="meta",
|
|
122
131
|
dtype=dtype,
|
|
123
132
|
num_layers=num_layers,
|
|
124
|
-
attn_kwargs=attn_kwargs,
|
|
125
133
|
relative_l1_threshold=relative_l1_threshold,
|
|
126
134
|
)
|
|
127
135
|
model = model.requires_grad_(False)
|
|
@@ -223,7 +223,6 @@ class Wav2Vec2StateDictConverter:
|
|
|
223
223
|
|
|
224
224
|
class Wav2Vec2Model(PreTrainedModel):
|
|
225
225
|
converter = Wav2Vec2StateDictConverter()
|
|
226
|
-
_supports_parallelization = False
|
|
227
226
|
|
|
228
227
|
def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
|
|
229
228
|
super().__init__()
|