diffsynth-engine 0.6.1.dev21__py3-none-any.whl → 0.6.1.dev23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
  2. diffsynth_engine/configs/pipeline.py +35 -5
  3. diffsynth_engine/models/basic/attention.py +59 -20
  4. diffsynth_engine/models/basic/video_sparse_attention.py +235 -0
  5. diffsynth_engine/models/flux/flux_controlnet.py +7 -19
  6. diffsynth_engine/models/flux/flux_dit.py +22 -36
  7. diffsynth_engine/models/flux/flux_dit_fbcache.py +9 -7
  8. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  9. diffsynth_engine/models/qwen_image/qwen_image_dit.py +13 -15
  10. diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
  11. diffsynth_engine/models/wan/wan_dit.py +62 -22
  12. diffsynth_engine/pipelines/flux_image.py +11 -10
  13. diffsynth_engine/pipelines/qwen_image.py +26 -28
  14. diffsynth_engine/pipelines/wan_s2v.py +3 -8
  15. diffsynth_engine/pipelines/wan_video.py +11 -13
  16. diffsynth_engine/utils/constants.py +13 -12
  17. diffsynth_engine/utils/flag.py +6 -0
  18. diffsynth_engine/utils/parallel.py +51 -6
  19. {diffsynth_engine-0.6.1.dev21.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/METADATA +1 -1
  20. {diffsynth_engine-0.6.1.dev21.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/RECORD +34 -32
  21. /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
  22. /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
  23. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
  24. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
  25. /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
  26. /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
  27. /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
  28. /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
  29. /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
  30. /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
  31. /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
  32. {diffsynth_engine-0.6.1.dev21.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/WHEEL +0 -0
  33. {diffsynth_engine-0.6.1.dev21.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/licenses/LICENSE +0 -0
  34. {diffsynth_engine-0.6.1.dev21.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/top_level.txt +0 -0
@@ -176,7 +176,6 @@ class FluxDoubleAttention(nn.Module):
176
176
  dim_b,
177
177
  num_heads,
178
178
  head_dim,
179
- attn_kwargs: Optional[Dict[str, Any]] = None,
180
179
  device: str = "cuda:0",
181
180
  dtype: torch.dtype = torch.bfloat16,
182
181
  ):
@@ -194,19 +193,20 @@ class FluxDoubleAttention(nn.Module):
194
193
 
195
194
  self.a_to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype)
196
195
  self.b_to_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype)
197
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
198
196
 
199
197
  def attention_callback(self, attn_out_a, attn_out_b, x_a, x_b, q_a, q_b, k_a, k_b, v_a, v_b, rope_emb, image_emb):
200
198
  return attn_out_a, attn_out_b
201
199
 
202
- def forward(self, image, text, rope_emb, image_emb):
200
+ def forward(self, image, text, rope_emb, image_emb, attn_kwargs=None):
203
201
  q_a, k_a, v_a = rearrange(self.a_to_qkv(image), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
204
202
  q_b, k_b, v_b = rearrange(self.b_to_qkv(text), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
205
203
  q = torch.cat([self.norm_q_b(q_b), self.norm_q_a(q_a)], dim=1)
206
204
  k = torch.cat([self.norm_k_b(k_b), self.norm_k_a(k_a)], dim=1)
207
205
  v = torch.cat([v_b, v_a], dim=1)
208
206
  q, k = apply_rope(q, k, rope_emb)
209
- attn_out = attention_ops.attention(q, k, v, **self.attn_kwargs)
207
+
208
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
209
+ attn_out = attention_ops.attention(q, k, v, **attn_kwargs)
210
210
  attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
211
211
  text_out, image_out = attn_out[:, : text.shape[1]], attn_out[:, text.shape[1] :]
212
212
  image_out, text_out = self.attention_callback(
@@ -231,14 +231,11 @@ class FluxDoubleTransformerBlock(nn.Module):
231
231
  self,
232
232
  dim,
233
233
  num_heads,
234
- attn_kwargs: Optional[Dict[str, Any]] = None,
235
234
  device: str = "cuda:0",
236
235
  dtype: torch.dtype = torch.bfloat16,
237
236
  ):
238
237
  super().__init__()
239
- self.attn = FluxDoubleAttention(
240
- dim, dim, num_heads, dim // num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype
241
- )
238
+ self.attn = FluxDoubleAttention(dim, dim, num_heads, dim // num_heads, device=device, dtype=dtype)
242
239
  # Image
243
240
  self.norm_msa_a = AdaLayerNormZero(dim, device=device, dtype=dtype)
244
241
  self.norm_mlp_a = AdaLayerNormZero(dim, device=device, dtype=dtype)
@@ -256,11 +253,11 @@ class FluxDoubleTransformerBlock(nn.Module):
256
253
  nn.Linear(dim * 4, dim, device=device, dtype=dtype),
257
254
  )
258
255
 
259
- def forward(self, image, text, t_emb, rope_emb, image_emb=None):
256
+ def forward(self, image, text, t_emb, rope_emb, image_emb=None, attn_kwargs=None):
260
257
  # AdaLayerNorm-Zero for Image and Text MSA
261
258
  image_in, gate_a = self.norm_msa_a(image, t_emb)
262
259
  text_in, gate_b = self.norm_msa_b(text, t_emb)
263
- image_out, text_out = self.attn(image_in, text_in, rope_emb, image_emb)
260
+ image_out, text_out = self.attn(image_in, text_in, rope_emb, image_emb, attn_kwargs)
264
261
  image = image + gate_a * image_out
265
262
  text = text + gate_b * text_out
266
263
 
@@ -279,7 +276,6 @@ class FluxSingleAttention(nn.Module):
279
276
  self,
280
277
  dim,
281
278
  num_heads,
282
- attn_kwargs: Optional[Dict[str, Any]] = None,
283
279
  device: str = "cuda:0",
284
280
  dtype: torch.dtype = torch.bfloat16,
285
281
  ):
@@ -288,15 +284,16 @@ class FluxSingleAttention(nn.Module):
288
284
  self.to_qkv = nn.Linear(dim, dim * 3, device=device, dtype=dtype)
289
285
  self.norm_q_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype)
290
286
  self.norm_k_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype)
291
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
292
287
 
293
288
  def attention_callback(self, attn_out, x, q, k, v, rope_emb, image_emb):
294
289
  return attn_out
295
290
 
296
- def forward(self, x, rope_emb, image_emb):
291
+ def forward(self, x, rope_emb, image_emb, attn_kwargs=None):
297
292
  q, k, v = rearrange(self.to_qkv(x), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
298
293
  q, k = apply_rope(self.norm_q_a(q), self.norm_k_a(k), rope_emb)
299
- attn_out = attention_ops.attention(q, k, v, **self.attn_kwargs)
294
+
295
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
296
+ attn_out = attention_ops.attention(q, k, v, **attn_kwargs)
300
297
  attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
301
298
  return self.attention_callback(attn_out=attn_out, x=x, q=q, k=k, v=v, rope_emb=rope_emb, image_emb=image_emb)
302
299
 
@@ -306,23 +303,22 @@ class FluxSingleTransformerBlock(nn.Module):
306
303
  self,
307
304
  dim,
308
305
  num_heads,
309
- attn_kwargs: Optional[Dict[str, Any]] = None,
310
306
  device: str = "cuda:0",
311
307
  dtype: torch.dtype = torch.bfloat16,
312
308
  ):
313
309
  super().__init__()
314
310
  self.dim = dim
315
311
  self.norm = AdaLayerNormZero(dim, device=device, dtype=dtype)
316
- self.attn = FluxSingleAttention(dim, num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
312
+ self.attn = FluxSingleAttention(dim, num_heads, device=device, dtype=dtype)
317
313
  self.mlp = nn.Sequential(
318
314
  nn.Linear(dim, dim * 4, device=device, dtype=dtype),
319
315
  nn.GELU(approximate="tanh"),
320
316
  )
321
317
  self.proj_out = nn.Linear(dim * 5, dim, device=device, dtype=dtype)
322
318
 
323
- def forward(self, x, t_emb, rope_emb, image_emb=None):
319
+ def forward(self, x, t_emb, rope_emb, image_emb=None, attn_kwargs=None):
324
320
  h, gate = self.norm(x, emb=t_emb)
325
- attn_output = self.attn(h, rope_emb, image_emb)
321
+ attn_output = self.attn(h, rope_emb, image_emb, attn_kwargs)
326
322
  mlp_output = self.mlp(h)
327
323
  return x + gate * self.proj_out(torch.cat([attn_output, mlp_output], dim=2))
328
324
 
@@ -334,7 +330,6 @@ class FluxDiT(PreTrainedModel):
334
330
  def __init__(
335
331
  self,
336
332
  in_channel: int = 64,
337
- attn_kwargs: Optional[Dict[str, Any]] = None,
338
333
  device: str = "cuda:0",
339
334
  dtype: torch.dtype = torch.bfloat16,
340
335
  ):
@@ -352,16 +347,10 @@ class FluxDiT(PreTrainedModel):
352
347
  self.x_embedder = nn.Linear(in_channel, 3072, device=device, dtype=dtype)
353
348
 
354
349
  self.blocks = nn.ModuleList(
355
- [
356
- FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
357
- for _ in range(19)
358
- ]
350
+ [FluxDoubleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(19)]
359
351
  )
360
352
  self.single_blocks = nn.ModuleList(
361
- [
362
- FluxSingleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
363
- for _ in range(38)
364
- ]
353
+ [FluxSingleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(38)]
365
354
  )
366
355
  self.final_norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
367
356
  self.final_proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
@@ -403,6 +392,7 @@ class FluxDiT(PreTrainedModel):
403
392
  text_ids: torch.Tensor,
404
393
  guidance: torch.Tensor,
405
394
  image_emb: torch.Tensor | None = None,
395
+ attn_kwargs: Optional[Dict[str, Any]] = None,
406
396
  controlnet_double_block_output: List[torch.Tensor] | None = None,
407
397
  controlnet_single_block_output: List[torch.Tensor] | None = None,
408
398
  **kwargs,
@@ -470,14 +460,16 @@ class FluxDiT(PreTrainedModel):
470
460
  rope_emb = torch.cat((text_rope_emb, image_rope_emb), dim=2)
471
461
 
472
462
  for i, block in enumerate(self.blocks):
473
- hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
463
+ hidden_states, prompt_emb = block(
464
+ hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
465
+ )
474
466
  if len(controlnet_double_block_output) > 0:
475
467
  interval_control = len(self.blocks) / len(controlnet_double_block_output)
476
468
  interval_control = int(np.ceil(interval_control))
477
469
  hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
478
470
  hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
479
471
  for i, block in enumerate(self.single_blocks):
480
- hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
472
+ hidden_states = block(hidden_states, conditioning, rope_emb, image_emb, attn_kwargs)
481
473
  if len(controlnet_single_block_output) > 0:
482
474
  interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
483
475
  interval_control = int(np.ceil(interval_control))
@@ -498,14 +490,8 @@ class FluxDiT(PreTrainedModel):
498
490
  device: str,
499
491
  dtype: torch.dtype,
500
492
  in_channel: int = 64,
501
- attn_kwargs: Optional[Dict[str, Any]] = None,
502
493
  ):
503
- model = cls(
504
- device="meta",
505
- dtype=dtype,
506
- in_channel=in_channel,
507
- attn_kwargs=attn_kwargs,
508
- )
494
+ model = cls(device="meta", dtype=dtype, in_channel=in_channel)
509
495
  model = model.requires_grad_(False)
510
496
  model.load_state_dict(state_dict, assign=True)
511
497
  model.to(device=device, dtype=dtype, non_blocking=True)
@@ -20,12 +20,11 @@ class FluxDiTFBCache(FluxDiT):
20
20
  def __init__(
21
21
  self,
22
22
  in_channel: int = 64,
23
- attn_kwargs: Optional[Dict[str, Any]] = None,
24
23
  device: str = "cuda:0",
25
24
  dtype: torch.dtype = torch.bfloat16,
26
25
  relative_l1_threshold: float = 0.05,
27
26
  ):
28
- super().__init__(in_channel=in_channel, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
27
+ super().__init__(in_channel=in_channel, device=device, dtype=dtype)
29
28
  self.relative_l1_threshold = relative_l1_threshold
30
29
  self.step_count = 0
31
30
  self.num_inference_steps = 0
@@ -56,6 +55,7 @@ class FluxDiTFBCache(FluxDiT):
56
55
  text_ids: torch.Tensor,
57
56
  guidance: torch.Tensor,
58
57
  image_emb: torch.Tensor | None = None,
58
+ attn_kwargs: Optional[Dict[str, Any]] = None,
59
59
  controlnet_double_block_output: List[torch.Tensor] | None = None,
60
60
  controlnet_single_block_output: List[torch.Tensor] | None = None,
61
61
  **kwargs,
@@ -124,7 +124,9 @@ class FluxDiTFBCache(FluxDiT):
124
124
 
125
125
  # first block
126
126
  original_hidden_states = hidden_states
127
- hidden_states, prompt_emb = self.blocks[0](hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
127
+ hidden_states, prompt_emb = self.blocks[0](
128
+ hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
129
+ )
128
130
  first_hidden_states_residual = hidden_states - original_hidden_states
129
131
 
130
132
  (first_hidden_states_residual,) = sequence_parallel_unshard(
@@ -149,14 +151,16 @@ class FluxDiTFBCache(FluxDiT):
149
151
 
150
152
  first_hidden_states = hidden_states.clone()
151
153
  for i, block in enumerate(self.blocks[1:]):
152
- hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
154
+ hidden_states, prompt_emb = block(
155
+ hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
156
+ )
153
157
  if len(controlnet_double_block_output) > 0:
154
158
  interval_control = len(self.blocks) / len(controlnet_double_block_output)
155
159
  interval_control = int(np.ceil(interval_control))
156
160
  hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
157
161
  hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
158
162
  for i, block in enumerate(self.single_blocks):
159
- hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
163
+ hidden_states = block(hidden_states, conditioning, rope_emb, image_emb, attn_kwargs)
160
164
  if len(controlnet_single_block_output) > 0:
161
165
  interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
162
166
  interval_control = int(np.ceil(interval_control))
@@ -182,14 +186,12 @@ class FluxDiTFBCache(FluxDiT):
182
186
  device: str,
183
187
  dtype: torch.dtype,
184
188
  in_channel: int = 64,
185
- attn_kwargs: Optional[Dict[str, Any]] = None,
186
189
  relative_l1_threshold: float = 0.05,
187
190
  ):
188
191
  model = cls(
189
192
  device="meta",
190
193
  dtype=dtype,
191
194
  in_channel=in_channel,
192
- attn_kwargs=attn_kwargs,
193
195
  relative_l1_threshold=relative_l1_threshold,
194
196
  )
195
197
  model = model.requires_grad_(False)
@@ -2,7 +2,7 @@ import torch
2
2
  from einops import rearrange
3
3
  from torch import nn
4
4
  from PIL import Image
5
- from typing import Any, Dict, List, Optional
5
+ from typing import Dict, List
6
6
  from functools import partial
7
7
  from diffsynth_engine.models.text_encoder.siglip import SiglipImageEncoder
8
8
  from diffsynth_engine.models.basic.transformer_helper import RMSNorm
@@ -18,7 +18,6 @@ class FluxIPAdapterAttention(nn.Module):
18
18
  dim: int = 3072,
19
19
  head_num: int = 24,
20
20
  scale: float = 1.0,
21
- attn_kwargs: Optional[Dict[str, Any]] = None,
22
21
  device: str = "cuda:0",
23
22
  dtype: torch.dtype = torch.bfloat16,
24
23
  ):
@@ -28,12 +27,13 @@ class FluxIPAdapterAttention(nn.Module):
28
27
  self.to_v_ip = nn.Linear(image_emb_dim, dim, device=device, dtype=dtype, bias=False)
29
28
  self.head_num = head_num
30
29
  self.scale = scale
31
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
32
30
 
33
- def forward(self, query: torch.Tensor, image_emb: torch.Tensor):
31
+ def forward(self, query: torch.Tensor, image_emb: torch.Tensor, attn_kwargs=None):
34
32
  key = rearrange(self.norm_k(self.to_k_ip(image_emb)), "b s (h d) -> b s h d", h=self.head_num)
35
33
  value = rearrange(self.to_v_ip(image_emb), "b s (h d) -> b s h d", h=self.head_num)
36
- attn_out = attention(query, key, value, **self.attn_kwargs)
34
+
35
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
36
+ attn_out = attention(query, key, value, **attn_kwargs)
37
37
  return self.scale * rearrange(attn_out, "b s h d -> b s (h d)")
38
38
 
39
39
  @classmethod
@@ -167,7 +167,6 @@ class QwenDoubleStreamAttention(nn.Module):
167
167
  dim_b,
168
168
  num_heads,
169
169
  head_dim,
170
- attn_kwargs: Optional[Dict[str, Any]] = None,
171
170
  device: str = "cuda:0",
172
171
  dtype: torch.dtype = torch.bfloat16,
173
172
  ):
@@ -189,7 +188,6 @@ class QwenDoubleStreamAttention(nn.Module):
189
188
 
190
189
  self.to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype)
