diffsynth-engine 0.6.1.dev22__py3-none-any.whl → 0.6.1.dev24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
  2. diffsynth_engine/configs/pipeline.py +35 -12
  3. diffsynth_engine/models/basic/attention.py +59 -20
  4. diffsynth_engine/models/basic/transformer_helper.py +36 -2
  5. diffsynth_engine/models/basic/video_sparse_attention.py +235 -0
  6. diffsynth_engine/models/flux/flux_controlnet.py +7 -19
  7. diffsynth_engine/models/flux/flux_dit.py +22 -36
  8. diffsynth_engine/models/flux/flux_dit_fbcache.py +9 -7
  9. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  10. diffsynth_engine/models/qwen_image/qwen_image_dit.py +26 -32
  11. diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
  12. diffsynth_engine/models/wan/wan_dit.py +62 -22
  13. diffsynth_engine/pipelines/flux_image.py +11 -10
  14. diffsynth_engine/pipelines/qwen_image.py +16 -15
  15. diffsynth_engine/pipelines/utils.py +52 -0
  16. diffsynth_engine/pipelines/wan_s2v.py +3 -8
  17. diffsynth_engine/pipelines/wan_video.py +11 -13
  18. diffsynth_engine/tokenizers/base.py +6 -0
  19. diffsynth_engine/tokenizers/qwen2.py +12 -4
  20. diffsynth_engine/utils/constants.py +13 -12
  21. diffsynth_engine/utils/flag.py +6 -0
  22. diffsynth_engine/utils/parallel.py +51 -6
  23. {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev24.dist-info}/METADATA +1 -1
  24. {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev24.dist-info}/RECORD +38 -36
  25. /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
  26. /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
  27. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
  28. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
  29. /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
  30. /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
  31. /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
  32. /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
  33. /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
  34. /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
  35. /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
  36. {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev24.dist-info}/WHEEL +0 -0
  37. {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev24.dist-info}/licenses/LICENSE +0 -0
  38. {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev24.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from diffsynth_engine.utils.constants import (
17
17
  WAN2_2_DIT_TI2V_5B_CONFIG_FILE,
18
18
  WAN2_2_DIT_I2V_A14B_CONFIG_FILE,
19
19
  WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
20
+ WAN_DIT_KEYMAP_FILE,
20
21
  )
21
22
  from diffsynth_engine.utils.gguf import gguf_inference
22
23
  from diffsynth_engine.utils.fp8_linear import fp8_inference
@@ -30,6 +31,9 @@ from diffsynth_engine.utils.parallel import (
30
31
  T5_TOKEN_NUM = 512
31
32
  FLF_TOKEN_NUM = 257 * 2
32
33
 
34
+ with open(WAN_DIT_KEYMAP_FILE, "r", encoding="utf-8") as f:
35
+ config = json.load(f)
36
+
33
37
 
34
38
  def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
35
39
  return x * (1 + scale) + shift
@@ -73,7 +77,7 @@ class SelfAttention(nn.Module):
73
77
  dim: int,
74
78
  num_heads: int,
75
79
  eps: float = 1e-6,
76
- attn_kwargs: Optional[Dict[str, Any]] = None,
80
+ use_vsa: bool = False,
77
81
  device: str = "cuda:0",
78
82
  dtype: torch.dtype = torch.bfloat16,
79
83
  ):
@@ -86,19 +90,25 @@ class SelfAttention(nn.Module):
86
90
  self.o = nn.Linear(dim, dim, device=device, dtype=dtype)
87
91
  self.norm_q = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
88
92
  self.norm_k = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
89
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
93
+ self.gate_compress = nn.Linear(dim, dim, device=device, dtype=dtype) if use_vsa else None
90
94
 
91
- def forward(self, x, freqs):
95
+ def forward(self, x, freqs, attn_kwargs=None):
92
96
  q, k, v = self.norm_q(self.q(x)), self.norm_k(self.k(x)), self.v(x)
97
+ g = self.gate_compress(x) if self.gate_compress is not None else None
98
+
93
99
  num_heads = q.shape[2] // self.head_dim
94
100
  q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
95
101
  k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
96
102
  v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
103
+ g = rearrange(g, "b s (n d) -> b s n d", n=num_heads) if g is not None else None
104
+
105
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
97
106
  x = attention_ops.attention(
98
107
  q=rope_apply(q, freqs),
99
108
  k=rope_apply(k, freqs),
100
109
  v=v,
101
- **self.attn_kwargs,
110
+ g=g,
111
+ **attn_kwargs,
102
112
  )
103
113
  x = x.flatten(2)
104
114
  return self.o(x)
@@ -111,7 +121,6 @@ class CrossAttention(nn.Module):
111
121
  num_heads: int,
112
122
  eps: float = 1e-6,
113
123
  has_image_input: bool = False,
114
- attn_kwargs: Optional[Dict[str, Any]] = None,
115
124
  device: str = "cuda:0",
116
125
  dtype: torch.dtype = torch.bfloat16,
117
126
  ):
