diffsynth-engine 0.6.1.dev22__py3-none-any.whl → 0.6.1.dev24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
  2. diffsynth_engine/configs/pipeline.py +35 -12
  3. diffsynth_engine/models/basic/attention.py +59 -20
  4. diffsynth_engine/models/basic/transformer_helper.py +36 -2
  5. diffsynth_engine/models/basic/video_sparse_attention.py +235 -0
  6. diffsynth_engine/models/flux/flux_controlnet.py +7 -19
  7. diffsynth_engine/models/flux/flux_dit.py +22 -36
  8. diffsynth_engine/models/flux/flux_dit_fbcache.py +9 -7
  9. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  10. diffsynth_engine/models/qwen_image/qwen_image_dit.py +26 -32
  11. diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
  12. diffsynth_engine/models/wan/wan_dit.py +62 -22
  13. diffsynth_engine/pipelines/flux_image.py +11 -10
  14. diffsynth_engine/pipelines/qwen_image.py +16 -15
  15. diffsynth_engine/pipelines/utils.py +52 -0
  16. diffsynth_engine/pipelines/wan_s2v.py +3 -8
  17. diffsynth_engine/pipelines/wan_video.py +11 -13
  18. diffsynth_engine/tokenizers/base.py +6 -0
  19. diffsynth_engine/tokenizers/qwen2.py +12 -4
  20. diffsynth_engine/utils/constants.py +13 -12
  21. diffsynth_engine/utils/flag.py +6 -0
  22. diffsynth_engine/utils/parallel.py +51 -6
  23. {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev24.dist-info}/METADATA +1 -1
  24. {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev24.dist-info}/RECORD +38 -36
  25. /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
  26. /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
  27. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
  28. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
  29. /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
  30. /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
  31. /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
  32. /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
  33. /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
  34. /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
  35. /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
  36. {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev24.dist-info}/WHEEL +0 -0
  37. {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev24.dist-info}/licenses/LICENSE +0 -0
  38. {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev24.dist-info}/top_level.txt +0 -0
@@ -176,7 +176,6 @@ class FluxDoubleAttention(nn.Module):
176
176
  dim_b,
177
177
  num_heads,
178
178
  head_dim,
179
- attn_kwargs: Optional[Dict[str, Any]] = None,
180
179
  device: str = "cuda:0",
181
180
  dtype: torch.dtype = torch.bfloat16,
182
181
  ):
@@ -194,19 +193,20 @@ class FluxDoubleAttention(nn.Module):
194
193
 
195
194
  self.a_to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype)
196
195
  self.b_to_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype)
197
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
198
196
 
199
197
  def attention_callback(self, attn_out_a, attn_out_b, x_a, x_b, q_a, q_b, k_a, k_b, v_a, v_b, rope_emb, image_emb):
200
198
  return attn_out_a, attn_out_b
201
199
 