191
190
  self.to_add_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype)
192
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
193
191
 
194
192
  def forward(
195
193
  self,
@@ -197,6 +195,7 @@ class QwenDoubleStreamAttention(nn.Module):
197
195
  text: torch.FloatTensor,
198
196
  rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
199
197
  attn_mask: Optional[torch.Tensor] = None,
198
+ attn_kwargs: Optional[Dict[str, Any]] = None,
200
199
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
201
200
  img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
202
201
  txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
@@ -227,7 +226,8 @@ class QwenDoubleStreamAttention(nn.Module):
227
226
  joint_k = joint_k.transpose(1, 2)
228
227
  joint_v = joint_v.transpose(1, 2)
229
228
 
230
- joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **self.attn_kwargs)
229
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
230
+ joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **attn_kwargs)
231
231
 
232
232
  joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype)
233
233
 
@@ -247,7 +247,6 @@ class QwenImageTransformerBlock(nn.Module):
247
247
  num_attention_heads: int,
248
248
  attention_head_dim: int,
249
249
  eps: float = 1e-6,
250
- attn_kwargs: Optional[Dict[str, Any]] = None,
251
250
  device: str = "cuda:0",
252
251
  dtype: torch.dtype = torch.bfloat16,
253
252
  ):
@@ -267,7 +266,6 @@ class QwenImageTransformerBlock(nn.Module):
267
266
  dim_b=dim,