@@ -130,9 +139,8 @@ class CrossAttention(nn.Module):
130
139
  self.k_img = nn.Linear(dim, dim, device=device, dtype=dtype)
131
140
  self.v_img = nn.Linear(dim, dim, device=device, dtype=dtype)
132
141
  self.norm_k_img = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
133
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
134
142
 
135
- def forward(self, x: torch.Tensor, y: torch.Tensor):
143
+ def forward(self, x: torch.Tensor, y: torch.Tensor, attn_kwargs=None):
136
144
  if self.has_image_input:
137
145
  img = y[:, :-T5_TOKEN_NUM]
138
146
  ctx = y[:, -T5_TOKEN_NUM:]
@@ -144,12 +152,16 @@ class CrossAttention(nn.Module):
144
152
  k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
145
153
  v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
146
154
 
147
- x = attention(q, k, v, **self.attn_kwargs).flatten(2)
155
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
156
+ if attn_kwargs.get("attn_impl", None) == "vsa":
157
+ attn_kwargs = attn_kwargs.copy()
158
+ attn_kwargs["attn_impl"] = "sdpa"
159
+ x = attention(q, k, v, **attn_kwargs).flatten(2)
148
160
  if self.has_image_input:
149
161
  k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
150
162
  k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
151
163
  v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
152
- y = attention(q, k_img, v_img, **self.attn_kwargs).flatten(2)
164
+ y = attention(q, k_img, v_img, **attn_kwargs).flatten(2)
153
165
  x = x + y
154
166
  return self.o(x)
155
167
 
@@ -162,7 +174,7 @@ class DiTBlock(nn.Module):
162
174
  num_heads: int,
163
175
  ffn_dim: int,
164
176
  eps: float = 1e-6,
165
- attn_kwargs: Optional[Dict[str, Any]] = None,
177
+ use_vsa: bool = False,
166
178
  device: str = "cuda:0",
167
179
  dtype: torch.dtype = torch.bfloat16,
168
180
  ):
@@ -170,9 +182,9 @@ class DiTBlock(nn.Module):
170
182
  self.dim = dim
171
183
  self.num_heads = num_heads
172
184
  self.ffn_dim = ffn_dim
173
- self.self_attn = SelfAttention(dim, num_heads, eps, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
185
+ self.self_attn = SelfAttention(dim, num_heads, eps, use_vsa=use_vsa, device=device, dtype=dtype)
174
186
  self.cross_attn = CrossAttention(
175
- dim, num_heads, eps, has_image_input=has_image_input, attn_kwargs=attn_kwargs, device=device, dtype=dtype
187
+ dim, num_heads, eps, has_image_input=has_image_input, device=device, dtype=dtype
176
188
  )
177
189
  self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
178
190
  self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
@@ -184,14 +196,14 @@ class DiTBlock(nn.Module):
184
196
  )
185
197
  self.modulation = nn.Parameter(torch.randn(1, 6, dim, device=device, dtype=dtype) / dim**0.5)
186
198
 
187
- def forward(self, x, context, t_mod, freqs):
199
+ def forward(self, x, context, t_mod, freqs, attn_kwargs=None):
188
200
  # msa: multi-head self-attention mlp: multi-layer perceptron
189
201
  shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
190
202
  t.squeeze(1) for t in (self.modulation + t_mod).chunk(6, dim=1)
191
203
  ]
192
204
  input_x = modulate(self.norm1(x), shift_msa, scale_msa)
193
- x = x + gate_msa * self.self_attn(input_x, freqs)
194
- x = x + self.cross_attn(self.norm3(x), context)
205
+ x = x + gate_msa * self.self_attn(input_x, freqs, attn_kwargs)
206
+ x = x + self.cross_attn(self.norm3(x), context, attn_kwargs)
195
207
  input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
196
208
  x = x + gate_mlp * self.ffn(input_x)
197
209
  return x
@@ -249,7 +261,26 @@ class Head(nn.Module):
249
261
 
250
262
 
251
263
  class WanDiTStateDictConverter(StateDictConverter):