202
- def forward(self, image, text, rope_emb, image_emb):
200
+ def forward(self, image, text, rope_emb, image_emb, attn_kwargs=None):
203
201
  q_a, k_a, v_a = rearrange(self.a_to_qkv(image), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
204
202
  q_b, k_b, v_b = rearrange(self.b_to_qkv(text), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
205
203
  q = torch.cat([self.norm_q_b(q_b), self.norm_q_a(q_a)], dim=1)
206
204
  k = torch.cat([self.norm_k_b(k_b), self.norm_k_a(k_a)], dim=1)
207
205
  v = torch.cat([v_b, v_a], dim=1)
208
206
  q, k = apply_rope(q, k, rope_emb)
209
- attn_out = attention_ops.attention(q, k, v, **self.attn_kwargs)
207
+
208
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
209
+ attn_out = attention_ops.attention(q, k, v, **attn_kwargs)
210
210
  attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
211
211
  text_out, image_out = attn_out[:, : text.shape[1]], attn_out[:, text.shape[1] :]
212
212
  image_out, text_out = self.attention_callback(
@@ -231,14 +231,11 @@ class FluxDoubleTransformerBlock(nn.Module):
231
231
  self,
232
232
  dim,
233
233
  num_heads,
234
- attn_kwargs: Optional[Dict[str, Any]] = None,
235
234
  device: str = "cuda:0",
236
235
  dtype: torch.dtype = torch.bfloat16,
237
236
  ):
238
237
  super().__init__()
239
- self.attn = FluxDoubleAttention(
240
- dim, dim, num_heads, dim // num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype
241
- )
238
+ self.attn = FluxDoubleAttention(dim, dim, num_heads, dim // num_heads, device=device, dtype=dtype)
242
239
  # Image
243
240
  self.norm_msa_a = AdaLayerNormZero(dim, device=device, dtype=dtype)
244
241
  self.norm_mlp_a = AdaLayerNormZero(dim, device=device, dtype=dtype)
@@ -256,11 +253,11 @@ class FluxDoubleTransformerBlock(nn.Module):
256
253
  nn.Linear(dim * 4, dim, device=device, dtype=dtype),
257
254
  )
258
255
 
259
- def forward(self, image, text, t_emb, rope_emb, image_emb=None):
256
+ def forward(self, image, text, t_emb, rope_emb, image_emb=None, attn_kwargs=None):
260
257
  # AdaLayerNorm-Zero for Image and Text MSA
261
258
  image_in, gate_a = self.norm_msa_a(image, t_emb)
262
259
  text_in, gate_b = self.norm_msa_b(text, t_emb)
263
- image_out, text_out = self.attn(image_in, text_in, rope_emb, image_emb)
260
+ image_out, text_out = self.attn(image_in, text_in, rope_emb, image_emb, attn_kwargs)
264
261
  image = image + gate_a * image_out
265
262
  text = text + gate_b * text_out
266
263
 
@@ -279,7 +276,6 @@ class FluxSingleAttention(nn.Module):
279
276
  self,
280
277
  dim,
281
278
  num_heads,
282
- attn_kwargs: Optional[Dict[str, Any]] = None,
283
279
  device: str = "cuda:0",
284
280
  dtype: torch.dtype = torch.bfloat16,
285
281
  ):
@@ -288,15 +284,16 @@ class FluxSingleAttention(nn.Module):
288
284
  self.to_qkv = nn.Linear(dim, dim * 3, device=device, dtype=dtype)
289
285
  self.norm_q_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype)
290
286
  self.norm_k_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype)
291
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
292
287
 
293
288
  def attention_callback(self, attn_out, x, q, k, v, rope_emb, image_emb):
294
289
  return attn_out
295
290
 
296
- def forward(self, x, rope_emb, image_emb):
291
+ def forward(self, x, rope_emb, image_emb, attn_kwargs=None):
297
292
  q, k, v = rearrange(self.to_qkv(x), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
298
293
  q, k = apply_rope(self.norm_q_a(q), self.norm_k_a(k), rope_emb)
299
- attn_out = attention_ops.attention(q, k, v, **self.attn_kwargs)
294
+
295
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
296
+ attn_out = attention_ops.attention(q, k, v, **attn_kwargs)
300
297
  attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
301
298
  return self.attention_callback(attn_out=attn_out, x=x, q=q, k=k, v=v, rope_emb=rope_emb, image_emb=image_emb)
302
299
 
@@ -306,23 +303,22 @@ class FluxSingleTransformerBlock(nn.Module):
306
303
  self,
307
304
  dim,
308
305
  num_heads,
309
- attn_kwargs: Optional[Dict[str, Any]] = None,
310
306
  device: str = "cuda:0",
311
307
  dtype: torch.dtype = torch.bfloat16,
312
308
  ):
313
309
  super().__init__()
314
310
  self.dim = dim
315
311
  self.norm = AdaLayerNormZero(dim, device=device, dtype=dtype)