268
267
  num_heads=num_attention_heads,
269
268
  head_dim=attention_head_dim,
270
- attn_kwargs=attn_kwargs,
271
269
  device=device,
272
270
  dtype=dtype,
273
271
  )
@@ -293,6 +291,7 @@ class QwenImageTransformerBlock(nn.Module):
293
291
  temb: torch.Tensor,
294
292
  rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
295
293
  attn_mask: Optional[torch.Tensor] = None,
294
+ attn_kwargs: Optional[Dict[str, Any]] = None,
296
295
  ) -> Tuple[torch.Tensor, torch.Tensor]:
297
296
  img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
298
297
  txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
@@ -308,6 +307,7 @@ class QwenImageTransformerBlock(nn.Module):
308
307
  text=txt_modulated,
309
308
  rotary_emb=rotary_emb,
310
309
  attn_mask=attn_mask,
310
+ attn_kwargs=attn_kwargs,
311
311
  )
312
312
 
313
313
  image = image + img_gate * img_attn_out
@@ -335,7 +335,6 @@ class QwenImageDiT(PreTrainedModel):
335
335
  def __init__(
336
336
  self,
337
337
  num_layers: int = 60,
338
- attn_kwargs: Optional[Dict[str, Any]] = None,
339
338
  device: str = "cuda:0",
340
339
  dtype: torch.dtype = torch.bfloat16,
341
340
  ):