264
+ def _from_diffusers(self, state_dict):
265
+ global_rename_dict = config["diffusers"]["global_rename_dict"]
266
+ rename_dict = config["diffusers"]["rename_dict"]
267
+ state_dict_ = {}
268
+ for name, param in state_dict.items():
269
+ suffix = ""
270
+ suffix = ".weight" if name.endswith(".weight") else suffix
271
+ suffix = ".bias" if name.endswith(".bias") else suffix
272
+ prefix = name[: -len(suffix)] if suffix else name
273
+ if prefix in global_rename_dict:
274
+ state_dict_[f"{global_rename_dict[prefix]}{suffix}"] = param
275
+ if prefix.startswith("blocks."):
276
+ _, idx, middle = prefix.split(".", 2)
277
+ if middle in rename_dict:
278
+ state_dict_[f"blocks.{idx}.{rename_dict[middle]}{suffix}"] = param
279
+ return state_dict_
280
+
252
281
  def convert(self, state_dict):
282
+ if "condition_embedder.time_proj.weight" in state_dict:
283
+ return self._from_diffusers(state_dict)
253
284
  return state_dict
254
285
 
255
286
 
@@ -273,7 +304,7 @@ class WanDiT(PreTrainedModel):
273
304
  has_vae_feature: bool = False,
274
305
  fuse_image_latents: bool = False,
275
306
  flf_pos_emb: bool = False,
276
- attn_kwargs: Optional[Dict[str, Any]] = None,
307
+ use_vsa: bool = False,
277
308
  device: str = "cuda:0",
278
309
  dtype: torch.dtype = torch.bfloat16,
279
310
  ):
@@ -307,7 +338,16 @@ class WanDiT(PreTrainedModel):
307
338
  )
308
339
  self.blocks = nn.ModuleList(
309
340
  [
310
- DiTBlock(has_clip_feature, dim, num_heads, ffn_dim, eps, attn_kwargs, device=device, dtype=dtype)
341
+ DiTBlock(
342
+ has_clip_feature,
343
+ dim,
344
+ num_heads,
345
+ ffn_dim,
346
+ eps,
347
+ use_vsa,
348
+ device=device,
349
+ dtype=dtype,
350
+ )
311
351
  for _ in range(num_layers)
312
352
  ]
313
353
  )
@@ -344,6 +384,7 @@ class WanDiT(PreTrainedModel):
344
384
  timestep: torch.Tensor,
345
385
  clip_feature: Optional[torch.Tensor] = None, # clip_vision_encoder(img)
346
386
  y: Optional[torch.Tensor] = None, # vae_encoder(img)
387
+ attn_kwargs: Optional[Dict[str, Any]] = None,
347
388
  ):
348
389
  fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
349
390
  use_cfg = x.shape[0] > 1
@@ -376,7 +417,7 @@ class WanDiT(PreTrainedModel):
376
417
 
377
418
  with sequence_parallel((x, t, t_mod, freqs), seq_dims=(1, 0, 0, 0)):
378
419
  for block in self.blocks:
379
- x = block(x, context, t_mod, freqs)
420
+ x = block(x, context, t_mod, freqs, attn_kwargs)
380
421
  x = self.head(x, t)
381
422
  (x,) = sequence_parallel_unshard((x,), seq_dims=(1,), seq_lens=(f * h * w,))
382
423
  x = self.unpatchify(x, (f, h, w))
@@ -409,12 +450,11 @@ class WanDiT(PreTrainedModel):
409
450
  config: Dict[str, Any],
410
451
  device: str = "cuda:0",
411
452
  dtype: torch.dtype = torch.bfloat16,
412
- attn_kwargs: Optional[Dict[str, Any]] = None,
413
- assign: bool = True,
453
+ use_vsa: bool = False,
414
454
  ):
415
- model = cls(**config, device="meta", dtype=dtype, attn_kwargs=attn_kwargs)
455
+ model = cls(**config, device="meta", dtype=dtype, use_vsa=use_vsa)
416
456
  model = model.requires_grad_(False)
417
- model.load_state_dict(state_dict, assign=assign)
457
+ model.load_state_dict(state_dict, assign=True)
418
458
  model.to(device=device, dtype=dtype, non_blocking=True)
419
459
  return model
420
460
 
@@ -17,7 +17,12 @@ from diffsynth_engine.models.flux import (
17
17
  flux_dit_config,
18
18
  flux_text_encoder_config,
19
19
  )