316
- self.attn = FluxSingleAttention(dim, num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
312
+ self.attn = FluxSingleAttention(dim, num_heads, device=device, dtype=dtype)
317
313
  self.mlp = nn.Sequential(
318
314
  nn.Linear(dim, dim * 4, device=device, dtype=dtype),
319
315
  nn.GELU(approximate="tanh"),
320
316
  )
321
317
  self.proj_out = nn.Linear(dim * 5, dim, device=device, dtype=dtype)
322
318
 
323
- def forward(self, x, t_emb, rope_emb, image_emb=None):
319
+ def forward(self, x, t_emb, rope_emb, image_emb=None, attn_kwargs=None):
324
320
  h, gate = self.norm(x, emb=t_emb)
325
- attn_output = self.attn(h, rope_emb, image_emb)
321
+ attn_output = self.attn(h, rope_emb, image_emb, attn_kwargs)
326
322
  mlp_output = self.mlp(h)
327
323
  return x + gate * self.proj_out(torch.cat([attn_output, mlp_output], dim=2))
328
324
 
@@ -334,7 +330,6 @@ class FluxDiT(PreTrainedModel):
334
330
  def __init__(
335
331
  self,
336
332
  in_channel: int = 64,
337
- attn_kwargs: Optional[Dict[str, Any]] = None,
338
333
  device: str = "cuda:0",
339
334
  dtype: torch.dtype = torch.bfloat16,
340
335
  ):
@@ -352,16 +347,10 @@ class FluxDiT(PreTrainedModel):
352
347
  self.x_embedder = nn.Linear(in_channel, 3072, device=device, dtype=dtype)
353
348
 
354
349
  self.blocks = nn.ModuleList(
355
- [
356
- FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
357
- for _ in range(19)
358
- ]
350
+ [FluxDoubleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(19)]
359
351
  )
360
352
  self.single_blocks = nn.ModuleList(
361
- [
362
- FluxSingleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
363
- for _ in range(38)
364
- ]
353
+ [FluxSingleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(38)]
365
354
  )
366
355
  self.final_norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
367
356
  self.final_proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
@@ -403,6 +392,7 @@ class FluxDiT(PreTrainedModel):
403
392
  text_ids: torch.Tensor,
404
393
  guidance: torch.Tensor,
405
394
  image_emb: torch.Tensor | None = None,
395
+ attn_kwargs: Optional[Dict[str, Any]] = None,
406
396
  controlnet_double_block_output: List[torch.Tensor] | None = None,
407
397
  controlnet_single_block_output: List[torch.Tensor] | None = None,
408
398
  **kwargs,
@@ -470,14 +460,16 @@ class FluxDiT(PreTrainedModel):
470
460
  rope_emb = torch.cat((text_rope_emb, image_rope_emb), dim=2)
471
461
 
472
462
  for i, block in enumerate(self.blocks):
473
- hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
463
+ hidden_states, prompt_emb = block(
464
+ hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
465
+ )
474
466
  if len(controlnet_double_block_output) > 0:
475
467
  interval_control = len(self.blocks) / len(controlnet_double_block_output)
476
468
  interval_control = int(np.ceil(interval_control))
477
469
  hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
478
470
  hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
479
471
  for i, block in enumerate(self.single_blocks):
480
- hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
472
+ hidden_states = block(hidden_states, conditioning, rope_emb, image_emb, attn_kwargs)
481
473
  if len(controlnet_single_block_output) > 0:
482
474
  interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
483
475
  interval_control = int(np.ceil(interval_control))
@@ -498,14 +490,8 @@ class FluxDiT(PreTrainedModel):
498
490
  device: str,
499
491
  dtype: torch.dtype,
500
492
  in_channel: int = 64,
501
- attn_kwargs: Optional[Dict[str, Any]] = None,
502
493
  ):
503
- model = cls(
504
- device="meta",
505
- dtype=dtype,
506
- in_channel=in_channel,
507
- attn_kwargs=attn_kwargs,
508
- )
494
+ model = cls(device="meta", dtype=dtype, in_channel=in_channel)
509
495
  model = model.requires_grad_(False)
510
496
  model.load_state_dict(state_dict, assign=True)
511
497
  model.to(device=device, dtype=dtype, non_blocking=True)
@@ -20,12 +20,11 @@ class FluxDiTFBCache(FluxDiT):
20
20
  def __init__(
21
21
  self,
22
22
  in_channel: int = 64,
23
- attn_kwargs: Optional[Dict[str, Any]] = None,
24
23
  device: str = "cuda:0",
25
24
  dtype: torch.dtype = torch.bfloat16,
26
25
  relative_l1_threshold: float = 0.05,
27
26
  ):