@@ -356,7 +355,6 @@ class QwenImageDiT(PreTrainedModel):
356
355
  dim=3072,
357
356
  num_attention_heads=24,
358
357
  attention_head_dim=128,
359
- attn_kwargs=attn_kwargs,
360
358
  device=device,
361
359
  dtype=dtype,
362
360
  )
@@ -444,6 +442,7 @@ class QwenImageDiT(PreTrainedModel):
444
442
  entity_text: Optional[List[torch.Tensor]] = None,
445
443
  entity_seq_lens: Optional[List[torch.LongTensor]] = None,
446
444
  entity_masks: Optional[List[torch.Tensor]] = None,
445
+ attn_kwargs: Optional[Dict[str, Any]] = None,
447
446
  ):
448
447
  h, w = image.shape[-2:]
449
448
  fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
@@ -509,7 +508,12 @@ class QwenImageDiT(PreTrainedModel):
509
508
  rotary_emb = (img_freqs, txt_freqs)
510
509
  for block in self.transformer_blocks:
511
510
  text, image = block(
512
- image=image, text=text, temb=conditioning, rotary_emb=rotary_emb, attn_mask=attn_mask
511
+ image=image,
512
+ text=text,
513
+ temb=conditioning,
514
+ rotary_emb=rotary_emb,
515
+ attn_mask=attn_mask,
516
+ attn_kwargs=attn_kwargs,
513
517
  )
