diffsynth-engine 0.6.1.dev22__py3-none-any.whl → 0.6.1.dev23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
  2. diffsynth_engine/configs/pipeline.py +33 -5
  3. diffsynth_engine/models/basic/attention.py +59 -20
  4. diffsynth_engine/models/basic/video_sparse_attention.py +235 -0
  5. diffsynth_engine/models/flux/flux_controlnet.py +7 -19
  6. diffsynth_engine/models/flux/flux_dit.py +22 -36
  7. diffsynth_engine/models/flux/flux_dit_fbcache.py +9 -7
  8. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  9. diffsynth_engine/models/qwen_image/qwen_image_dit.py +13 -15
  10. diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
  11. diffsynth_engine/models/wan/wan_dit.py +62 -22
  12. diffsynth_engine/pipelines/flux_image.py +11 -10
  13. diffsynth_engine/pipelines/qwen_image.py +3 -10
  14. diffsynth_engine/pipelines/wan_s2v.py +3 -8
  15. diffsynth_engine/pipelines/wan_video.py +11 -13
  16. diffsynth_engine/utils/constants.py +13 -12
  17. diffsynth_engine/utils/flag.py +6 -0
  18. diffsynth_engine/utils/parallel.py +51 -6
  19. {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/METADATA +1 -1
  20. {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/RECORD +34 -32
  21. /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
  22. /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
  23. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
  24. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
  25. /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
  26. /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
  27. /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
  28. /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
  29. /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
  30. /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
  31. /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
  32. {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/WHEEL +0 -0
  33. {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/licenses/LICENSE +0 -0
  34. {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from diffsynth_engine.utils.constants import (
17
17
  WAN2_2_DIT_TI2V_5B_CONFIG_FILE,
18
18
  WAN2_2_DIT_I2V_A14B_CONFIG_FILE,
19
19
  WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
20
+ WAN_DIT_KEYMAP_FILE,
20
21
  )
21
22
  from diffsynth_engine.utils.gguf import gguf_inference
22
23
  from diffsynth_engine.utils.fp8_linear import fp8_inference
@@ -30,6 +31,9 @@ from diffsynth_engine.utils.parallel import (
30
31
  T5_TOKEN_NUM = 512
31
32
  FLF_TOKEN_NUM = 257 * 2
32
33
 
34
+ with open(WAN_DIT_KEYMAP_FILE, "r", encoding="utf-8") as f:
35
+ config = json.load(f)
36
+
33
37
 
34
38
  def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
35
39
  return x * (1 + scale) + shift
@@ -73,7 +77,7 @@ class SelfAttention(nn.Module):
73
77
  dim: int,
74
78
  num_heads: int,
75
79
  eps: float = 1e-6,
76
- attn_kwargs: Optional[Dict[str, Any]] = None,
80
+ use_vsa: bool = False,
77
81
  device: str = "cuda:0",
78
82
  dtype: torch.dtype = torch.bfloat16,
79
83
  ):
@@ -86,19 +90,25 @@ class SelfAttention(nn.Module):
86
90
  self.o = nn.Linear(dim, dim, device=device, dtype=dtype)
87
91
  self.norm_q = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
88
92
  self.norm_k = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
89
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
93
+ self.gate_compress = nn.Linear(dim, dim, device=device, dtype=dtype) if use_vsa else None
90
94
 
91
- def forward(self, x, freqs):
95
+ def forward(self, x, freqs, attn_kwargs=None):
92
96
  q, k, v = self.norm_q(self.q(x)), self.norm_k(self.k(x)), self.v(x)
97
+ g = self.gate_compress(x) if self.gate_compress is not None else None
98
+
93
99
  num_heads = q.shape[2] // self.head_dim
94
100
  q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
95
101
  k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
96
102
  v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
103
+ g = rearrange(g, "b s (n d) -> b s n d", n=num_heads) if g is not None else None
104
+
105
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
97
106
  x = attention_ops.attention(
98
107
  q=rope_apply(q, freqs),
99
108
  k=rope_apply(k, freqs),
100
109
  v=v,
101
- **self.attn_kwargs,
110
+ g=g,
111
+ **attn_kwargs,
102
112
  )
103
113
  x = x.flatten(2)
104
114
  return self.o(x)
@@ -111,7 +121,6 @@ class CrossAttention(nn.Module):
111
121
  num_heads: int,
112
122
  eps: float = 1e-6,
113
123
  has_image_input: bool = False,
114
- attn_kwargs: Optional[Dict[str, Any]] = None,
115
124
  device: str = "cuda:0",
116
125
  dtype: torch.dtype = torch.bfloat16,
117
126
  ):
@@ -130,9 +139,8 @@ class CrossAttention(nn.Module):
130
139
  self.k_img = nn.Linear(dim, dim, device=device, dtype=dtype)
131
140
  self.v_img = nn.Linear(dim, dim, device=device, dtype=dtype)
132
141
  self.norm_k_img = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
133
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
134
142
 
135
- def forward(self, x: torch.Tensor, y: torch.Tensor):
143
+ def forward(self, x: torch.Tensor, y: torch.Tensor, attn_kwargs=None):
136
144
  if self.has_image_input:
137
145
  img = y[:, :-T5_TOKEN_NUM]
138
146
  ctx = y[:, -T5_TOKEN_NUM:]
@@ -144,12 +152,16 @@ class CrossAttention(nn.Module):
144
152
  k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
145
153
  v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
146
154
 
147
- x = attention(q, k, v, **self.attn_kwargs).flatten(2)
155
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
156
+ if attn_kwargs.get("attn_impl", None) == "vsa":
157
+ attn_kwargs = attn_kwargs.copy()
158
+ attn_kwargs["attn_impl"] = "sdpa"
159
+ x = attention(q, k, v, **attn_kwargs).flatten(2)
148
160
  if self.has_image_input:
149
161
  k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
150
162
  k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
151
163
  v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
152
- y = attention(q, k_img, v_img, **self.attn_kwargs).flatten(2)
164
+ y = attention(q, k_img, v_img, **attn_kwargs).flatten(2)
153
165
  x = x + y
154
166
  return self.o(x)
155
167
 
@@ -162,7 +174,7 @@ class DiTBlock(nn.Module):
162
174
  num_heads: int,
163
175
  ffn_dim: int,
164
176
  eps: float = 1e-6,
165
- attn_kwargs: Optional[Dict[str, Any]] = None,
177
+ use_vsa: bool = False,
166
178
  device: str = "cuda:0",
167
179
  dtype: torch.dtype = torch.bfloat16,
168
180
  ):
@@ -170,9 +182,9 @@ class DiTBlock(nn.Module):
170
182
  self.dim = dim
171
183
  self.num_heads = num_heads
172
184
  self.ffn_dim = ffn_dim
173
- self.self_attn = SelfAttention(dim, num_heads, eps, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
185
+ self.self_attn = SelfAttention(dim, num_heads, eps, use_vsa=use_vsa, device=device, dtype=dtype)
174
186
  self.cross_attn = CrossAttention(
175
- dim, num_heads, eps, has_image_input=has_image_input, attn_kwargs=attn_kwargs, device=device, dtype=dtype
187
+ dim, num_heads, eps, has_image_input=has_image_input, device=device, dtype=dtype
176
188
  )
177
189
  self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
178
190
  self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
@@ -184,14 +196,14 @@ class DiTBlock(nn.Module):
184
196
  )
185
197
  self.modulation = nn.Parameter(torch.randn(1, 6, dim, device=device, dtype=dtype) / dim**0.5)
186
198
 
187
- def forward(self, x, context, t_mod, freqs):
199
+ def forward(self, x, context, t_mod, freqs, attn_kwargs=None):
188
200
  # msa: multi-head self-attention mlp: multi-layer perceptron
189
201
  shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
190
202
  t.squeeze(1) for t in (self.modulation + t_mod).chunk(6, dim=1)
191
203
  ]
192
204
  input_x = modulate(self.norm1(x), shift_msa, scale_msa)
193
- x = x + gate_msa * self.self_attn(input_x, freqs)
194
- x = x + self.cross_attn(self.norm3(x), context)
205
+ x = x + gate_msa * self.self_attn(input_x, freqs, attn_kwargs)
206
+ x = x + self.cross_attn(self.norm3(x), context, attn_kwargs)
195
207
  input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
196
208
  x = x + gate_mlp * self.ffn(input_x)
197
209
  return x
@@ -249,7 +261,26 @@ class Head(nn.Module):
249
261
 
250
262
 
251
263
  class WanDiTStateDictConverter(StateDictConverter):
264
+ def _from_diffusers(self, state_dict):
265
+ global_rename_dict = config["diffusers"]["global_rename_dict"]
266
+ rename_dict = config["diffusers"]["rename_dict"]
267
+ state_dict_ = {}
268
+ for name, param in state_dict.items():
269
+ suffix = ""
270
+ suffix = ".weight" if name.endswith(".weight") else suffix
271
+ suffix = ".bias" if name.endswith(".bias") else suffix
272
+ prefix = name[: -len(suffix)] if suffix else name
273
+ if prefix in global_rename_dict:
274
+ state_dict_[f"{global_rename_dict[prefix]}{suffix}"] = param
275
+ if prefix.startswith("blocks."):
276
+ _, idx, middle = prefix.split(".", 2)
277
+ if middle in rename_dict:
278
+ state_dict_[f"blocks.{idx}.{rename_dict[middle]}{suffix}"] = param
279
+ return state_dict_
280
+
252
281
  def convert(self, state_dict):
282
+ if "condition_embedder.time_proj.weight" in state_dict:
283
+ return self._from_diffusers(state_dict)
253
284
  return state_dict
254
285
 
255
286
 
@@ -273,7 +304,7 @@ class WanDiT(PreTrainedModel):
273
304
  has_vae_feature: bool = False,
274
305
  fuse_image_latents: bool = False,
275
306
  flf_pos_emb: bool = False,
276
- attn_kwargs: Optional[Dict[str, Any]] = None,
307
+ use_vsa: bool = False,
277
308
  device: str = "cuda:0",
278
309
  dtype: torch.dtype = torch.bfloat16,
279
310
  ):
@@ -307,7 +338,16 @@ class WanDiT(PreTrainedModel):
307
338
  )