20
- from diffsynth_engine.configs import FluxPipelineConfig, FluxStateDicts, ControlType, ControlNetParams
20
+ from diffsynth_engine.configs import (
21
+ FluxPipelineConfig,
22
+ FluxStateDicts,
23
+ ControlType,
24
+ ControlNetParams,
25
+ )
21
26
  from diffsynth_engine.models.basic.lora import LoRAContext
22
27
  from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
23
28
  from diffsynth_engine.pipelines.utils import accumulate, calculate_shift
@@ -507,20 +512,12 @@ class FluxImagePipeline(BasePipeline):
507
512
  vae_encoder = FluxVAEEncoder.from_state_dict(state_dicts.vae, device=init_device, dtype=config.vae_dtype)
508
513
 
509
514
  with LoRAContext():
510
- attn_kwargs = {
511
- "attn_impl": config.dit_attn_impl.value,
512
- "sparge_smooth_k": config.sparge_smooth_k,
513
- "sparge_cdfthreshd": config.sparge_cdfthreshd,
514
- "sparge_simthreshd1": config.sparge_simthreshd1,
515
- "sparge_pvthreshd": config.sparge_pvthreshd,
516
- }
517
515
  if config.use_fbcache:
518
516
  dit = FluxDiTFBCache.from_state_dict(
519
517
  state_dicts.model,
520
518
  device=("cpu" if config.use_fsdp else init_device),
521
519
  dtype=config.model_dtype,
522
520
  in_channel=config.control_type.get_in_channel(),
523
- attn_kwargs=attn_kwargs,
524
521
  relative_l1_threshold=config.fbcache_relative_l1_threshold,
525
522
  )
526
523
  else:
@@ -529,7 +526,6 @@ class FluxImagePipeline(BasePipeline):
529
526
  device=("cpu" if config.use_fsdp else init_device),
530
527
  dtype=config.model_dtype,
531
528
  in_channel=config.control_type.get_in_channel(),
532
- attn_kwargs=attn_kwargs,
533
529
  )
534
530
  if config.use_fp8_linear:
535
531
  enable_fp8_linear(dit)
@@ -755,6 +751,7 @@ class FluxImagePipeline(BasePipeline):
755
751
  latents = latents.to(self.dtype)
756
752
  self.load_models_to_device(["dit"])
757
753
 
754
+ attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
758
755
  noise_pred = self.dit(
759
756
  hidden_states=latents,
760
757
  timestep=timestep,
@@ -766,6 +763,7 @@ class FluxImagePipeline(BasePipeline):
766
763
  image_ids=image_ids,
767
764
  controlnet_double_block_output=double_block_output,
768
765
  controlnet_single_block_output=single_block_output,
766
+ attn_kwargs=attn_kwargs,
769
767
  )
770
768
  noise_pred = noise_pred[:, :image_seq_len]
771
769
  noise_pred = self.dit.unpatchify(noise_pred, height, width)
@@ -887,6 +885,8 @@ class FluxImagePipeline(BasePipeline):
887
885
  if self.offload_mode is not None:
888
886
  empty_cache()
889
887
  param.model.to(self.device)
888
+
889
+ attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
890
890
  double_block_output, single_block_output = param.model(
891
891
  hidden_states=latents,
892
892
  control_condition=control_condition,
@@ -897,6 +897,7 @@ class FluxImagePipeline(BasePipeline):
897
897
  image_ids=image_ids,
898
898
  text_ids=text_ids,
899
899
  guidance=guidance,
900
+ attn_kwargs=attn_kwargs,
900
901
  )
901
902
  if self.offload_mode is not None:
902
903
  param.model.to("cpu")