514
518
  image = self.norm_out(image, conditioning)
515
519
  image = self.proj_out(image)
@@ -527,14 +531,8 @@ class QwenImageDiT(PreTrainedModel):
527
531
  device: str,
528
532
  dtype: torch.dtype,
529
533
  num_layers: int = 60,
530
- attn_kwargs: Optional[Dict[str, Any]] = None,
531
534
  ):
532
- model = cls(
533
- device="meta",
534
- dtype=dtype,
535
- num_layers=num_layers,
536
- attn_kwargs=attn_kwargs,
537
- )
535
+ model = cls(device="meta", dtype=dtype, num_layers=num_layers)
538
536
  model = model.requires_grad_(False)
539
537
  model.load_state_dict(state_dict, assign=True)
540
538
  model.to(device=device, dtype=dtype, non_blocking=True)
@@ -11,12 +11,11 @@ class QwenImageDiTFBCache(QwenImageDiT):
11
11
  def __init__(
12
12
  self,
13
13
  num_layers: int = 60,
14
- attn_kwargs: Optional[Dict[str, Any]] = None,
15
14
  device: str = "cuda:0",
16
15
  dtype: torch.dtype = torch.bfloat16,
17
16
  relative_l1_threshold: float = 0.05,
18
17
  ):
19
- super().__init__(num_layers=num_layers, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
18
+ super().__init__(num_layers=num_layers, device=device, dtype=dtype)
20
19
  self.relative_l1_threshold = relative_l1_threshold
21
20
  self.step_count = 0
22
21
  self.num_inference_steps = 0
@@ -43,6 +42,7 @@ class QwenImageDiTFBCache(QwenImageDiT):
43
42
  text: torch.Tensor = None,
44
43
  timestep: torch.LongTensor = None,
45
44
  txt_seq_lens: torch.LongTensor = None,
45
+ attn_kwargs: Optional[Dict[str, Any]] = None,
46
46
  ):