28
- super().__init__(in_channel=in_channel, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
27
+ super().__init__(in_channel=in_channel, device=device, dtype=dtype)
29
28
  self.relative_l1_threshold = relative_l1_threshold
30
29
  self.step_count = 0
31
30
  self.num_inference_steps = 0
@@ -56,6 +55,7 @@ class FluxDiTFBCache(FluxDiT):
56
55
  text_ids: torch.Tensor,
57
56
  guidance: torch.Tensor,
58
57
  image_emb: torch.Tensor | None = None,
58
+ attn_kwargs: Optional[Dict[str, Any]] = None,
59
59
  controlnet_double_block_output: List[torch.Tensor] | None = None,
60
60
  controlnet_single_block_output: List[torch.Tensor] | None = None,
61
61
  **kwargs,
@@ -124,7 +124,9 @@ class FluxDiTFBCache(FluxDiT):
124
124
 
125
125
  # first block
126
126
  original_hidden_states = hidden_states
127
- hidden_states, prompt_emb = self.blocks[0](hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
127
+ hidden_states, prompt_emb = self.blocks[0](
128
+ hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
129
+ )
128
130
  first_hidden_states_residual = hidden_states - original_hidden_states
129
131
 
130
132
  (first_hidden_states_residual,) = sequence_parallel_unshard(
@@ -149,14 +151,16 @@ class FluxDiTFBCache(FluxDiT):
149
151
 
150
152
  first_hidden_states = hidden_states.clone()
151
153
  for i, block in enumerate(self.blocks[1:]):
152
- hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
154
+ hidden_states, prompt_emb = block(
155
+ hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
156
+ )
153
157
  if len(controlnet_double_block_output) > 0:
154
158
  interval_control = len(self.blocks) / len(controlnet_double_block_output)
155
159
  interval_control = int(np.ceil(interval_control))
156
160
  hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
157
161
  hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
158
162
  for i, block in enumerate(self.single_blocks):
159
- hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
163
+ hidden_states = block(hidden_states, conditioning, rope_emb, image_emb, attn_kwargs)
160
164
  if len(controlnet_single_block_output) > 0:
161
165
  interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
162
166
  interval_control = int(np.ceil(interval_control))
@@ -182,14 +186,12 @@ class FluxDiTFBCache(FluxDiT):
182
186
  device: str,
183
187
  dtype: torch.dtype,
184
188
  in_channel: int = 64,
185
- attn_kwargs: Optional[Dict[str, Any]] = None,
186
189
  relative_l1_threshold: float = 0.05,
187
190
  ):
188
191
  model = cls(
189
192
  device="meta",
190
193
  dtype=dtype,
191
194
  in_channel=in_channel,
192
- attn_kwargs=attn_kwargs,
193
195
  relative_l1_threshold=relative_l1_threshold,
194
196
  )
195
197
  model = model.requires_grad_(False)
@@ -2,7 +2,7 @@ import torch
2
2
  from einops import rearrange
3
3
  from torch import nn
4
4
  from PIL import Image
5
- from typing import Any, Dict, List, Optional
5
+ from typing import Dict, List
6
6
  from functools import partial
7
7
  from diffsynth_engine.models.text_encoder.siglip import SiglipImageEncoder
8
8
  from diffsynth_engine.models.basic.transformer_helper import RMSNorm
@@ -18,7 +18,6 @@ class FluxIPAdapterAttention(nn.Module):
18
18
  dim: int = 3072,
19
19
  head_num: int = 24,
20
20
  scale: float = 1.0,
21
- attn_kwargs: Optional[Dict[str, Any]] = None,
22
21
  device: str = "cuda:0",
23
22
  dtype: torch.dtype = torch.bfloat16,
24
23
  ):
@@ -28,12 +27,13 @@ class FluxIPAdapterAttention(nn.Module):
28
27
  self.to_v_ip = nn.Linear(image_emb_dim, dim, device=device, dtype=dtype, bias=False)
29
28
  self.head_num = head_num
30
29
  self.scale = scale
31
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
32
30
 
33
- def forward(self, query: torch.Tensor, image_emb: torch.Tensor):
31
+ def forward(self, query: torch.Tensor, image_emb: torch.Tensor, attn_kwargs=None):
34
32
  key = rearrange(self.norm_k(self.to_k_ip(image_emb)), "b s (h d) -> b s h d", h=self.head_num)
35
33
  value = rearrange(self.to_v_ip(image_emb), "b s (h d) -> b s h d", h=self.head_num)
36
- attn_out = attention(query, key, value, **self.attn_kwargs)
34
+
35
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
36
+ attn_out = attention(query, key, value, **attn_kwargs)
37
37
  return self.scale * rearrange(attn_out, "b s h d -> b s (h d)")
38
38
 
39
39
  @classmethod
@@ -6,7 +6,7 @@ from einops import rearrange
6
6
  from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
7
7
  from diffsynth_engine.models.basic import attention as attention_ops
8
8
  from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
9
- from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, ApproximateGELU, RMSNorm
9
+ from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, GELU, RMSNorm
10
10
  from diffsynth_engine.utils.gguf import gguf_inference
11
11
  from diffsynth_engine.utils.fp8_linear import fp8_inference
12
12
  from diffsynth_engine.utils.parallel import (
@@ -144,7 +144,7 @@ class QwenFeedForward(nn.Module):
144
144
  super().__init__()
145
145
  inner_dim = int(dim * 4)
146
146
  self.net = nn.ModuleList([])
147
- self.net.append(ApproximateGELU(dim, inner_dim, device=device, dtype=dtype))
147
+ self.net.append(GELU(dim, inner_dim, approximate="tanh", device=device, dtype=dtype))
148
148
  self.net.append(nn.Dropout(dropout))
149
149
  self.net.append(nn.Linear(inner_dim, dim_out, device=device, dtype=dtype))
150
150
 
@@ -155,8 +155,8 @@ class QwenFeedForward(nn.Module):
155
155
 
156
156
 
157
157
  def apply_rotary_emb_qwen(x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]):
158
- x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
159
- x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
158
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) # (b, s, h, d) -> (b, s, h, d/2, 2)
159
+ x_out = torch.view_as_real(x_rotated * freqs_cis.unsqueeze(1)).flatten(3) # (b, s, h, d/2, 2) -> (b, s, h, d)
160
160
  return x_out.type_as(x)