308
339
  self.blocks = nn.ModuleList(
309
340
  [
310
- DiTBlock(has_clip_feature, dim, num_heads, ffn_dim, eps, attn_kwargs, device=device, dtype=dtype)
341
+ DiTBlock(
342
+ has_clip_feature,
343
+ dim,
344
+ num_heads,
345
+ ffn_dim,
346
+ eps,
347
+ use_vsa,
348
+ device=device,
349
+ dtype=dtype,
350
+ )
311
351
  for _ in range(num_layers)
312
352
  ]
313
353
  )
@@ -344,6 +384,7 @@ class WanDiT(PreTrainedModel):
344
384
  timestep: torch.Tensor,
345
385
  clip_feature: Optional[torch.Tensor] = None, # clip_vision_encoder(img)
346
386
  y: Optional[torch.Tensor] = None, # vae_encoder(img)
387
+ attn_kwargs: Optional[Dict[str, Any]] = None,
347
388
  ):
348
389
  fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
349
390
  use_cfg = x.shape[0] > 1
@@ -376,7 +417,7 @@ class WanDiT(PreTrainedModel):
376
417
 
377
418
  with sequence_parallel((x, t, t_mod, freqs), seq_dims=(1, 0, 0, 0)):
378
419
  for block in self.blocks:
379
- x = block(x, context, t_mod, freqs)
420
+ x = block(x, context, t_mod, freqs, attn_kwargs)
380
421
  x = self.head(x, t)
381
422
  (x,) = sequence_parallel_unshard((x,), seq_dims=(1,), seq_lens=(f * h * w,))
382
423
  x = self.unpatchify(x, (f, h, w))
@@ -409,12 +450,11 @@ class WanDiT(PreTrainedModel):
409
450
  config: Dict[str, Any],
410
451
  device: str = "cuda:0",
411
452
  dtype: torch.dtype = torch.bfloat16,
412
- attn_kwargs: Optional[Dict[str, Any]] = None,
413
- assign: bool = True,
453
+ use_vsa: bool = False,
414
454
  ):
415
- model = cls(**config, device="meta", dtype=dtype, attn_kwargs=attn_kwargs)
455
+ model = cls(**config, device="meta", dtype=dtype, use_vsa=use_vsa)
416
456
  model = model.requires_grad_(False)
417
- model.load_state_dict(state_dict, assign=assign)
457
+ model.load_state_dict(state_dict, assign=True)
418
458
  model.to(device=device, dtype=dtype, non_blocking=True)
419
459
  return model
420
460
 
@@ -17,7 +17,12 @@ from diffsynth_engine.models.flux import (
17
17
  flux_dit_config,
18
18
  flux_text_encoder_config,
19
19
  )