47
47
  h, w = image.shape[-2:]
48
48
  fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
@@ -72,7 +72,11 @@ class QwenImageDiTFBCache(QwenImageDiT):
72
72
  # first block
73
73
  original_hidden_states = image
74
74
  text, image = self.transformer_blocks[0](
75
- image=image, text=text, temb=conditioning, image_rotary_emb=image_rotary_emb
75
+ image=image,
76
+ text=text,
77
+ temb=conditioning,
78
+ image_rotary_emb=image_rotary_emb,
79
+ attn_kwargs=attn_kwargs,
76
80
  )
77
81
  first_hidden_states_residual = image - original_hidden_states
78
82
 
@@ -94,7 +98,13 @@ class QwenImageDiTFBCache(QwenImageDiT):
94
98
  first_hidden_states = image.clone()
95
99
 
96
100
  for block in self.transformer_blocks[1:]:
97
- text, image = block(image=image, text=text, temb=conditioning, image_rotary_emb=image_rotary_emb)
101
+ text, image = block(
102
+ image=image,
103
+ text=text,
104
+ temb=conditioning,
105
+ image_rotary_emb=image_rotary_emb,
106
+ attn_kwargs=attn_kwargs,
107
+ )
98
108
 
99
109
  previous_residual = image - first_hidden_states
100
110
  self.previous_residual = previous_residual
@@ -114,14 +124,12 @@ class QwenImageDiTFBCache(QwenImageDiT):
114
124
  device: str,
115
125
  dtype: torch.dtype,
116
126
  num_layers: int = 60,
117
- attn_kwargs: Optional[Dict[str, Any]] = None,
118
127
  relative_l1_threshold: float = 0.05,
119
128
  ):
120
129
  model = cls(
121
130
  device="meta",
122
131
  dtype=dtype,
123
132
  num_layers=num_layers,
124
- attn_kwargs=attn_kwargs,
125
133
  relative_l1_threshold=relative_l1_threshold,
126
134
  )
127
135
  model = model.requires_grad_(False)