161
161
 
162
162
 
@@ -167,7 +167,6 @@ class QwenDoubleStreamAttention(nn.Module):
167
167
  dim_b,
168
168
  num_heads,
169
169
  head_dim,
170
- attn_kwargs: Optional[Dict[str, Any]] = None,
171
170
  device: str = "cuda:0",
172
171
  dtype: torch.dtype = torch.bfloat16,
173
172
  ):
@@ -189,7 +188,6 @@ class QwenDoubleStreamAttention(nn.Module):
189
188
 
190
189
  self.to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype)
191
190
  self.to_add_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype)
192
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
193
191
 
194
192
  def forward(
195
193
  self,
@@ -197,17 +195,18 @@ class QwenDoubleStreamAttention(nn.Module):
197
195
  text: torch.FloatTensor,
198
196
  rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
199
197
  attn_mask: Optional[torch.Tensor] = None,
198
+ attn_kwargs: Optional[Dict[str, Any]] = None,
200
199
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
201
200
  img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
202
201
  txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
203
202
 
204
- img_q = rearrange(img_q, "b s (h d) -> b h s d", h=self.num_heads)
205
- img_k = rearrange(img_k, "b s (h d) -> b h s d", h=self.num_heads)
206
- img_v = rearrange(img_v, "b s (h d) -> b h s d", h=self.num_heads)
203
+ img_q = rearrange(img_q, "b s (h d) -> b s h d", h=self.num_heads)
204
+ img_k = rearrange(img_k, "b s (h d) -> b s h d", h=self.num_heads)
205
+ img_v = rearrange(img_v, "b s (h d) -> b s h d", h=self.num_heads)
207
206
 
208
- txt_q = rearrange(txt_q, "b s (h d) -> b h s d", h=self.num_heads)
209
- txt_k = rearrange(txt_k, "b s (h d) -> b h s d", h=self.num_heads)
210
- txt_v = rearrange(txt_v, "b s (h d) -> b h s d", h=self.num_heads)
207
+ txt_q = rearrange(txt_q, "b s (h d) -> b s h d", h=self.num_heads)
208
+ txt_k = rearrange(txt_k, "b s (h d) -> b s h d", h=self.num_heads)
209
+ txt_v = rearrange(txt_v, "b s (h d) -> b s h d", h=self.num_heads)
211
210
 
212
211
  img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
213
212
  txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
@@ -219,15 +218,12 @@ class QwenDoubleStreamAttention(nn.Module):
219
218
  txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
220
219
  txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
221
220
 
222
- joint_q = torch.cat([txt_q, img_q], dim=2)
223
- joint_k = torch.cat([txt_k, img_k], dim=2)
224
- joint_v = torch.cat([txt_v, img_v], dim=2)
221
+ joint_q = torch.cat([txt_q, img_q], dim=1)
222
+ joint_k = torch.cat([txt_k, img_k], dim=1)
223
+ joint_v = torch.cat([txt_v, img_v], dim=1)
225
224
 
226
- joint_q = joint_q.transpose(1, 2)
227
- joint_k = joint_k.transpose(1, 2)
228
- joint_v = joint_v.transpose(1, 2)
229
-
230
- joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **self.attn_kwargs)
225
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
226
+ joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **attn_kwargs)
231
227
 