@@ -24,7 +24,7 @@ from diffsynth_engine.models.qwen_image import (
24
24
  from diffsynth_engine.models.qwen_image import QwenImageVAE
25
25
  from diffsynth_engine.tokenizers import Qwen2TokenizerFast, Qwen2VLProcessor
26
26
  from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
27
- from diffsynth_engine.pipelines.utils import calculate_shift
27
+ from diffsynth_engine.pipelines.utils import calculate_shift, pad_and_concat
28
28
  from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
29
29
  from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
30
30
  from diffsynth_engine.utils.constants import (
@@ -91,7 +91,7 @@ class QwenImageLoRAConverter(LoRAStateDictConverter):
91
91
  if "lora_A.weight" in key:
92
92
  lora_a_suffix = "lora_A.weight"
93
93
  lora_b_suffix = "lora_B.weight"
94
-
94
+
95
95
  if lora_a_suffix is None:
96
96
  continue
97
97
 
@@ -148,9 +148,17 @@ class QwenImagePipeline(BasePipeline):
148
148
  self.prompt_template_encode_start_idx = 34
149
149
  # qwen image edit
150
150
  self.edit_system_prompt = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
151
- self.edit_prompt_template_encode = "<|im_start|>system\n" + self.edit_system_prompt + "<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
151
+ self.edit_prompt_template_encode = (
152
+ "<|im_start|>system\n"
153
+ + self.edit_system_prompt
154
+ + "<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
155
+ )
152
156
  # qwen image edit plus
153
- self.edit_plus_prompt_template_encode = "<|im_start|>system\n" + self.edit_system_prompt + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
157
+ self.edit_plus_prompt_template_encode = (
158
+ "<|im_start|>system\n"
159
+ + self.edit_system_prompt
160
+ + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
161
+ )
154
162
 
155
163
  self.edit_prompt_template_encode_start_idx = 64
156
164
 
@@ -253,19 +261,11 @@ class QwenImagePipeline(BasePipeline):
253
261
  )
254
262
 
255
263
  with LoRAContext():
256
- attn_kwargs = {
257
- "attn_impl": config.dit_attn_impl.value,
258
- "sparge_smooth_k": config.sparge_smooth_k,
259
- "sparge_cdfthreshd": config.sparge_cdfthreshd,
260
- "sparge_simthreshd1": config.sparge_simthreshd1,
261
- "sparge_pvthreshd": config.sparge_pvthreshd,
262
- }
263
264
  if config.use_fbcache:
264
265
  dit = QwenImageDiTFBCache.from_state_dict(
265
266
  state_dicts.model,
266
267
  device=("cpu" if config.use_fsdp else init_device),
267
268
  dtype=config.model_dtype,
268
- attn_kwargs=attn_kwargs,
269
269
  relative_l1_threshold=config.fbcache_relative_l1_threshold,
270
270
  )
271
271
  else:
@@ -273,7 +273,6 @@ class QwenImagePipeline(BasePipeline):
273
273
  state_dicts.model,
274
274
  device=("cpu" if config.use_fsdp else init_device),
275
275
  dtype=config.model_dtype,
276
- attn_kwargs=attn_kwargs,
277
276
  )
278
277
  if config.use_fp8_linear:
279
278
  enable_fp8_linear(dit)
@@ -499,8 +498,8 @@ class QwenImagePipeline(BasePipeline):
499
498
  else:
500
499
  # cfg by predict noise in one batch
501
500
  bs, _, h, w = latents.shape
502
- prompt_emb = torch.cat([prompt_emb, negative_prompt_emb], dim=0)
503
- prompt_emb_mask = torch.cat([prompt_emb_mask, negative_prompt_emb_mask], dim=0)
501
+ prompt_emb = pad_and_concat(prompt_emb, negative_prompt_emb)
502
+ prompt_emb_mask = pad_and_concat(prompt_emb_mask, negative_prompt_emb_mask)
504
503
  if entity_prompt_embs is not None:
505
504
  entity_prompt_embs = [
506
505
  torch.cat([x, y], dim=0) for x, y in zip(entity_prompt_embs, negative_entity_prompt_embs)
@@ -548,6 +547,7 @@ class QwenImagePipeline(BasePipeline):
548
547
  entity_masks: Optional[List[torch.Tensor]] = None,
549
548
  ):
550
549
  self.load_models_to_device(["dit"])
550
+ attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
551
551
  noise_pred = self.dit(
552
552
  image=latents,
553
553
  edit=image_latents,
@@ -558,6 +558,7 @@ class QwenImagePipeline(BasePipeline):
558
558
  entity_text=entity_prompt_embs,
559
559
  entity_seq_lens=[mask.sum(dim=1) for mask in entity_prompt_emb_masks] if entity_prompt_emb_masks else None,
560
560
  entity_masks=entity_masks,
561
+ attn_kwargs=attn_kwargs,
561
562
  )
562
563
  return noise_pred
563
564
 
@@ -1,3 +1,7 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
1
5
  def accumulate(result, new_item):
2
6
  if result is None:
3
7
  return new_item