20
- from diffsynth_engine.configs import FluxPipelineConfig, FluxStateDicts, ControlType, ControlNetParams
20
+ from diffsynth_engine.configs import (
21
+ FluxPipelineConfig,
22
+ FluxStateDicts,
23
+ ControlType,
24
+ ControlNetParams,
25
+ )
21
26
  from diffsynth_engine.models.basic.lora import LoRAContext
22
27
  from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
23
28
  from diffsynth_engine.pipelines.utils import accumulate, calculate_shift
@@ -507,20 +512,12 @@ class FluxImagePipeline(BasePipeline):
507
512
  vae_encoder = FluxVAEEncoder.from_state_dict(state_dicts.vae, device=init_device, dtype=config.vae_dtype)
508
513
 
509
514
  with LoRAContext():
510
- attn_kwargs = {
511
- "attn_impl": config.dit_attn_impl.value,
512
- "sparge_smooth_k": config.sparge_smooth_k,
513
- "sparge_cdfthreshd": config.sparge_cdfthreshd,
514
- "sparge_simthreshd1": config.sparge_simthreshd1,
515
- "sparge_pvthreshd": config.sparge_pvthreshd,
516
- }
517
515
  if config.use_fbcache:
518
516
  dit = FluxDiTFBCache.from_state_dict(
519
517
  state_dicts.model,
520
518
  device=("cpu" if config.use_fsdp else init_device),
521
519
  dtype=config.model_dtype,
522
520
  in_channel=config.control_type.get_in_channel(),
523
- attn_kwargs=attn_kwargs,
524
521
  relative_l1_threshold=config.fbcache_relative_l1_threshold,
525
522
  )
526
523
  else:
@@ -529,7 +526,6 @@ class FluxImagePipeline(BasePipeline):
529
526
  device=("cpu" if config.use_fsdp else init_device),
530
527
  dtype=config.model_dtype,
531
528
  in_channel=config.control_type.get_in_channel(),
532
- attn_kwargs=attn_kwargs,
533
529
  )
534
530
  if config.use_fp8_linear:
535
531
  enable_fp8_linear(dit)
@@ -755,6 +751,7 @@ class FluxImagePipeline(BasePipeline):
755
751
  latents = latents.to(self.dtype)
756
752
  self.load_models_to_device(["dit"])
757
753
 
754
+ attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
758
755
  noise_pred = self.dit(
759
756
  hidden_states=latents,
760
757
  timestep=timestep,
@@ -766,6 +763,7 @@ class FluxImagePipeline(BasePipeline):
766
763
  image_ids=image_ids,
767
764
  controlnet_double_block_output=double_block_output,
768
765
  controlnet_single_block_output=single_block_output,
766
+ attn_kwargs=attn_kwargs,
769
767
  )
770
768
  noise_pred = noise_pred[:, :image_seq_len]
771
769
  noise_pred = self.dit.unpatchify(noise_pred, height, width)
@@ -887,6 +885,8 @@ class FluxImagePipeline(BasePipeline):
887
885
  if self.offload_mode is not None:
888
886
  empty_cache()
889
887
  param.model.to(self.device)
888
+
889
+ attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
890
890
  double_block_output, single_block_output = param.model(
891
891
  hidden_states=latents,
892
892
  control_condition=control_condition,
@@ -897,6 +897,7 @@ class FluxImagePipeline(BasePipeline):
897
897
  image_ids=image_ids,
898
898
  text_ids=text_ids,
899
899
  guidance=guidance,
900
+ attn_kwargs=attn_kwargs,
900
901
  )
901
902
  if self.offload_mode is not None:
902
903
  param.model.to("cpu")
@@ -91,7 +91,7 @@ class QwenImageLoRAConverter(LoRAStateDictConverter):
91
91
  if "lora_A.weight" in key:
92
92
  lora_a_suffix = "lora_A.weight"
93
93
  lora_b_suffix = "lora_B.weight"
94
-
94
+
95
95
  if lora_a_suffix is None:
96
96
  continue
97
97
 
@@ -253,19 +253,11 @@ class QwenImagePipeline(BasePipeline):
253
253
  )
254
254
 
255
255
  with LoRAContext():
256
- attn_kwargs = {
257
- "attn_impl": config.dit_attn_impl.value,
258
- "sparge_smooth_k": config.sparge_smooth_k,
259
- "sparge_cdfthreshd": config.sparge_cdfthreshd,
260
- "sparge_simthreshd1": config.sparge_simthreshd1,
261
- "sparge_pvthreshd": config.sparge_pvthreshd,
262
- }
263
256
  if config.use_fbcache:
264
257
  dit = QwenImageDiTFBCache.from_state_dict(
265
258
  state_dicts.model,
266
259
  device=("cpu" if config.use_fsdp else init_device),
267
260
  dtype=config.model_dtype,
268
- attn_kwargs=attn_kwargs,
269
261
  relative_l1_threshold=config.fbcache_relative_l1_threshold,
270
262
  )
271
263
  else:
@@ -273,7 +265,6 @@ class QwenImagePipeline(BasePipeline):
273
265
  state_dicts.model,
274
266
  device=("cpu" if config.use_fsdp else init_device),
275
267
  dtype=config.model_dtype,
276
- attn_kwargs=attn_kwargs,
277
268
  )
278
269
  if config.use_fp8_linear:
279
270
  enable_fp8_linear(dit)
@@ -548,6 +539,7 @@ class QwenImagePipeline(BasePipeline):
548
539
  entity_masks: Optional[List[torch.Tensor]] = None,
549
540
  ):
550
541
  self.load_models_to_device(["dit"])
542
+ attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
551
543
  noise_pred = self.dit(
552
544
  image=latents,
553
545
  edit=image_latents,
@@ -558,6 +550,7 @@ class QwenImagePipeline(BasePipeline):
558
550
  entity_text=entity_prompt_embs,
559
551
  entity_seq_lens=[mask.sum(dim=1) for mask in entity_prompt_emb_masks] if entity_prompt_emb_masks else None,
560
552
  entity_masks=entity_masks,
553
+ attn_kwargs=attn_kwargs,
561
554
  )
562
555
  return noise_pred
563
556
 
@@ -394,6 +394,7 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
394
394
  void_audio_input: torch.Tensor | None = None,
395
395
  ):
396
396
  latents = latents.to(dtype=self.config.model_dtype, device=self.device)
397
+ attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
397
398
 
398
399
  noise_pred = model(
399
400
  x=latents,
@@ -408,6 +409,7 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
408
409
  drop_motion_frames=drop_motion_frames,
409
410
  audio_mask=audio_mask,
410
411
  void_audio_input=void_audio_input,
412
+ attn_kwargs=attn_kwargs,
411
413
  )
412
414
  return noise_pred
413
415
 
@@ -654,19 +656,12 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
654
656
  )
655
657
 
656
658
  with LoRAContext():
