diffsynth-engine 0.6.1.dev22__py3-none-any.whl → 0.6.1.dev23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
- diffsynth_engine/configs/pipeline.py +33 -5
- diffsynth_engine/models/basic/attention.py +59 -20
- diffsynth_engine/models/basic/video_sparse_attention.py +235 -0
- diffsynth_engine/models/flux/flux_controlnet.py +7 -19
- diffsynth_engine/models/flux/flux_dit.py +22 -36
- diffsynth_engine/models/flux/flux_dit_fbcache.py +9 -7
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/qwen_image/qwen_image_dit.py +13 -15
- diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
- diffsynth_engine/models/wan/wan_dit.py +62 -22
- diffsynth_engine/pipelines/flux_image.py +11 -10
- diffsynth_engine/pipelines/qwen_image.py +3 -10
- diffsynth_engine/pipelines/wan_s2v.py +3 -8
- diffsynth_engine/pipelines/wan_video.py +11 -13
- diffsynth_engine/utils/constants.py +13 -12
- diffsynth_engine/utils/flag.py +6 -0
- diffsynth_engine/utils/parallel.py +51 -6
- {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/RECORD +34 -32
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
- {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/top_level.txt +0 -0
|
@@ -17,6 +17,7 @@ from diffsynth_engine.utils.constants import (
|
|
|
17
17
|
WAN2_2_DIT_TI2V_5B_CONFIG_FILE,
|
|
18
18
|
WAN2_2_DIT_I2V_A14B_CONFIG_FILE,
|
|
19
19
|
WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
|
|
20
|
+
WAN_DIT_KEYMAP_FILE,
|
|
20
21
|
)
|
|
21
22
|
from diffsynth_engine.utils.gguf import gguf_inference
|
|
22
23
|
from diffsynth_engine.utils.fp8_linear import fp8_inference
|
|
@@ -30,6 +31,9 @@ from diffsynth_engine.utils.parallel import (
|
|
|
30
31
|
T5_TOKEN_NUM = 512
|
|
31
32
|
FLF_TOKEN_NUM = 257 * 2
|
|
32
33
|
|
|
34
|
+
with open(WAN_DIT_KEYMAP_FILE, "r", encoding="utf-8") as f:
|
|
35
|
+
config = json.load(f)
|
|
36
|
+
|
|
33
37
|
|
|
34
38
|
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
|
35
39
|
return x * (1 + scale) + shift
|
|
@@ -73,7 +77,7 @@ class SelfAttention(nn.Module):
|
|
|
73
77
|
dim: int,
|
|
74
78
|
num_heads: int,
|
|
75
79
|
eps: float = 1e-6,
|
|
76
|
-
|
|
80
|
+
use_vsa: bool = False,
|
|
77
81
|
device: str = "cuda:0",
|
|
78
82
|
dtype: torch.dtype = torch.bfloat16,
|
|
79
83
|
):
|
|
@@ -86,19 +90,25 @@ class SelfAttention(nn.Module):
|
|
|
86
90
|
self.o = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
87
91
|
self.norm_q = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
88
92
|
self.norm_k = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
89
|
-
self.
|
|
93
|
+
self.gate_compress = nn.Linear(dim, dim, device=device, dtype=dtype) if use_vsa else None
|
|
90
94
|
|
|
91
|
-
def forward(self, x, freqs):
|
|
95
|
+
def forward(self, x, freqs, attn_kwargs=None):
|
|
92
96
|
q, k, v = self.norm_q(self.q(x)), self.norm_k(self.k(x)), self.v(x)
|
|
97
|
+
g = self.gate_compress(x) if self.gate_compress is not None else None
|
|
98
|
+
|
|
93
99
|
num_heads = q.shape[2] // self.head_dim
|
|
94
100
|
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
|
95
101
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
|
96
102
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
|
103
|
+
g = rearrange(g, "b s (n d) -> b s n d", n=num_heads) if g is not None else None
|
|
104
|
+
|
|
105
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
97
106
|
x = attention_ops.attention(
|
|
98
107
|
q=rope_apply(q, freqs),
|
|
99
108
|
k=rope_apply(k, freqs),
|
|
100
109
|
v=v,
|
|
101
|
-
|
|
110
|
+
g=g,
|
|
111
|
+
**attn_kwargs,
|
|
102
112
|
)
|
|
103
113
|
x = x.flatten(2)
|
|
104
114
|
return self.o(x)
|
|
@@ -111,7 +121,6 @@ class CrossAttention(nn.Module):
|
|
|
111
121
|
num_heads: int,
|
|
112
122
|
eps: float = 1e-6,
|
|
113
123
|
has_image_input: bool = False,
|
|
114
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
115
124
|
device: str = "cuda:0",
|
|
116
125
|
dtype: torch.dtype = torch.bfloat16,
|
|
117
126
|
):
|
|
@@ -130,9 +139,8 @@ class CrossAttention(nn.Module):
|
|
|
130
139
|
self.k_img = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
131
140
|
self.v_img = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
132
141
|
self.norm_k_img = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
133
|
-
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
134
142
|
|
|
135
|
-
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
143
|
+
def forward(self, x: torch.Tensor, y: torch.Tensor, attn_kwargs=None):
|
|
136
144
|
if self.has_image_input:
|
|
137
145
|
img = y[:, :-T5_TOKEN_NUM]
|
|
138
146
|
ctx = y[:, -T5_TOKEN_NUM:]
|
|
@@ -144,12 +152,16 @@ class CrossAttention(nn.Module):
|
|
|
144
152
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
|
145
153
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
|
146
154
|
|
|
147
|
-
|
|
155
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
156
|
+
if attn_kwargs.get("attn_impl", None) == "vsa":
|
|
157
|
+
attn_kwargs = attn_kwargs.copy()
|
|
158
|
+
attn_kwargs["attn_impl"] = "sdpa"
|
|
159
|
+
x = attention(q, k, v, **attn_kwargs).flatten(2)
|
|
148
160
|
if self.has_image_input:
|
|
149
161
|
k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
|
|
150
162
|
k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
|
|
151
163
|
v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
|
|
152
|
-
y = attention(q, k_img, v_img, **
|
|
164
|
+
y = attention(q, k_img, v_img, **attn_kwargs).flatten(2)
|
|
153
165
|
x = x + y
|
|
154
166
|
return self.o(x)
|
|
155
167
|
|
|
@@ -162,7 +174,7 @@ class DiTBlock(nn.Module):
|
|
|
162
174
|
num_heads: int,
|
|
163
175
|
ffn_dim: int,
|
|
164
176
|
eps: float = 1e-6,
|
|
165
|
-
|
|
177
|
+
use_vsa: bool = False,
|
|
166
178
|
device: str = "cuda:0",
|
|
167
179
|
dtype: torch.dtype = torch.bfloat16,
|
|
168
180
|
):
|
|
@@ -170,9 +182,9 @@ class DiTBlock(nn.Module):
|
|
|
170
182
|
self.dim = dim
|
|
171
183
|
self.num_heads = num_heads
|
|
172
184
|
self.ffn_dim = ffn_dim
|
|
173
|
-
self.self_attn = SelfAttention(dim, num_heads, eps,
|
|
185
|
+
self.self_attn = SelfAttention(dim, num_heads, eps, use_vsa=use_vsa, device=device, dtype=dtype)
|
|
174
186
|
self.cross_attn = CrossAttention(
|
|
175
|
-
dim, num_heads, eps, has_image_input=has_image_input,
|
|
187
|
+
dim, num_heads, eps, has_image_input=has_image_input, device=device, dtype=dtype
|
|
176
188
|
)
|
|
177
189
|
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
|
|
178
190
|
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
|
|
@@ -184,14 +196,14 @@ class DiTBlock(nn.Module):
|
|
|
184
196
|
)
|
|
185
197
|
self.modulation = nn.Parameter(torch.randn(1, 6, dim, device=device, dtype=dtype) / dim**0.5)
|
|
186
198
|
|
|
187
|
-
def forward(self, x, context, t_mod, freqs):
|
|
199
|
+
def forward(self, x, context, t_mod, freqs, attn_kwargs=None):
|
|
188
200
|
# msa: multi-head self-attention mlp: multi-layer perceptron
|
|
189
201
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
|
190
202
|
t.squeeze(1) for t in (self.modulation + t_mod).chunk(6, dim=1)
|
|
191
203
|
]
|
|
192
204
|
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
|
193
|
-
x = x + gate_msa * self.self_attn(input_x, freqs)
|
|
194
|
-
x = x + self.cross_attn(self.norm3(x), context)
|
|
205
|
+
x = x + gate_msa * self.self_attn(input_x, freqs, attn_kwargs)
|
|
206
|
+
x = x + self.cross_attn(self.norm3(x), context, attn_kwargs)
|
|
195
207
|
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
|
196
208
|
x = x + gate_mlp * self.ffn(input_x)
|
|
197
209
|
return x
|
|
@@ -249,7 +261,26 @@ class Head(nn.Module):
|
|
|
249
261
|
|
|
250
262
|
|
|
251
263
|
class WanDiTStateDictConverter(StateDictConverter):
|
|
264
|
+
def _from_diffusers(self, state_dict):
|
|
265
|
+
global_rename_dict = config["diffusers"]["global_rename_dict"]
|
|
266
|
+
rename_dict = config["diffusers"]["rename_dict"]
|
|
267
|
+
state_dict_ = {}
|
|
268
|
+
for name, param in state_dict.items():
|
|
269
|
+
suffix = ""
|
|
270
|
+
suffix = ".weight" if name.endswith(".weight") else suffix
|
|
271
|
+
suffix = ".bias" if name.endswith(".bias") else suffix
|
|
272
|
+
prefix = name[: -len(suffix)] if suffix else name
|
|
273
|
+
if prefix in global_rename_dict:
|
|
274
|
+
state_dict_[f"{global_rename_dict[prefix]}{suffix}"] = param
|
|
275
|
+
if prefix.startswith("blocks."):
|
|
276
|
+
_, idx, middle = prefix.split(".", 2)
|
|
277
|
+
if middle in rename_dict:
|
|
278
|
+
state_dict_[f"blocks.{idx}.{rename_dict[middle]}{suffix}"] = param
|
|
279
|
+
return state_dict_
|
|
280
|
+
|
|
252
281
|
def convert(self, state_dict):
|
|
282
|
+
if "condition_embedder.time_proj.weight" in state_dict:
|
|
283
|
+
return self._from_diffusers(state_dict)
|
|
253
284
|
return state_dict
|
|
254
285
|
|
|
255
286
|
|
|
@@ -273,7 +304,7 @@ class WanDiT(PreTrainedModel):
|
|
|
273
304
|
has_vae_feature: bool = False,
|
|
274
305
|
fuse_image_latents: bool = False,
|
|
275
306
|
flf_pos_emb: bool = False,
|
|
276
|
-
|
|
307
|
+
use_vsa: bool = False,
|
|
277
308
|
device: str = "cuda:0",
|
|
278
309
|
dtype: torch.dtype = torch.bfloat16,
|
|
279
310
|
):
|
|
@@ -307,7 +338,16 @@ class WanDiT(PreTrainedModel):
|
|
|
307
338
|
)
|
|
308
339
|
self.blocks = nn.ModuleList(
|
|
309
340
|
[
|
|
310
|
-
DiTBlock(
|
|
341
|
+
DiTBlock(
|
|
342
|
+
has_clip_feature,
|
|
343
|
+
dim,
|
|
344
|
+
num_heads,
|
|
345
|
+
ffn_dim,
|
|
346
|
+
eps,
|
|
347
|
+
use_vsa,
|
|
348
|
+
device=device,
|
|
349
|
+
dtype=dtype,
|
|
350
|
+
)
|
|
311
351
|
for _ in range(num_layers)
|
|
312
352
|
]
|
|
313
353
|
)
|
|
@@ -344,6 +384,7 @@ class WanDiT(PreTrainedModel):
|
|
|
344
384
|
timestep: torch.Tensor,
|
|
345
385
|
clip_feature: Optional[torch.Tensor] = None, # clip_vision_encoder(img)
|
|
346
386
|
y: Optional[torch.Tensor] = None, # vae_encoder(img)
|
|
387
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
347
388
|
):
|
|
348
389
|
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
|
|
349
390
|
use_cfg = x.shape[0] > 1
|
|
@@ -376,7 +417,7 @@ class WanDiT(PreTrainedModel):
|
|
|
376
417
|
|
|
377
418
|
with sequence_parallel((x, t, t_mod, freqs), seq_dims=(1, 0, 0, 0)):
|
|
378
419
|
for block in self.blocks:
|
|
379
|
-
x = block(x, context, t_mod, freqs)
|
|
420
|
+
x = block(x, context, t_mod, freqs, attn_kwargs)
|
|
380
421
|
x = self.head(x, t)
|
|
381
422
|
(x,) = sequence_parallel_unshard((x,), seq_dims=(1,), seq_lens=(f * h * w,))
|
|
382
423
|
x = self.unpatchify(x, (f, h, w))
|
|
@@ -409,12 +450,11 @@ class WanDiT(PreTrainedModel):
|
|
|
409
450
|
config: Dict[str, Any],
|
|
410
451
|
device: str = "cuda:0",
|
|
411
452
|
dtype: torch.dtype = torch.bfloat16,
|
|
412
|
-
|
|
413
|
-
assign: bool = True,
|
|
453
|
+
use_vsa: bool = False,
|
|
414
454
|
):
|
|
415
|
-
model = cls(**config, device="meta", dtype=dtype,
|
|
455
|
+
model = cls(**config, device="meta", dtype=dtype, use_vsa=use_vsa)
|
|
416
456
|
model = model.requires_grad_(False)
|
|
417
|
-
model.load_state_dict(state_dict, assign=
|
|
457
|
+
model.load_state_dict(state_dict, assign=True)
|
|
418
458
|
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
419
459
|
return model
|
|
420
460
|
|
|
@@ -17,7 +17,12 @@ from diffsynth_engine.models.flux import (
|
|
|
17
17
|
flux_dit_config,
|
|
18
18
|
flux_text_encoder_config,
|
|
19
19
|
)
|
|
20
|
-
from diffsynth_engine.configs import
|
|
20
|
+
from diffsynth_engine.configs import (
|
|
21
|
+
FluxPipelineConfig,
|
|
22
|
+
FluxStateDicts,
|
|
23
|
+
ControlType,
|
|
24
|
+
ControlNetParams,
|
|
25
|
+
)
|
|
21
26
|
from diffsynth_engine.models.basic.lora import LoRAContext
|
|
22
27
|
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
|
|
23
28
|
from diffsynth_engine.pipelines.utils import accumulate, calculate_shift
|
|
@@ -507,20 +512,12 @@ class FluxImagePipeline(BasePipeline):
|
|
|
507
512
|
vae_encoder = FluxVAEEncoder.from_state_dict(state_dicts.vae, device=init_device, dtype=config.vae_dtype)
|
|
508
513
|
|
|
509
514
|
with LoRAContext():
|
|
510
|
-
attn_kwargs = {
|
|
511
|
-
"attn_impl": config.dit_attn_impl.value,
|
|
512
|
-
"sparge_smooth_k": config.sparge_smooth_k,
|
|
513
|
-
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
514
|
-
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
515
|
-
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
516
|
-
}
|
|
517
515
|
if config.use_fbcache:
|
|
518
516
|
dit = FluxDiTFBCache.from_state_dict(
|
|
519
517
|
state_dicts.model,
|
|
520
518
|
device=("cpu" if config.use_fsdp else init_device),
|
|
521
519
|
dtype=config.model_dtype,
|
|
522
520
|
in_channel=config.control_type.get_in_channel(),
|
|
523
|
-
attn_kwargs=attn_kwargs,
|
|
524
521
|
relative_l1_threshold=config.fbcache_relative_l1_threshold,
|
|
525
522
|
)
|
|
526
523
|
else:
|
|
@@ -529,7 +526,6 @@ class FluxImagePipeline(BasePipeline):
|
|
|
529
526
|
device=("cpu" if config.use_fsdp else init_device),
|
|
530
527
|
dtype=config.model_dtype,
|
|
531
528
|
in_channel=config.control_type.get_in_channel(),
|
|
532
|
-
attn_kwargs=attn_kwargs,
|
|
533
529
|
)
|
|
534
530
|
if config.use_fp8_linear:
|
|
535
531
|
enable_fp8_linear(dit)
|
|
@@ -755,6 +751,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
755
751
|
latents = latents.to(self.dtype)
|
|
756
752
|
self.load_models_to_device(["dit"])
|
|
757
753
|
|
|
754
|
+
attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
|
|
758
755
|
noise_pred = self.dit(
|
|
759
756
|
hidden_states=latents,
|
|
760
757
|
timestep=timestep,
|
|
@@ -766,6 +763,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
766
763
|
image_ids=image_ids,
|
|
767
764
|
controlnet_double_block_output=double_block_output,
|
|
768
765
|
controlnet_single_block_output=single_block_output,
|
|
766
|
+
attn_kwargs=attn_kwargs,
|
|
769
767
|
)
|
|
770
768
|
noise_pred = noise_pred[:, :image_seq_len]
|
|
771
769
|
noise_pred = self.dit.unpatchify(noise_pred, height, width)
|
|
@@ -887,6 +885,8 @@ class FluxImagePipeline(BasePipeline):
|
|
|
887
885
|
if self.offload_mode is not None:
|
|
888
886
|
empty_cache()
|
|
889
887
|
param.model.to(self.device)
|
|
888
|
+
|
|
889
|
+
attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
|
|
890
890
|
double_block_output, single_block_output = param.model(
|
|
891
891
|
hidden_states=latents,
|
|
892
892
|
control_condition=control_condition,
|
|
@@ -897,6 +897,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
897
897
|
image_ids=image_ids,
|
|
898
898
|
text_ids=text_ids,
|
|
899
899
|
guidance=guidance,
|
|
900
|
+
attn_kwargs=attn_kwargs,
|
|
900
901
|
)
|
|
901
902
|
if self.offload_mode is not None:
|
|
902
903
|
param.model.to("cpu")
|
|
@@ -91,7 +91,7 @@ class QwenImageLoRAConverter(LoRAStateDictConverter):
|
|
|
91
91
|
if "lora_A.weight" in key:
|
|
92
92
|
lora_a_suffix = "lora_A.weight"
|
|
93
93
|
lora_b_suffix = "lora_B.weight"
|
|
94
|
-
|
|
94
|
+
|
|
95
95
|
if lora_a_suffix is None:
|
|
96
96
|
continue
|
|
97
97
|
|
|
@@ -253,19 +253,11 @@ class QwenImagePipeline(BasePipeline):
|
|
|
253
253
|
)
|
|
254
254
|
|
|
255
255
|
with LoRAContext():
|
|
256
|
-
attn_kwargs = {
|
|
257
|
-
"attn_impl": config.dit_attn_impl.value,
|
|
258
|
-
"sparge_smooth_k": config.sparge_smooth_k,
|
|
259
|
-
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
260
|
-
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
261
|
-
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
262
|
-
}
|
|
263
256
|
if config.use_fbcache:
|
|
264
257
|
dit = QwenImageDiTFBCache.from_state_dict(
|
|
265
258
|
state_dicts.model,
|
|
266
259
|
device=("cpu" if config.use_fsdp else init_device),
|
|
267
260
|
dtype=config.model_dtype,
|
|
268
|
-
attn_kwargs=attn_kwargs,
|
|
269
261
|
relative_l1_threshold=config.fbcache_relative_l1_threshold,
|
|
270
262
|
)
|
|
271
263
|
else:
|
|
@@ -273,7 +265,6 @@ class QwenImagePipeline(BasePipeline):
|
|
|
273
265
|
state_dicts.model,
|
|
274
266
|
device=("cpu" if config.use_fsdp else init_device),
|
|
275
267
|
dtype=config.model_dtype,
|
|
276
|
-
attn_kwargs=attn_kwargs,
|
|
277
268
|
)
|
|
278
269
|
if config.use_fp8_linear:
|
|
279
270
|
enable_fp8_linear(dit)
|
|
@@ -548,6 +539,7 @@ class QwenImagePipeline(BasePipeline):
|
|
|
548
539
|
entity_masks: Optional[List[torch.Tensor]] = None,
|
|
549
540
|
):
|
|
550
541
|
self.load_models_to_device(["dit"])
|
|
542
|
+
attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
|
|
551
543
|
noise_pred = self.dit(
|
|
552
544
|
image=latents,
|
|
553
545
|
edit=image_latents,
|
|
@@ -558,6 +550,7 @@ class QwenImagePipeline(BasePipeline):
|
|
|
558
550
|
entity_text=entity_prompt_embs,
|
|
559
551
|
entity_seq_lens=[mask.sum(dim=1) for mask in entity_prompt_emb_masks] if entity_prompt_emb_masks else None,
|
|
560
552
|
entity_masks=entity_masks,
|
|
553
|
+
attn_kwargs=attn_kwargs,
|
|
561
554
|
)
|
|
562
555
|
return noise_pred
|
|
563
556
|
|
|
@@ -394,6 +394,7 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
|
|
|
394
394
|
void_audio_input: torch.Tensor | None = None,
|
|
395
395
|
):
|
|
396
396
|
latents = latents.to(dtype=self.config.model_dtype, device=self.device)
|
|
397
|
+
attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
|
|
397
398
|
|
|
398
399
|
noise_pred = model(
|
|
399
400
|
x=latents,
|
|
@@ -408,6 +409,7 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
|
|
|
408
409
|
drop_motion_frames=drop_motion_frames,
|
|
409
410
|
audio_mask=audio_mask,
|
|
410
411
|
void_audio_input=void_audio_input,
|
|
412
|
+
attn_kwargs=attn_kwargs,
|
|
411
413
|
)
|
|
412
414
|
return noise_pred
|
|
413
415
|
|
|
@@ -654,19 +656,12 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
|
|
|
654
656
|
)
|
|
655
657
|
|
|
656
658
|
with LoRAContext():
|
|
657
|
-
attn_kwargs = {
|
|
658
|
-
"attn_impl": config.dit_attn_impl.value,
|
|
659
|
-
"sparge_smooth_k": config.sparge_smooth_k,
|
|
660
|
-
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
661
|
-
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
662
|
-
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
663
|
-
}
|
|
664
659
|
dit = WanS2VDiT.from_state_dict(
|
|
665
660
|
state_dicts.model,
|
|
666
661
|
config=model_config,
|
|
667
662
|
device=("cpu" if config.use_fsdp else init_device),
|
|
668
663
|
dtype=config.model_dtype,
|
|
669
|
-
|
|
664
|
+
use_vsa=(config.dit_attn_impl.value == "vsa"),
|
|
670
665
|
)
|
|
671
666
|
if config.use_fp8_linear:
|
|
672
667
|
enable_fp8_linear(dit)
|
|
@@ -323,6 +323,7 @@ class WanVideoPipeline(BasePipeline):
|
|
|
323
323
|
|
|
324
324
|
def predict_noise(self, model, latents, image_clip_feature, image_y, timestep, context):
|
|
325
325
|
latents = latents.to(dtype=self.config.model_dtype, device=self.device)
|
|
326
|
+
attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
|
|
326
327
|
|
|
327
328
|
noise_pred = model(
|
|
328
329
|
x=latents,
|
|
@@ -330,6 +331,7 @@ class WanVideoPipeline(BasePipeline):
|
|
|
330
331
|
context=context,
|
|
331
332
|
clip_feature=image_clip_feature,
|
|
332
333
|
y=image_y,
|
|
334
|
+
attn_kwargs=attn_kwargs,
|
|
333
335
|
)
|
|
334
336
|
return noise_pred
|
|
335
337
|
|
|
@@ -578,19 +580,12 @@ class WanVideoPipeline(BasePipeline):
|
|
|
578
580
|
dit_state_dict = state_dicts.model
|
|
579
581
|
|
|
580
582
|
with LoRAContext():
|
|
581
|
-
attn_kwargs = {
|
|
582
|
-
"attn_impl": config.dit_attn_impl.value,
|
|
583
|
-
"sparge_smooth_k": config.sparge_smooth_k,
|
|
584
|
-
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
585
|
-
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
586
|
-
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
587
|
-
}
|
|
588
583
|
dit = WanDiT.from_state_dict(
|
|
589
584
|
dit_state_dict,
|
|
590
585
|
config=dit_config,
|
|
591
586
|
device=("cpu" if config.use_fsdp else init_device),
|
|
592
587
|
dtype=config.model_dtype,
|
|
593
|
-
|
|
588
|
+
use_vsa=(config.dit_attn_impl.value == "vsa"),
|
|
594
589
|
)
|
|
595
590
|
if config.use_fp8_linear:
|
|
596
591
|
enable_fp8_linear(dit)
|
|
@@ -602,7 +597,7 @@ class WanVideoPipeline(BasePipeline):
|
|
|
602
597
|
config=dit_config,
|
|
603
598
|
device=("cpu" if config.use_fsdp else init_device),
|
|
604
599
|
dtype=config.model_dtype,
|
|
605
|
-
|
|
600
|
+
use_vsa=(config.dit_attn_impl.value == "vsa"),
|
|
606
601
|
)
|
|
607
602
|
if config.use_fp8_linear:
|
|
608
603
|
enable_fp8_linear(dit2)
|
|
@@ -640,19 +635,22 @@ class WanVideoPipeline(BasePipeline):
|
|
|
640
635
|
@staticmethod
|
|
641
636
|
def _get_dit_type(model_state_dict: Dict[str, torch.Tensor] | Dict[str, Dict[str, torch.Tensor]]) -> str:
|
|
642
637
|
# determine wan dit type by model params
|
|
638
|
+
def has_any_key(*xs):
|
|
639
|
+
return any(x in model_state_dict for x in xs)
|
|
640
|
+
|
|
643
641
|
dit_type = None
|
|
644
|
-
if "high_noise_model"
|
|
642
|
+
if has_any_key("high_noise_model"):
|
|
645
643
|
if model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 36:
|
|
646
644
|
dit_type = "wan2.2-i2v-a14b"
|
|
647
645
|
elif model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 16:
|
|
648
646
|
dit_type = "wan2.2-t2v-a14b"
|
|
649
647
|
elif model_state_dict["patch_embedding.weight"].shape[1] == 48:
|
|
650
648
|
dit_type = "wan2.2-ti2v-5b"
|
|
651
|
-
elif "img_emb.emb_pos"
|
|
649
|
+
elif has_any_key("img_emb.emb_pos", "condition_embedder.image_embedder.pos_embed"):
|
|
652
650
|
dit_type = "wan2.1-flf2v-14b"
|
|
653
|
-
elif "img_emb.proj.0.weight"
|
|
651
|
+
elif has_any_key("img_emb.proj.0.weight", "condition_embedder.image_embedder.norm1"):
|
|
654
652
|
dit_type = "wan2.1-i2v-14b"
|
|
655
|
-
elif "blocks.39.self_attn.norm_q.weight"
|
|
653
|
+
elif has_any_key("blocks.39.self_attn.norm_q.weight", "blocks.39.attn1.norm_q.weight"):
|
|
656
654
|
dit_type = "wan2.1-t2v-14b"
|
|
657
655
|
else:
|
|
658
656
|
dit_type = "wan2.1-t2v-1.3b"
|
|
@@ -27,18 +27,19 @@ SD3_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd3", "sd3_tex
|
|
|
27
27
|
SDXL_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_text_encoder.json")
|
|
28
28
|
SDXL_UNET_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_unet.json")
|
|
29
29
|
|
|
30
|
-
WAN2_1_DIT_T2V_1_3B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
31
|
-
WAN2_1_DIT_T2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
32
|
-
WAN2_1_DIT_I2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
33
|
-
WAN2_1_DIT_FLF2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
34
|
-
WAN2_2_DIT_TI2V_5B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
35
|
-
WAN2_2_DIT_T2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
36
|
-
WAN2_2_DIT_I2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
37
|
-
WAN2_2_DIT_S2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
30
|
+
WAN2_1_DIT_T2V_1_3B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_t2v_1.3b.json")
|
|
31
|
+
WAN2_1_DIT_T2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_t2v_14b.json")
|
|
32
|
+
WAN2_1_DIT_I2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_i2v_14b.json")
|
|
33
|
+
WAN2_1_DIT_FLF2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_flf2v_14b.json")
|
|
34
|
+
WAN2_2_DIT_TI2V_5B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_ti2v_5b.json")
|
|
35
|
+
WAN2_2_DIT_T2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_t2v_a14b.json")
|
|
36
|
+
WAN2_2_DIT_I2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_i2v_a14b.json")
|
|
37
|
+
WAN2_2_DIT_S2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_s2v_14b.json")
|
|
38
|
+
WAN_DIT_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan_dit_keymap.json")
|
|
39
|
+
|
|
40
|
+
WAN2_1_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.1_vae.json")
|
|
41
|
+
WAN2_2_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.2_vae.json")
|
|
42
|
+
WAN_VAE_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan_vae_keymap.json")
|
|
42
43
|
|
|
43
44
|
QWEN_IMAGE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "qwen_image", "qwen2_5_vl_config.json")
|
|
44
45
|
QWEN_IMAGE_VISION_CONFIG_FILE = os.path.join(CONF_PATH, "models", "qwen_image", "qwen2_5_vl_vision_config.json")
|
diffsynth_engine/utils/flag.py
CHANGED
|
@@ -44,3 +44,9 @@ if SPARGE_ATTN_AVAILABLE:
|
|
|
44
44
|
logger.info("Sparge attention is available")
|
|
45
45
|
else:
|
|
46
46
|
logger.info("Sparge attention is not available")
|
|
47
|
+
|
|
48
|
+
VIDEO_SPARSE_ATTN_AVAILABLE = importlib.util.find_spec("vsa") is not None
|
|
49
|
+
if VIDEO_SPARSE_ATTN_AVAILABLE:
|
|
50
|
+
logger.info("Video sparse attention is available")
|
|
51
|
+
else:
|
|
52
|
+
logger.info("Video sparse attention is not available")
|
|
@@ -40,10 +40,14 @@ class ProcessGroupSingleton(Singleton):
|
|
|
40
40
|
def __init__(self):
|
|
41
41
|
self.CFG_GROUP: Optional[dist.ProcessGroup] = None
|
|
42
42
|
self.SP_GROUP: Optional[dist.ProcessGroup] = None
|
|
43
|
+
self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None
|
|
44
|
+
self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None
|
|
43
45
|
self.TP_GROUP: Optional[dist.ProcessGroup] = None
|
|
44
46
|
|
|
45
47
|
self.CFG_RANKS: List[int] = []
|
|
46
48
|
self.SP_RANKS: List[int] = []
|
|
49
|
+
self.SP_ULYSSUES_RANKS: List[int] = []
|
|
50
|
+
self.SP_RING_RANKS: List[int] = []
|
|
47
51
|
self.TP_RANKS: List[int] = []
|
|
48
52
|
|
|
49
53
|
|
|
@@ -82,6 +86,38 @@ def get_sp_ranks():
|
|
|
82
86
|
return PROCESS_GROUP.SP_RANKS
|
|
83
87
|
|
|
84
88
|
|
|
89
|
+
def get_sp_ulysses_group():
|
|
90
|
+
return PROCESS_GROUP.SP_ULYSSUES_GROUP
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def get_sp_ulysses_world_size():
|
|
94
|
+
return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_sp_ulysses_rank():
|
|
98
|
+
return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_sp_ulysses_ranks():
|
|
102
|
+
return PROCESS_GROUP.SP_ULYSSUES_RANKS
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def get_sp_ring_group():
|
|
106
|
+
return PROCESS_GROUP.SP_RING_GROUP
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def get_sp_ring_world_size():
|
|
110
|
+
return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def get_sp_ring_rank():
|
|
114
|
+
return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def get_sp_ring_ranks():
|
|
118
|
+
return PROCESS_GROUP.SP_RING_RANKS
|
|
119
|
+
|
|
120
|
+
|
|
85
121
|
def get_tp_group():
|
|
86
122
|
return PROCESS_GROUP.TP_GROUP
|
|
87
123
|
|
|
@@ -127,23 +163,32 @@ def init_parallel_pgs(
|
|
|
127
163
|
blocks = [list(range(world_size))]
|
|
128
164
|
cfg_groups, cfg_blocks = make_parallel_groups(blocks, cfg_degree)
|
|
129
165
|
for cfg_ranks in cfg_groups:
|
|
130
|
-
cfg_group = dist.new_group(cfg_ranks)
|
|
131
166
|
if rank in cfg_ranks:
|
|
132
|
-
PROCESS_GROUP.CFG_GROUP =
|
|
167
|
+
PROCESS_GROUP.CFG_GROUP = dist.new_group(cfg_ranks)
|
|
133
168
|
PROCESS_GROUP.CFG_RANKS = cfg_ranks
|
|
134
169
|
|
|
135
170
|
sp_groups, sp_blocks = make_parallel_groups(cfg_blocks, sp_degree)
|
|
136
171
|
for sp_ranks in sp_groups:
|
|
137
|
-
group = dist.new_group(sp_ranks)
|
|
138
172
|
if rank in sp_ranks:
|
|
139
|
-
PROCESS_GROUP.SP_GROUP =
|
|
173
|
+
PROCESS_GROUP.SP_GROUP = dist.new_group(sp_ranks)
|
|
140
174
|
PROCESS_GROUP.SP_RANKS = sp_ranks
|
|
141
175
|
|
|
176
|
+
sp_ulysses_groups, sp_ulysses_blocks = make_parallel_groups(cfg_blocks, sp_ulysses_degree)
|
|
177
|
+
for sp_ulysses_ranks in sp_ulysses_groups:
|
|
178
|
+
if rank in sp_ulysses_ranks:
|
|
179
|
+
PROCESS_GROUP.SP_ULYSSUES_GROUP = dist.new_group(sp_ulysses_ranks)
|
|
180
|
+
PROCESS_GROUP.SP_ULYSSUES_RANKS = sp_ulysses_ranks
|
|
181
|
+
|
|
182
|
+
sp_ring_groups, _ = make_parallel_groups(sp_ulysses_blocks, sp_ring_degree)
|
|
183
|
+
for sp_ring_ranks in sp_ring_groups:
|
|
184
|
+
if rank in sp_ring_ranks:
|
|
185
|
+
PROCESS_GROUP.SP_RING_GROUP = dist.new_group(sp_ring_ranks)
|
|
186
|
+
PROCESS_GROUP.SP_RING_RANKS = sp_ring_ranks
|
|
187
|
+
|
|
142
188
|
tp_groups, _ = make_parallel_groups(sp_blocks, tp_degree)
|
|
143
189
|
for tp_ranks in tp_groups:
|
|
144
|
-
group = dist.new_group(tp_ranks)
|
|
145
190
|
if rank in tp_ranks:
|
|
146
|
-
PROCESS_GROUP.TP_GROUP =
|
|
191
|
+
PROCESS_GROUP.TP_GROUP = dist.new_group(tp_ranks)
|
|
147
192
|
PROCESS_GROUP.TP_RANKS = tp_ranks
|
|
148
193
|
|
|
149
194
|
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
|