@@ -17,3 +21,51 @@ def calculate_shift(
17
21
  b = base_shift - m * base_seq_len
18
22
  mu = image_seq_len * m + b
19
23
  return mu
24
+
25
+
26
+ def pad_and_concat(
27
+ tensor1: torch.Tensor,
28
+ tensor2: torch.Tensor,
29
+ concat_dim: int = 0,
30
+ pad_dim: int = 1,
31
+ ) -> torch.Tensor:
32
+ """
33
+ Concatenate two tensors along a specified dimension after padding along another dimension.
34
+
35
+ Assumes input tensors have shape (b, s, d), where:
36
+ - b: batch dimension
37
+ - s: sequence dimension (may differ)
38
+ - d: feature dimension
39
+
40
+ Args:
41
+ tensor1: First tensor with shape (b1, s1, d)
42
+ tensor2: Second tensor with shape (b2, s2, d)
43
+ concat_dim: Dimension to concatenate along, default is 0 (batch dimension)
44
+ pad_dim: Dimension to pad along, default is 1 (sequence dimension)
45
+
46
+ Returns:
47
+ Concatenated tensor, shape depends on concat_dim and pad_dim choices
48
+ """
49
+ assert tensor1.dim() == tensor2.dim(), "Both tensors must have the same number of dimensions"
50
+ assert concat_dim != pad_dim, "concat_dim and pad_dim cannot be the same"
51
+
52
+ len1, len2 = tensor1.shape[pad_dim], tensor2.shape[pad_dim]
53
+ max_len = max(len1, len2)
54
+
55
+ # Calculate the position of pad_dim in the padding list
56
+ # Padding format: from the last dimension, each pair represents (dim_n_left, dim_n_right, ..., dim_0_left, dim_0_right)
57
+ ndim = tensor1.dim()
58
+ padding = [0] * (2 * ndim)
59
+ pad_right_idx = -2 * pad_dim - 1
60
+
61
+ if len1 < max_len:
62
+ pad_len = max_len - len1
63
+ padding[pad_right_idx] = pad_len
64
+ tensor1 = F.pad(tensor1, padding, mode="constant", value=0)
65
+ elif len2 < max_len:
66
+ pad_len = max_len - len2
67
+ padding[pad_right_idx] = pad_len
68
+ tensor2 = F.pad(tensor2, padding, mode="constant", value=0)
69
+
70
+ # Concatenate along the specified dimension
71
+ return torch.cat([tensor1, tensor2], dim=concat_dim)
@@ -394,6 +394,7 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
394
394
  void_audio_input: torch.Tensor | None = None,
395
395
  ):
396
396
  latents = latents.to(dtype=self.config.model_dtype, device=self.device)
397
+ attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
397
398
 
398
399
  noise_pred = model(
399
400
  x=latents,
@@ -408,6 +409,7 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
408
409
  drop_motion_frames=drop_motion_frames,
409
410
  audio_mask=audio_mask,
410
411
  void_audio_input=void_audio_input,
412
+ attn_kwargs=attn_kwargs,
411
413
  )
412
414
  return noise_pred
413
415
 
@@ -654,19 +656,12 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
654
656
  )
655
657
 
656
658
  with LoRAContext():
657
- attn_kwargs = {
658
- "attn_impl": config.dit_attn_impl.value,
659
- "sparge_smooth_k": config.sparge_smooth_k,
660
- "sparge_cdfthreshd": config.sparge_cdfthreshd,
661
- "sparge_simthreshd1": config.sparge_simthreshd1,
662
- "sparge_pvthreshd": config.sparge_pvthreshd,
663
- }
664
659
  dit = WanS2VDiT.from_state_dict(
665
660
  state_dicts.model,
666
661
  config=model_config,
667
662
  device=("cpu" if config.use_fsdp else init_device),
668
663
  dtype=config.model_dtype,
669
- attn_kwargs=attn_kwargs,
664
+ use_vsa=(config.dit_attn_impl.value == "vsa"),
670
665
  )
671
666
  if config.use_fp8_linear:
672
667
  enable_fp8_linear(dit)
@@ -323,6 +323,7 @@ class WanVideoPipeline(BasePipeline):
323
323
 
324
324
  def predict_noise(self, model, latents, image_clip_feature, image_y, timestep, context):
325
325
  latents = latents.to(dtype=self.config.model_dtype, device=self.device)
326
+ attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
326
327
 
327
328
  noise_pred = model(
328
329
  x=latents,
@@ -330,6 +331,7 @@ class WanVideoPipeline(BasePipeline):
330
331
  context=context,
331
332
  clip_feature=image_clip_feature,
332
333
  y=image_y,
334
+ attn_kwargs=attn_kwargs,
333
335
  )
334
336
  return noise_pred
335
337
 
@@ -578,19 +580,12 @@ class WanVideoPipeline(BasePipeline):
578
580
  dit_state_dict = state_dicts.model
579
581
 
580
582
  with LoRAContext():