657
- attn_kwargs = {
658
- "attn_impl": config.dit_attn_impl.value,
659
- "sparge_smooth_k": config.sparge_smooth_k,
660
- "sparge_cdfthreshd": config.sparge_cdfthreshd,
661
- "sparge_simthreshd1": config.sparge_simthreshd1,
662
- "sparge_pvthreshd": config.sparge_pvthreshd,
663
- }
664
659
  dit = WanS2VDiT.from_state_dict(
665
660
  state_dicts.model,
666
661
  config=model_config,
667
662
  device=("cpu" if config.use_fsdp else init_device),
668
663
  dtype=config.model_dtype,
669
- attn_kwargs=attn_kwargs,
664
+ use_vsa=(config.dit_attn_impl.value == "vsa"),
670
665
  )
671
666
  if config.use_fp8_linear:
672
667
  enable_fp8_linear(dit)
@@ -323,6 +323,7 @@ class WanVideoPipeline(BasePipeline):
323
323
 
324
324
  def predict_noise(self, model, latents, image_clip_feature, image_y, timestep, context):
325
325
  latents = latents.to(dtype=self.config.model_dtype, device=self.device)
326
+ attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
326
327
 
327
328
  noise_pred = model(
328
329
  x=latents,
@@ -330,6 +331,7 @@ class WanVideoPipeline(BasePipeline):
330
331
  context=context,
331
332
  clip_feature=image_clip_feature,
332
333
  y=image_y,
334
+ attn_kwargs=attn_kwargs,
333
335
  )
334
336
  return noise_pred
335
337
 
@@ -578,19 +580,12 @@ class WanVideoPipeline(BasePipeline):
578
580
  dit_state_dict = state_dicts.model
579
581
 
580
582
  with LoRAContext():
581
- attn_kwargs = {
582
- "attn_impl": config.dit_attn_impl.value,
583
- "sparge_smooth_k": config.sparge_smooth_k,
584
- "sparge_cdfthreshd": config.sparge_cdfthreshd,
585
- "sparge_simthreshd1": config.sparge_simthreshd1,
586
- "sparge_pvthreshd": config.sparge_pvthreshd,
587
- }
588
583
  dit = WanDiT.from_state_dict(
589
584
  dit_state_dict,
590
585
  config=dit_config,
591
586
  device=("cpu" if config.use_fsdp else init_device),
592
587
  dtype=config.model_dtype,
593
- attn_kwargs=attn_kwargs,
588
+ use_vsa=(config.dit_attn_impl.value == "vsa"),
594
589
  )
595
590
  if config.use_fp8_linear:
596
591
  enable_fp8_linear(dit)
@@ -602,7 +597,7 @@ class WanVideoPipeline(BasePipeline):
602
597
  config=dit_config,
603
598
  device=("cpu" if config.use_fsdp else init_device),
604
599
  dtype=config.model_dtype,
605
- attn_kwargs=attn_kwargs,
600
+ use_vsa=(config.dit_attn_impl.value == "vsa"),
606
601
  )
607
602
  if config.use_fp8_linear:
608
603
  enable_fp8_linear(dit2)
@@ -640,19 +635,22 @@ class WanVideoPipeline(BasePipeline):
640
635
  @staticmethod
641
636
  def _get_dit_type(model_state_dict: Dict[str, torch.Tensor] | Dict[str, Dict[str, torch.Tensor]]) -> str:
642
637
  # determine wan dit type by model params
638
+ def has_any_key(*xs):
639
+ return any(x in model_state_dict for x in xs)
640
+
643
641
  dit_type = None
644
- if "high_noise_model" in model_state_dict and "low_noise_model" in model_state_dict:
642
+ if has_any_key("high_noise_model"):
645
643
  if model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 36:
646
644
  dit_type = "wan2.2-i2v-a14b"
647
645
  elif model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 16:
648
646
  dit_type = "wan2.2-t2v-a14b"
649
647
  elif model_state_dict["patch_embedding.weight"].shape[1] == 48:
650
648
  dit_type = "wan2.2-ti2v-5b"