232
228
  joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype)
233
229
 
@@ -247,7 +243,6 @@ class QwenImageTransformerBlock(nn.Module):
247
243
  num_attention_heads: int,
248
244
  attention_head_dim: int,
249
245
  eps: float = 1e-6,
250
- attn_kwargs: Optional[Dict[str, Any]] = None,
251
246
  device: str = "cuda:0",
252
247
  dtype: torch.dtype = torch.bfloat16,
253
248
  ):
@@ -267,7 +262,6 @@ class QwenImageTransformerBlock(nn.Module):
267
262
  dim_b=dim,
268
263
  num_heads=num_attention_heads,
269
264
  head_dim=attention_head_dim,
270
- attn_kwargs=attn_kwargs,
271
265
  device=device,
272
266
  dtype=dtype,
273
267
  )
@@ -293,6 +287,7 @@ class QwenImageTransformerBlock(nn.Module):
293
287
  temb: torch.Tensor,
294
288
  rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
295
289
  attn_mask: Optional[torch.Tensor] = None,
290
+ attn_kwargs: Optional[Dict[str, Any]] = None,
296
291
  ) -> Tuple[torch.Tensor, torch.Tensor]:
297
292
  img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
298
293
  txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
@@ -308,6 +303,7 @@ class QwenImageTransformerBlock(nn.Module):
308
303
  text=txt_modulated,
309
304
  rotary_emb=rotary_emb,
310
305
  attn_mask=attn_mask,
306
+ attn_kwargs=attn_kwargs,
311
307
  )
312
308
 
313
309
  image = image + img_gate * img_attn_out
@@ -335,7 +331,6 @@ class QwenImageDiT(PreTrainedModel):
335
331
  def __init__(
336
332
  self,
337
333
  num_layers: int = 60,
338
- attn_kwargs: Optional[Dict[str, Any]] = None,
339
334
  device: str = "cuda:0",
340
335
  dtype: torch.dtype = torch.bfloat16,
341
336
  ):
@@ -356,7 +351,6 @@ class QwenImageDiT(PreTrainedModel):
356
351
  dim=3072,
357
352
  num_attention_heads=24,
358
353
  attention_head_dim=128,
359
- attn_kwargs=attn_kwargs,
360
354
  device=device,
361
355
  dtype=dtype,
362
356
  )
@@ -444,6 +438,7 @@ class QwenImageDiT(PreTrainedModel):
444
438
  entity_text: Optional[List[torch.Tensor]] = None,
445
439
  entity_seq_lens: Optional[List[torch.LongTensor]] = None,
446
440
  entity_masks: Optional[List[torch.Tensor]] = None,
441
+ attn_kwargs: Optional[Dict[str, Any]] = None,
447
442
  ):
448
443
  h, w = image.shape[-2:]
449
444
  fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
@@ -509,7 +504,12 @@ class QwenImageDiT(PreTrainedModel):
509
504
  rotary_emb = (img_freqs, txt_freqs)
510
505
  for block in self.transformer_blocks:
511
506
  text, image = block(
512
- image=image, text=text, temb=conditioning, rotary_emb=rotary_emb, attn_mask=attn_mask
507
+ image=image,
508
+ text=text,
509
+ temb=conditioning,
510
+ rotary_emb=rotary_emb,
511
+ attn_mask=attn_mask,
512
+ attn_kwargs=attn_kwargs,
513
513
  )
514
514
  image = self.norm_out(image, conditioning)
515
515
  image = self.proj_out(image)
@@ -527,14 +527,8 @@ class QwenImageDiT(PreTrainedModel):
527
527
  device: str,
528
528
  dtype: torch.dtype,
529
529
  num_layers: int = 60,
530
- attn_kwargs: Optional[Dict[str, Any]] = None,
531
530
  ):
532
- model = cls(
533
- device="meta",
534
- dtype=dtype,
535
- num_layers=num_layers,
536
- attn_kwargs=attn_kwargs,
537
- )
531
+ model = cls(device="meta", dtype=dtype, num_layers=num_layers)
538
532
  model = model.requires_grad_(False)
539
533
  model.load_state_dict(state_dict, assign=True)
540
534
  model.to(device=device, dtype=dtype, non_blocking=True)
@@ -11,12 +11,11 @@ class QwenImageDiTFBCache(QwenImageDiT):
11
11
  def __init__(
12
12
  self,
13
13
  num_layers: int = 60,
14
- attn_kwargs: Optional[Dict[str, Any]] = None,
15
14
  device: str = "cuda:0",
16
15
  dtype: torch.dtype = torch.bfloat16,
17
16
  relative_l1_threshold: float = 0.05,
18
17
  ):
19
- super().__init__(num_layers=num_layers, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
18
+ super().__init__(num_layers=num_layers, device=device, dtype=dtype)
20
19
  self.relative_l1_threshold = relative_l1_threshold
21
20
  self.step_count = 0
22
21
  self.num_inference_steps = 0
@@ -43,6 +42,7 @@ class QwenImageDiTFBCache(QwenImageDiT):
43
42
  text: torch.Tensor = None,
44
43
  timestep: torch.LongTensor = None,
45
44
  txt_seq_lens: torch.LongTensor = None,
45
+ attn_kwargs: Optional[Dict[str, Any]] = None,
46
46
  ):
47
47
  h, w = image.shape[-2:]
48
48
  fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
@@ -72,7 +72,11 @@ class QwenImageDiTFBCache(QwenImageDiT):
72
72
  # first block
73
73
  original_hidden_states = image
74
74
  text, image = self.transformer_blocks[0](
75
- image=image, text=text, temb=conditioning, image_rotary_emb=image_rotary_emb
75
+ image=image,
76
+ text=text,
77
+ temb=conditioning,
78
+ image_rotary_emb=image_rotary_emb,
79
+ attn_kwargs=attn_kwargs,
76
80
  )
77
81
  first_hidden_states_residual = image - original_hidden_states
78
82
 
@@ -94,7 +98,13 @@ class QwenImageDiTFBCache(QwenImageDiT):
94
98
  first_hidden_states = image.clone()
95
99
 
96
100
  for block in self.transformer_blocks[1:]:
97
- text, image = block(image=image, text=text, temb=conditioning, image_rotary_emb=image_rotary_emb)
101
+ text, image = block(
102
+ image=image,
103
+ text=text,
104
+ temb=conditioning,
105
+ image_rotary_emb=image_rotary_emb,
106
+ attn_kwargs=attn_kwargs,
107
+ )
98
108
 
99
109
  previous_residual = image - first_hidden_states
100
110
  self.previous_residual = previous_residual
@@ -114,14 +124,12 @@ class QwenImageDiTFBCache(QwenImageDiT):
114
124
  device: str,
115
125
  dtype: torch.dtype,
116
126
  num_layers: int = 60,
117
- attn_kwargs: Optional[Dict[str, Any]] = None,
118
127
  relative_l1_threshold: float = 0.05,
119
128
  ):
120
129
  model = cls(
121
130
  device="meta",
122
131
  dtype=dtype,
123
132
  num_layers=num_layers,
124
- attn_kwargs=attn_kwargs,
125
133
  relative_l1_threshold=relative_l1_threshold,
126
134
  )
127
135
  model = model.requires_grad_(False)