581
- attn_kwargs = {
582
- "attn_impl": config.dit_attn_impl.value,
583
- "sparge_smooth_k": config.sparge_smooth_k,
584
- "sparge_cdfthreshd": config.sparge_cdfthreshd,
585
- "sparge_simthreshd1": config.sparge_simthreshd1,
586
- "sparge_pvthreshd": config.sparge_pvthreshd,
587
- }
588
583
  dit = WanDiT.from_state_dict(
589
584
  dit_state_dict,
590
585
  config=dit_config,
591
586
  device=("cpu" if config.use_fsdp else init_device),
592
587
  dtype=config.model_dtype,
593
- attn_kwargs=attn_kwargs,
588
+ use_vsa=(config.dit_attn_impl.value == "vsa"),
594
589
  )
595
590
  if config.use_fp8_linear:
596
591
  enable_fp8_linear(dit)
@@ -602,7 +597,7 @@ class WanVideoPipeline(BasePipeline):
602
597
  config=dit_config,
603
598
  device=("cpu" if config.use_fsdp else init_device),
604
599
  dtype=config.model_dtype,
605
- attn_kwargs=attn_kwargs,
600
+ use_vsa=(config.dit_attn_impl.value == "vsa"),
606
601
  )
607
602
  if config.use_fp8_linear:
608
603
  enable_fp8_linear(dit2)
@@ -640,19 +635,22 @@ class WanVideoPipeline(BasePipeline):
640
635
  @staticmethod
641
636
  def _get_dit_type(model_state_dict: Dict[str, torch.Tensor] | Dict[str, Dict[str, torch.Tensor]]) -> str:
642
637
  # determine wan dit type by model params
638
+ def has_any_key(*xs):
639
+ return any(x in model_state_dict for x in xs)
640
+
643
641
  dit_type = None
644
- if "high_noise_model" in model_state_dict and "low_noise_model" in model_state_dict:
642
+ if has_any_key("high_noise_model"):
645
643
  if model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 36:
646
644
  dit_type = "wan2.2-i2v-a14b"
647
645
  elif model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 16:
648
646
  dit_type = "wan2.2-t2v-a14b"
649
647
  elif model_state_dict["patch_embedding.weight"].shape[1] == 48:
650
648
  dit_type = "wan2.2-ti2v-5b"
651
- elif "img_emb.emb_pos" in model_state_dict:
649
+ elif has_any_key("img_emb.emb_pos", "condition_embedder.image_embedder.pos_embed"):
652
650
  dit_type = "wan2.1-flf2v-14b"
653
- elif "img_emb.proj.0.weight" in model_state_dict:
651
+ elif has_any_key("img_emb.proj.0.weight", "condition_embedder.image_embedder.norm1"):
654
652
  dit_type = "wan2.1-i2v-14b"
655
- elif "blocks.39.self_attn.norm_q.weight" in model_state_dict:
653
+ elif has_any_key("blocks.39.self_attn.norm_q.weight", "blocks.39.attn1.norm_q.weight"):
656
654
  dit_type = "wan2.1-t2v-14b"
657
655
  else:
658
656
  dit_type = "wan2.1-t2v-1.3b"
@@ -1,10 +1,16 @@
1
1
  # Modified from transformers.tokenization_utils_base
2
2
  from typing import Dict, List, Union, overload
3
+ from enum import Enum
3
4
 
4
5
 
5
6
  TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
6
7
 
7
8
 
9
+ class PaddingStrategy(str, Enum):
10
+ LONGEST = "longest"
11
+ MAX_LENGTH = "max_length"
12
+
13
+
8
14
  class BaseTokenizer:
9
15
  SPECIAL_TOKENS_ATTRIBUTES = [
10
16
  "bos_token",
@@ -4,7 +4,7 @@ import torch
4
4
  from typing import Dict, List, Union, Optional
5
5
  from tokenizers import Tokenizer as TokenizerFast, AddedToken
6
6
 
7
- from diffsynth_engine.tokenizers.base import BaseTokenizer, TOKENIZER_CONFIG_FILE
7
+ from diffsynth_engine.tokenizers.base import BaseTokenizer, PaddingStrategy, TOKENIZER_CONFIG_FILE
8
8
 
9
9
 
10
10
  VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
@@ -165,22 +165,28 @@ class Qwen2TokenizerFast(BaseTokenizer):
165
165
  texts: Union[str, List[str]],
166
166
  max_length: Optional[int] = None,
167
167
  padding_side: Optional[str] = None,
168
+ padding_strategy: Union[PaddingStrategy, str] = "longest",
168
169
  **kwargs,
169
170
  ) -> Dict[str, "torch.Tensor"]:
170
171
  """
171
172
  Tokenize text and prepare for model inputs.
172
173
 
173
174
  Args:
174
- text (`str`, `List[str]`, *optional*):
175
+ texts (`str`, `List[str]`):
175
176
  The sequence or batch of sequences to be encoded.
176
177
 
177
178
  max_length (`int`, *optional*):
178
- Each encoded sequence will be truncated or padded to max_length.
179
+ Maximum length of the encoded sequences.
179
180
 
180
181
  padding_side (`str`, *optional*):
181
182
  The side on which the padding should be applied. Should be selected between `"right"` and `"left"`.
182
183
  Defaults to `"right"`.
183
184
 
185
+ padding_strategy (`PaddingStrategy`, `str`, *optional*):
186
+ If `"longest"`, will pad the sequences to the longest sequence in the batch.
187
+ If `"max_length"`, will pad the sequences to the `max_length` argument.
188
+ Defaults to `"longest"`.
189
+
184
190
  Returns:
185
191
  `Dict[str, "torch.Tensor"]`: tensor dict compatible with model_input_names.
186
192
  """
@@ -190,7 +196,9 @@ class Qwen2TokenizerFast(BaseTokenizer):
190
196
 
191
197
  batch_ids = self.batch_encode(texts)
192
198
  ids_lens = [len(ids_) for ids_ in batch_ids]
193
- max_length = max_length if max_length is not None else min(max(ids_lens), self.model_max_length)
199
+ max_length = max_length if max_length is not None else self.model_max_length
200
+ if padding_strategy == PaddingStrategy.LONGEST:
201
+ max_length = min(max(ids_lens), max_length)
194
202
  padding_side = padding_side if padding_side is not None else self.padding_side
195
203
 
196
204
  encoded = torch.zeros(len(texts), max_length, dtype=torch.long)
@@ -27,18 +27,19 @@ SD3_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd3", "sd3_tex
27
27
  SDXL_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_text_encoder.json")
28
28
  SDXL_UNET_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_unet.json")
29
29
 
30
- WAN2_1_DIT_T2V_1_3B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-t2v-1.3b.json")
31
- WAN2_1_DIT_T2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-t2v-14b.json")
32
- WAN2_1_DIT_I2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-i2v-14b.json")
33
- WAN2_1_DIT_FLF2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-flf2v-14b.json")
34
- WAN2_2_DIT_TI2V_5B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-ti2v-5b.json")
35
- WAN2_2_DIT_T2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-t2v-a14b.json")
36
- WAN2_2_DIT_I2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-i2v-a14b.json")
37
- WAN2_2_DIT_S2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-s2v-14b.json")
38
-
39
- WAN2_1_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.1-vae.json")
40
- WAN2_2_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.2-vae.json")
41
- WAN_VAE_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan-vae-keymap.json")
30
+ WAN2_1_DIT_T2V_1_3B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_t2v_1.3b.json")
31
+ WAN2_1_DIT_T2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_t2v_14b.json")
32
+ WAN2_1_DIT_I2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_i2v_14b.json")
33
+ WAN2_1_DIT_FLF2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_flf2v_14b.json")
34
+ WAN2_2_DIT_TI2V_5B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_ti2v_5b.json")
35
+ WAN2_2_DIT_T2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_t2v_a14b.json")
36
+ WAN2_2_DIT_I2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_i2v_a14b.json")
37
+ WAN2_2_DIT_S2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_s2v_14b.json")
38
+ WAN_DIT_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan_dit_keymap.json")
39
+
40
+ WAN2_1_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.1_vae.json")
41
+ WAN2_2_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.2_vae.json")
42
+ WAN_VAE_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan_vae_keymap.json")
42
43
 
43
44
  QWEN_IMAGE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "qwen_image", "qwen2_5_vl_config.json")
44
45
  QWEN_IMAGE_VISION_CONFIG_FILE = os.path.join(CONF_PATH, "models", "qwen_image", "qwen2_5_vl_vision_config.json")
@@ -44,3 +44,9 @@ if SPARGE_ATTN_AVAILABLE:
44
44
  logger.info("Sparge attention is available")
45
45
  else:
46
46
  logger.info("Sparge attention is not available")
47
+
48
+ VIDEO_SPARSE_ATTN_AVAILABLE = importlib.util.find_spec("vsa") is not None
49
+ if VIDEO_SPARSE_ATTN_AVAILABLE:
50
+ logger.info("Video sparse attention is available")
51
+ else:
52
+ logger.info("Video sparse attention is not available")