651
- elif "img_emb.emb_pos" in model_state_dict:
649
+ elif has_any_key("img_emb.emb_pos", "condition_embedder.image_embedder.pos_embed"):
652
650
  dit_type = "wan2.1-flf2v-14b"
653
- elif "img_emb.proj.0.weight" in model_state_dict:
651
+ elif has_any_key("img_emb.proj.0.weight", "condition_embedder.image_embedder.norm1"):
654
652
  dit_type = "wan2.1-i2v-14b"
655
- elif "blocks.39.self_attn.norm_q.weight" in model_state_dict:
653
+ elif has_any_key("blocks.39.self_attn.norm_q.weight", "blocks.39.attn1.norm_q.weight"):
656
654
  dit_type = "wan2.1-t2v-14b"
657
655
  else:
658
656
  dit_type = "wan2.1-t2v-1.3b"
@@ -27,18 +27,19 @@ SD3_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd3", "sd3_tex
27
27
  SDXL_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_text_encoder.json")
28
28
  SDXL_UNET_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_unet.json")
29
29
 
30
- WAN2_1_DIT_T2V_1_3B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-t2v-1.3b.json")
31
- WAN2_1_DIT_T2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-t2v-14b.json")
32
- WAN2_1_DIT_I2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-i2v-14b.json")
33
- WAN2_1_DIT_FLF2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-flf2v-14b.json")
34
- WAN2_2_DIT_TI2V_5B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-ti2v-5b.json")
35
- WAN2_2_DIT_T2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-t2v-a14b.json")
36
- WAN2_2_DIT_I2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-i2v-a14b.json")
37
- WAN2_2_DIT_S2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-s2v-14b.json")
38
-
39
- WAN2_1_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.1-vae.json")
40
- WAN2_2_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.2-vae.json")
41
- WAN_VAE_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan-vae-keymap.json")
30
+ WAN2_1_DIT_T2V_1_3B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_t2v_1.3b.json")
31
+ WAN2_1_DIT_T2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_t2v_14b.json")
32
+ WAN2_1_DIT_I2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_i2v_14b.json")
33
+ WAN2_1_DIT_FLF2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_flf2v_14b.json")
34
+ WAN2_2_DIT_TI2V_5B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_ti2v_5b.json")
35
+ WAN2_2_DIT_T2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_t2v_a14b.json")
36
+ WAN2_2_DIT_I2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_i2v_a14b.json")
37
+ WAN2_2_DIT_S2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_s2v_14b.json")
38
+ WAN_DIT_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan_dit_keymap.json")
39
+
40
+ WAN2_1_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.1_vae.json")
41
+ WAN2_2_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.2_vae.json")
42
+ WAN_VAE_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan_vae_keymap.json")
42
43
 
43
44
  QWEN_IMAGE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "qwen_image", "qwen2_5_vl_config.json")
44
45
  QWEN_IMAGE_VISION_CONFIG_FILE = os.path.join(CONF_PATH, "models", "qwen_image", "qwen2_5_vl_vision_config.json")
@@ -44,3 +44,9 @@ if SPARGE_ATTN_AVAILABLE:
44
44
  logger.info("Sparge attention is available")
45
45
  else:
46
46
  logger.info("Sparge attention is not available")
47
+
48
+ VIDEO_SPARSE_ATTN_AVAILABLE = importlib.util.find_spec("vsa") is not None
49
+ if VIDEO_SPARSE_ATTN_AVAILABLE:
50
+ logger.info("Video sparse attention is available")
51
+ else:
52
+ logger.info("Video sparse attention is not available")
@@ -40,10 +40,14 @@ class ProcessGroupSingleton(Singleton):
40
40
  def __init__(self):
41
41
  self.CFG_GROUP: Optional[dist.ProcessGroup] = None
42
42
  self.SP_GROUP: Optional[dist.ProcessGroup] = None
43
+ self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None
44
+ self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None
43
45
  self.TP_GROUP: Optional[dist.ProcessGroup] = None
44
46
 
45
47
  self.CFG_RANKS: List[int] = []
46
48
  self.SP_RANKS: List[int] = []
49
+ self.SP_ULYSSUES_RANKS: List[int] = []
50
+ self.SP_RING_RANKS: List[int] = []
47
51
  self.TP_RANKS: List[int] = []
48
52
 
49
53
 
@@ -82,6 +86,38 @@ def get_sp_ranks():
82
86
  return PROCESS_GROUP.SP_RANKS
83
87
 
84
88
 
89
+ def get_sp_ulysses_group():
90
+ return PROCESS_GROUP.SP_ULYSSUES_GROUP
91
+
92
+
93
+ def get_sp_ulysses_world_size():
94
+ return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1
95
+
96
+
97
+ def get_sp_ulysses_rank():
98
+ return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0
99
+
100
+
101
+ def get_sp_ulysses_ranks():
102
+ return PROCESS_GROUP.SP_ULYSSUES_RANKS
103
+
104
+
105
+ def get_sp_ring_group():
106
+ return PROCESS_GROUP.SP_RING_GROUP
107
+
108
+
109
+ def get_sp_ring_world_size():
110
+ return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1
111
+
112
+
113
+ def get_sp_ring_rank():
114
+ return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0
115
+
116
+
117
+ def get_sp_ring_ranks():
118
+ return PROCESS_GROUP.SP_RING_RANKS
119
+
120
+
85
121
  def get_tp_group():
86
122
  return PROCESS_GROUP.TP_GROUP
87
123
 
@@ -127,23 +163,32 @@ def init_parallel_pgs(
127
163
  blocks = [list(range(world_size))]
128
164
  cfg_groups, cfg_blocks = make_parallel_groups(blocks, cfg_degree)
129
165
  for cfg_ranks in cfg_groups:
130
- cfg_group = dist.new_group(cfg_ranks)
131
166
  if rank in cfg_ranks:
132
- PROCESS_GROUP.CFG_GROUP = cfg_group
167
+ PROCESS_GROUP.CFG_GROUP = dist.new_group(cfg_ranks)
133
168
  PROCESS_GROUP.CFG_RANKS = cfg_ranks
134
169
 
135
170
  sp_groups, sp_blocks = make_parallel_groups(cfg_blocks, sp_degree)
136
171
  for sp_ranks in sp_groups:
137
- group = dist.new_group(sp_ranks)
138
172
  if rank in sp_ranks:
139
- PROCESS_GROUP.SP_GROUP = group
173
+ PROCESS_GROUP.SP_GROUP = dist.new_group(sp_ranks)
140
174
  PROCESS_GROUP.SP_RANKS = sp_ranks
141
175
 
176
+ sp_ulysses_groups, sp_ulysses_blocks = make_parallel_groups(cfg_blocks, sp_ulysses_degree)
177
+ for sp_ulysses_ranks in sp_ulysses_groups:
178
+ if rank in sp_ulysses_ranks:
179
+ PROCESS_GROUP.SP_ULYSSUES_GROUP = dist.new_group(sp_ulysses_ranks)
180
+ PROCESS_GROUP.SP_ULYSSUES_RANKS = sp_ulysses_ranks
181
+
182
+ sp_ring_groups, _ = make_parallel_groups(sp_ulysses_blocks, sp_ring_degree)
183
+ for sp_ring_ranks in sp_ring_groups:
184
+ if rank in sp_ring_ranks:
185
+ PROCESS_GROUP.SP_RING_GROUP = dist.new_group(sp_ring_ranks)
186
+ PROCESS_GROUP.SP_RING_RANKS = sp_ring_ranks
187
+
142
188
  tp_groups, _ = make_parallel_groups(sp_blocks, tp_degree)
143
189
  for tp_ranks in tp_groups:
144
- group = dist.new_group(tp_ranks)
145
190
  if rank in tp_ranks:
146
- PROCESS_GROUP.TP_GROUP = group
191
+ PROCESS_GROUP.TP_GROUP = dist.new_group(tp_ranks)
147
192
  PROCESS_GROUP.TP_RANKS = tp_ranks
148
193
 
149
194
  set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev22
3
+ Version: 0.6.1.dev23
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent