diffsynth-engine 0.6.1.dev22__py3-none-any.whl → 0.6.1.dev24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
- diffsynth_engine/configs/pipeline.py +35 -12
- diffsynth_engine/models/basic/attention.py +59 -20
- diffsynth_engine/models/basic/transformer_helper.py +36 -2
- diffsynth_engine/models/basic/video_sparse_attention.py +235 -0
- diffsynth_engine/models/flux/flux_controlnet.py +7 -19
- diffsynth_engine/models/flux/flux_dit.py +22 -36
- diffsynth_engine/models/flux/flux_dit_fbcache.py +9 -7
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/qwen_image/qwen_image_dit.py +26 -32
- diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
- diffsynth_engine/models/wan/wan_dit.py +62 -22
- diffsynth_engine/pipelines/flux_image.py +11 -10
- diffsynth_engine/pipelines/qwen_image.py +16 -15
- diffsynth_engine/pipelines/utils.py +52 -0
- diffsynth_engine/pipelines/wan_s2v.py +3 -8
- diffsynth_engine/pipelines/wan_video.py +11 -13
- diffsynth_engine/tokenizers/base.py +6 -0
- diffsynth_engine/tokenizers/qwen2.py +12 -4
- diffsynth_engine/utils/constants.py +13 -12
- diffsynth_engine/utils/flag.py +6 -0
- diffsynth_engine/utils/parallel.py +51 -6
- {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev24.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev24.dist-info}/RECORD +38 -36
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
- {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev24.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev24.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev24.dist-info}/top_level.txt +0 -0
|
@@ -17,6 +17,7 @@ from diffsynth_engine.utils.constants import (
|
|
|
17
17
|
WAN2_2_DIT_TI2V_5B_CONFIG_FILE,
|
|
18
18
|
WAN2_2_DIT_I2V_A14B_CONFIG_FILE,
|
|
19
19
|
WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
|
|
20
|
+
WAN_DIT_KEYMAP_FILE,
|
|
20
21
|
)
|
|
21
22
|
from diffsynth_engine.utils.gguf import gguf_inference
|
|
22
23
|
from diffsynth_engine.utils.fp8_linear import fp8_inference
|
|
@@ -30,6 +31,9 @@ from diffsynth_engine.utils.parallel import (
|
|
|
30
31
|
T5_TOKEN_NUM = 512
|
|
31
32
|
FLF_TOKEN_NUM = 257 * 2
|
|
32
33
|
|
|
34
|
+
with open(WAN_DIT_KEYMAP_FILE, "r", encoding="utf-8") as f:
|
|
35
|
+
config = json.load(f)
|
|
36
|
+
|
|
33
37
|
|
|
34
38
|
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
|
35
39
|
return x * (1 + scale) + shift
|
|
@@ -73,7 +77,7 @@ class SelfAttention(nn.Module):
|
|
|
73
77
|
dim: int,
|
|
74
78
|
num_heads: int,
|
|
75
79
|
eps: float = 1e-6,
|
|
76
|
-
|
|
80
|
+
use_vsa: bool = False,
|
|
77
81
|
device: str = "cuda:0",
|
|
78
82
|
dtype: torch.dtype = torch.bfloat16,
|
|
79
83
|
):
|
|
@@ -86,19 +90,25 @@ class SelfAttention(nn.Module):
|
|
|
86
90
|
self.o = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
87
91
|
self.norm_q = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
88
92
|
self.norm_k = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
89
|
-
self.
|
|
93
|
+
self.gate_compress = nn.Linear(dim, dim, device=device, dtype=dtype) if use_vsa else None
|
|
90
94
|
|
|
91
|
-
def forward(self, x, freqs):
|
|
95
|
+
def forward(self, x, freqs, attn_kwargs=None):
|
|
92
96
|
q, k, v = self.norm_q(self.q(x)), self.norm_k(self.k(x)), self.v(x)
|
|
97
|
+
g = self.gate_compress(x) if self.gate_compress is not None else None
|
|
98
|
+
|
|
93
99
|
num_heads = q.shape[2] // self.head_dim
|
|
94
100
|
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
|
95
101
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
|
96
102
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
|
103
|
+
g = rearrange(g, "b s (n d) -> b s n d", n=num_heads) if g is not None else None
|
|
104
|
+
|
|
105
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
97
106
|
x = attention_ops.attention(
|
|
98
107
|
q=rope_apply(q, freqs),
|
|
99
108
|
k=rope_apply(k, freqs),
|
|
100
109
|
v=v,
|
|
101
|
-
|
|
110
|
+
g=g,
|
|
111
|
+
**attn_kwargs,
|
|
102
112
|
)
|
|
103
113
|
x = x.flatten(2)
|
|
104
114
|
return self.o(x)
|
|
@@ -111,7 +121,6 @@ class CrossAttention(nn.Module):
|
|
|
111
121
|
num_heads: int,
|
|
112
122
|
eps: float = 1e-6,
|
|
113
123
|
has_image_input: bool = False,
|
|
114
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
115
124
|
device: str = "cuda:0",
|
|
116
125
|
dtype: torch.dtype = torch.bfloat16,
|
|
117
126
|
):
|
|
@@ -130,9 +139,8 @@ class CrossAttention(nn.Module):
|
|
|
130
139
|
self.k_img = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
131
140
|
self.v_img = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
132
141
|
self.norm_k_img = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
133
|
-
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
134
142
|
|
|
135
|
-
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
143
|
+
def forward(self, x: torch.Tensor, y: torch.Tensor, attn_kwargs=None):
|
|
136
144
|
if self.has_image_input:
|
|
137
145
|
img = y[:, :-T5_TOKEN_NUM]
|
|
138
146
|
ctx = y[:, -T5_TOKEN_NUM:]
|
|
@@ -144,12 +152,16 @@ class CrossAttention(nn.Module):
|
|
|
144
152
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
|
145
153
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
|
146
154
|
|
|
147
|
-
|
|
155
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
156
|
+
if attn_kwargs.get("attn_impl", None) == "vsa":
|
|
157
|
+
attn_kwargs = attn_kwargs.copy()
|
|
158
|
+
attn_kwargs["attn_impl"] = "sdpa"
|
|
159
|
+
x = attention(q, k, v, **attn_kwargs).flatten(2)
|
|
148
160
|
if self.has_image_input:
|
|
149
161
|
k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
|
|
150
162
|
k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
|
|
151
163
|
v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
|
|
152
|
-
y = attention(q, k_img, v_img, **
|
|
164
|
+
y = attention(q, k_img, v_img, **attn_kwargs).flatten(2)
|
|
153
165
|
x = x + y
|
|
154
166
|
return self.o(x)
|
|
155
167
|
|
|
@@ -162,7 +174,7 @@ class DiTBlock(nn.Module):
|
|
|
162
174
|
num_heads: int,
|
|
163
175
|
ffn_dim: int,
|
|
164
176
|
eps: float = 1e-6,
|
|
165
|
-
|
|
177
|
+
use_vsa: bool = False,
|
|
166
178
|
device: str = "cuda:0",
|
|
167
179
|
dtype: torch.dtype = torch.bfloat16,
|
|
168
180
|
):
|
|
@@ -170,9 +182,9 @@ class DiTBlock(nn.Module):
|
|
|
170
182
|
self.dim = dim
|
|
171
183
|
self.num_heads = num_heads
|
|
172
184
|
self.ffn_dim = ffn_dim
|
|
173
|
-
self.self_attn = SelfAttention(dim, num_heads, eps,
|
|
185
|
+
self.self_attn = SelfAttention(dim, num_heads, eps, use_vsa=use_vsa, device=device, dtype=dtype)
|
|
174
186
|
self.cross_attn = CrossAttention(
|
|
175
|
-
dim, num_heads, eps, has_image_input=has_image_input,
|
|
187
|
+
dim, num_heads, eps, has_image_input=has_image_input, device=device, dtype=dtype
|
|
176
188
|
)
|
|
177
189
|
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
|
|
178
190
|
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
|
|
@@ -184,14 +196,14 @@ class DiTBlock(nn.Module):
|
|
|
184
196
|
)
|
|
185
197
|
self.modulation = nn.Parameter(torch.randn(1, 6, dim, device=device, dtype=dtype) / dim**0.5)
|
|
186
198
|
|
|
187
|
-
def forward(self, x, context, t_mod, freqs):
|
|
199
|
+
def forward(self, x, context, t_mod, freqs, attn_kwargs=None):
|
|
188
200
|
# msa: multi-head self-attention mlp: multi-layer perceptron
|
|
189
201
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
|
190
202
|
t.squeeze(1) for t in (self.modulation + t_mod).chunk(6, dim=1)
|
|
191
203
|
]
|
|
192
204
|
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
|
193
|
-
x = x + gate_msa * self.self_attn(input_x, freqs)
|
|
194
|
-
x = x + self.cross_attn(self.norm3(x), context)
|
|
205
|
+
x = x + gate_msa * self.self_attn(input_x, freqs, attn_kwargs)
|
|
206
|
+
x = x + self.cross_attn(self.norm3(x), context, attn_kwargs)
|
|
195
207
|
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
|
196
208
|
x = x + gate_mlp * self.ffn(input_x)
|
|
197
209
|
return x
|
|
@@ -249,7 +261,26 @@ class Head(nn.Module):
|
|
|
249
261
|
|
|
250
262
|
|
|
251
263
|
class WanDiTStateDictConverter(StateDictConverter):
|
|
264
|
+
def _from_diffusers(self, state_dict):
|
|
265
|
+
global_rename_dict = config["diffusers"]["global_rename_dict"]
|
|
266
|
+
rename_dict = config["diffusers"]["rename_dict"]
|
|
267
|
+
state_dict_ = {}
|
|
268
|
+
for name, param in state_dict.items():
|
|
269
|
+
suffix = ""
|
|
270
|
+
suffix = ".weight" if name.endswith(".weight") else suffix
|
|
271
|
+
suffix = ".bias" if name.endswith(".bias") else suffix
|
|
272
|
+
prefix = name[: -len(suffix)] if suffix else name
|
|
273
|
+
if prefix in global_rename_dict:
|
|
274
|
+
state_dict_[f"{global_rename_dict[prefix]}{suffix}"] = param
|
|
275
|
+
if prefix.startswith("blocks."):
|
|
276
|
+
_, idx, middle = prefix.split(".", 2)
|
|
277
|
+
if middle in rename_dict:
|
|
278
|
+
state_dict_[f"blocks.{idx}.{rename_dict[middle]}{suffix}"] = param
|
|
279
|
+
return state_dict_
|
|
280
|
+
|
|
252
281
|
def convert(self, state_dict):
|
|
282
|
+
if "condition_embedder.time_proj.weight" in state_dict:
|
|
283
|
+
return self._from_diffusers(state_dict)
|
|
253
284
|
return state_dict
|
|
254
285
|
|
|
255
286
|
|
|
@@ -273,7 +304,7 @@ class WanDiT(PreTrainedModel):
|
|
|
273
304
|
has_vae_feature: bool = False,
|
|
274
305
|
fuse_image_latents: bool = False,
|
|
275
306
|
flf_pos_emb: bool = False,
|
|
276
|
-
|
|
307
|
+
use_vsa: bool = False,
|
|
277
308
|
device: str = "cuda:0",
|
|
278
309
|
dtype: torch.dtype = torch.bfloat16,
|
|
279
310
|
):
|
|
@@ -307,7 +338,16 @@ class WanDiT(PreTrainedModel):
|
|
|
307
338
|
)
|
|
308
339
|
self.blocks = nn.ModuleList(
|
|
309
340
|
[
|
|
310
|
-
DiTBlock(
|
|
341
|
+
DiTBlock(
|
|
342
|
+
has_clip_feature,
|
|
343
|
+
dim,
|
|
344
|
+
num_heads,
|
|
345
|
+
ffn_dim,
|
|
346
|
+
eps,
|
|
347
|
+
use_vsa,
|
|
348
|
+
device=device,
|
|
349
|
+
dtype=dtype,
|
|
350
|
+
)
|
|
311
351
|
for _ in range(num_layers)
|
|
312
352
|
]
|
|
313
353
|
)
|
|
@@ -344,6 +384,7 @@ class WanDiT(PreTrainedModel):
|
|
|
344
384
|
timestep: torch.Tensor,
|
|
345
385
|
clip_feature: Optional[torch.Tensor] = None, # clip_vision_encoder(img)
|
|
346
386
|
y: Optional[torch.Tensor] = None, # vae_encoder(img)
|
|
387
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
347
388
|
):
|
|
348
389
|
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
|
|
349
390
|
use_cfg = x.shape[0] > 1
|
|
@@ -376,7 +417,7 @@ class WanDiT(PreTrainedModel):
|
|
|
376
417
|
|
|
377
418
|
with sequence_parallel((x, t, t_mod, freqs), seq_dims=(1, 0, 0, 0)):
|
|
378
419
|
for block in self.blocks:
|
|
379
|
-
x = block(x, context, t_mod, freqs)
|
|
420
|
+
x = block(x, context, t_mod, freqs, attn_kwargs)
|
|
380
421
|
x = self.head(x, t)
|
|
381
422
|
(x,) = sequence_parallel_unshard((x,), seq_dims=(1,), seq_lens=(f * h * w,))
|
|
382
423
|
x = self.unpatchify(x, (f, h, w))
|
|
@@ -409,12 +450,11 @@ class WanDiT(PreTrainedModel):
|
|
|
409
450
|
config: Dict[str, Any],
|
|
410
451
|
device: str = "cuda:0",
|
|
411
452
|
dtype: torch.dtype = torch.bfloat16,
|
|
412
|
-
|
|
413
|
-
assign: bool = True,
|
|
453
|
+
use_vsa: bool = False,
|
|
414
454
|
):
|
|
415
|
-
model = cls(**config, device="meta", dtype=dtype,
|
|
455
|
+
model = cls(**config, device="meta", dtype=dtype, use_vsa=use_vsa)
|
|
416
456
|
model = model.requires_grad_(False)
|
|
417
|
-
model.load_state_dict(state_dict, assign=
|
|
457
|
+
model.load_state_dict(state_dict, assign=True)
|
|
418
458
|
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
419
459
|
return model
|
|
420
460
|
|
|
@@ -17,7 +17,12 @@ from diffsynth_engine.models.flux import (
|
|
|
17
17
|
flux_dit_config,
|
|
18
18
|
flux_text_encoder_config,
|
|
19
19
|
)
|
|
20
|
-
from diffsynth_engine.configs import
|
|
20
|
+
from diffsynth_engine.configs import (
|
|
21
|
+
FluxPipelineConfig,
|
|
22
|
+
FluxStateDicts,
|
|
23
|
+
ControlType,
|
|
24
|
+
ControlNetParams,
|
|
25
|
+
)
|
|
21
26
|
from diffsynth_engine.models.basic.lora import LoRAContext
|
|
22
27
|
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
|
|
23
28
|
from diffsynth_engine.pipelines.utils import accumulate, calculate_shift
|
|
@@ -507,20 +512,12 @@ class FluxImagePipeline(BasePipeline):
|
|
|
507
512
|
vae_encoder = FluxVAEEncoder.from_state_dict(state_dicts.vae, device=init_device, dtype=config.vae_dtype)
|
|
508
513
|
|
|
509
514
|
with LoRAContext():
|
|
510
|
-
attn_kwargs = {
|
|
511
|
-
"attn_impl": config.dit_attn_impl.value,
|
|
512
|
-
"sparge_smooth_k": config.sparge_smooth_k,
|
|
513
|
-
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
514
|
-
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
515
|
-
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
516
|
-
}
|
|
517
515
|
if config.use_fbcache:
|
|
518
516
|
dit = FluxDiTFBCache.from_state_dict(
|
|
519
517
|
state_dicts.model,
|
|
520
518
|
device=("cpu" if config.use_fsdp else init_device),
|
|
521
519
|
dtype=config.model_dtype,
|
|
522
520
|
in_channel=config.control_type.get_in_channel(),
|
|
523
|
-
attn_kwargs=attn_kwargs,
|
|
524
521
|
relative_l1_threshold=config.fbcache_relative_l1_threshold,
|
|
525
522
|
)
|
|
526
523
|
else:
|
|
@@ -529,7 +526,6 @@ class FluxImagePipeline(BasePipeline):
|
|
|
529
526
|
device=("cpu" if config.use_fsdp else init_device),
|
|
530
527
|
dtype=config.model_dtype,
|
|
531
528
|
in_channel=config.control_type.get_in_channel(),
|
|
532
|
-
attn_kwargs=attn_kwargs,
|
|
533
529
|
)
|
|
534
530
|
if config.use_fp8_linear:
|
|
535
531
|
enable_fp8_linear(dit)
|
|
@@ -755,6 +751,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
755
751
|
latents = latents.to(self.dtype)
|
|
756
752
|
self.load_models_to_device(["dit"])
|
|
757
753
|
|
|
754
|
+
attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
|
|
758
755
|
noise_pred = self.dit(
|
|
759
756
|
hidden_states=latents,
|
|
760
757
|
timestep=timestep,
|
|
@@ -766,6 +763,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
766
763
|
image_ids=image_ids,
|
|
767
764
|
controlnet_double_block_output=double_block_output,
|
|
768
765
|
controlnet_single_block_output=single_block_output,
|
|
766
|
+
attn_kwargs=attn_kwargs,
|
|
769
767
|
)
|
|
770
768
|
noise_pred = noise_pred[:, :image_seq_len]
|
|
771
769
|
noise_pred = self.dit.unpatchify(noise_pred, height, width)
|
|
@@ -887,6 +885,8 @@ class FluxImagePipeline(BasePipeline):
|
|
|
887
885
|
if self.offload_mode is not None:
|
|
888
886
|
empty_cache()
|
|
889
887
|
param.model.to(self.device)
|
|
888
|
+
|
|
889
|
+
attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
|
|
890
890
|
double_block_output, single_block_output = param.model(
|
|
891
891
|
hidden_states=latents,
|
|
892
892
|
control_condition=control_condition,
|
|
@@ -897,6 +897,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
897
897
|
image_ids=image_ids,
|
|
898
898
|
text_ids=text_ids,
|
|
899
899
|
guidance=guidance,
|
|
900
|
+
attn_kwargs=attn_kwargs,
|
|
900
901
|
)
|
|
901
902
|
if self.offload_mode is not None:
|
|
902
903
|
param.model.to("cpu")
|
|
@@ -24,7 +24,7 @@ from diffsynth_engine.models.qwen_image import (
|
|
|
24
24
|
from diffsynth_engine.models.qwen_image import QwenImageVAE
|
|
25
25
|
from diffsynth_engine.tokenizers import Qwen2TokenizerFast, Qwen2VLProcessor
|
|
26
26
|
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
|
|
27
|
-
from diffsynth_engine.pipelines.utils import calculate_shift
|
|
27
|
+
from diffsynth_engine.pipelines.utils import calculate_shift, pad_and_concat
|
|
28
28
|
from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
|
|
29
29
|
from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
|
|
30
30
|
from diffsynth_engine.utils.constants import (
|
|
@@ -91,7 +91,7 @@ class QwenImageLoRAConverter(LoRAStateDictConverter):
|
|
|
91
91
|
if "lora_A.weight" in key:
|
|
92
92
|
lora_a_suffix = "lora_A.weight"
|
|
93
93
|
lora_b_suffix = "lora_B.weight"
|
|
94
|
-
|
|
94
|
+
|
|
95
95
|
if lora_a_suffix is None:
|
|
96
96
|
continue
|
|
97
97
|
|
|
@@ -148,9 +148,17 @@ class QwenImagePipeline(BasePipeline):
|
|
|
148
148
|
self.prompt_template_encode_start_idx = 34
|
|
149
149
|
# qwen image edit
|
|
150
150
|
self.edit_system_prompt = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
|
|
151
|
-
self.edit_prompt_template_encode =
|
|
151
|
+
self.edit_prompt_template_encode = (
|
|
152
|
+
"<|im_start|>system\n"
|
|
153
|
+
+ self.edit_system_prompt
|
|
154
|
+
+ "<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
|
155
|
+
)
|
|
152
156
|
# qwen image edit plus
|
|
153
|
-
self.edit_plus_prompt_template_encode =
|
|
157
|
+
self.edit_plus_prompt_template_encode = (
|
|
158
|
+
"<|im_start|>system\n"
|
|
159
|
+
+ self.edit_system_prompt
|
|
160
|
+
+ "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
|
161
|
+
)
|
|
154
162
|
|
|
155
163
|
self.edit_prompt_template_encode_start_idx = 64
|
|
156
164
|
|
|
@@ -253,19 +261,11 @@ class QwenImagePipeline(BasePipeline):
|
|
|
253
261
|
)
|
|
254
262
|
|
|
255
263
|
with LoRAContext():
|
|
256
|
-
attn_kwargs = {
|
|
257
|
-
"attn_impl": config.dit_attn_impl.value,
|
|
258
|
-
"sparge_smooth_k": config.sparge_smooth_k,
|
|
259
|
-
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
260
|
-
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
261
|
-
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
262
|
-
}
|
|
263
264
|
if config.use_fbcache:
|
|
264
265
|
dit = QwenImageDiTFBCache.from_state_dict(
|
|
265
266
|
state_dicts.model,
|
|
266
267
|
device=("cpu" if config.use_fsdp else init_device),
|
|
267
268
|
dtype=config.model_dtype,
|
|
268
|
-
attn_kwargs=attn_kwargs,
|
|
269
269
|
relative_l1_threshold=config.fbcache_relative_l1_threshold,
|
|
270
270
|
)
|
|
271
271
|
else:
|
|
@@ -273,7 +273,6 @@ class QwenImagePipeline(BasePipeline):
|
|
|
273
273
|
state_dicts.model,
|
|
274
274
|
device=("cpu" if config.use_fsdp else init_device),
|
|
275
275
|
dtype=config.model_dtype,
|
|
276
|
-
attn_kwargs=attn_kwargs,
|
|
277
276
|
)
|
|
278
277
|
if config.use_fp8_linear:
|
|
279
278
|
enable_fp8_linear(dit)
|
|
@@ -499,8 +498,8 @@ class QwenImagePipeline(BasePipeline):
|
|
|
499
498
|
else:
|
|
500
499
|
# cfg by predict noise in one batch
|
|
501
500
|
bs, _, h, w = latents.shape
|
|
502
|
-
prompt_emb =
|
|
503
|
-
prompt_emb_mask =
|
|
501
|
+
prompt_emb = pad_and_concat(prompt_emb, negative_prompt_emb)
|
|
502
|
+
prompt_emb_mask = pad_and_concat(prompt_emb_mask, negative_prompt_emb_mask)
|
|
504
503
|
if entity_prompt_embs is not None:
|
|
505
504
|
entity_prompt_embs = [
|
|
506
505
|
torch.cat([x, y], dim=0) for x, y in zip(entity_prompt_embs, negative_entity_prompt_embs)
|
|
@@ -548,6 +547,7 @@ class QwenImagePipeline(BasePipeline):
|
|
|
548
547
|
entity_masks: Optional[List[torch.Tensor]] = None,
|
|
549
548
|
):
|
|
550
549
|
self.load_models_to_device(["dit"])
|
|
550
|
+
attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
|
|
551
551
|
noise_pred = self.dit(
|
|
552
552
|
image=latents,
|
|
553
553
|
edit=image_latents,
|
|
@@ -558,6 +558,7 @@ class QwenImagePipeline(BasePipeline):
|
|
|
558
558
|
entity_text=entity_prompt_embs,
|
|
559
559
|
entity_seq_lens=[mask.sum(dim=1) for mask in entity_prompt_emb_masks] if entity_prompt_emb_masks else None,
|
|
560
560
|
entity_masks=entity_masks,
|
|
561
|
+
attn_kwargs=attn_kwargs,
|
|
561
562
|
)
|
|
562
563
|
return noise_pred
|
|
563
564
|
|
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
|
|
1
5
|
def accumulate(result, new_item):
|
|
2
6
|
if result is None:
|
|
3
7
|
return new_item
|
|
@@ -17,3 +21,51 @@ def calculate_shift(
|
|
|
17
21
|
b = base_shift - m * base_seq_len
|
|
18
22
|
mu = image_seq_len * m + b
|
|
19
23
|
return mu
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def pad_and_concat(
|
|
27
|
+
tensor1: torch.Tensor,
|
|
28
|
+
tensor2: torch.Tensor,
|
|
29
|
+
concat_dim: int = 0,
|
|
30
|
+
pad_dim: int = 1,
|
|
31
|
+
) -> torch.Tensor:
|
|
32
|
+
"""
|
|
33
|
+
Concatenate two tensors along a specified dimension after padding along another dimension.
|
|
34
|
+
|
|
35
|
+
Assumes input tensors have shape (b, s, d), where:
|
|
36
|
+
- b: batch dimension
|
|
37
|
+
- s: sequence dimension (may differ)
|
|
38
|
+
- d: feature dimension
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
tensor1: First tensor with shape (b1, s1, d)
|
|
42
|
+
tensor2: Second tensor with shape (b2, s2, d)
|
|
43
|
+
concat_dim: Dimension to concatenate along, default is 0 (batch dimension)
|
|
44
|
+
pad_dim: Dimension to pad along, default is 1 (sequence dimension)
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Concatenated tensor, shape depends on concat_dim and pad_dim choices
|
|
48
|
+
"""
|
|
49
|
+
assert tensor1.dim() == tensor2.dim(), "Both tensors must have the same number of dimensions"
|
|
50
|
+
assert concat_dim != pad_dim, "concat_dim and pad_dim cannot be the same"
|
|
51
|
+
|
|
52
|
+
len1, len2 = tensor1.shape[pad_dim], tensor2.shape[pad_dim]
|
|
53
|
+
max_len = max(len1, len2)
|
|
54
|
+
|
|
55
|
+
# Calculate the position of pad_dim in the padding list
|
|
56
|
+
# Padding format: from the last dimension, each pair represents (dim_n_left, dim_n_right, ..., dim_0_left, dim_0_right)
|
|
57
|
+
ndim = tensor1.dim()
|
|
58
|
+
padding = [0] * (2 * ndim)
|
|
59
|
+
pad_right_idx = -2 * pad_dim - 1
|
|
60
|
+
|
|
61
|
+
if len1 < max_len:
|
|
62
|
+
pad_len = max_len - len1
|
|
63
|
+
padding[pad_right_idx] = pad_len
|
|
64
|
+
tensor1 = F.pad(tensor1, padding, mode="constant", value=0)
|
|
65
|
+
elif len2 < max_len:
|
|
66
|
+
pad_len = max_len - len2
|
|
67
|
+
padding[pad_right_idx] = pad_len
|
|
68
|
+
tensor2 = F.pad(tensor2, padding, mode="constant", value=0)
|
|
69
|
+
|
|
70
|
+
# Concatenate along the specified dimension
|
|
71
|
+
return torch.cat([tensor1, tensor2], dim=concat_dim)
|
|
@@ -394,6 +394,7 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
|
|
|
394
394
|
void_audio_input: torch.Tensor | None = None,
|
|
395
395
|
):
|
|
396
396
|
latents = latents.to(dtype=self.config.model_dtype, device=self.device)
|
|
397
|
+
attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
|
|
397
398
|
|
|
398
399
|
noise_pred = model(
|
|
399
400
|
x=latents,
|
|
@@ -408,6 +409,7 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
|
|
|
408
409
|
drop_motion_frames=drop_motion_frames,
|
|
409
410
|
audio_mask=audio_mask,
|
|
410
411
|
void_audio_input=void_audio_input,
|
|
412
|
+
attn_kwargs=attn_kwargs,
|
|
411
413
|
)
|
|
412
414
|
return noise_pred
|
|
413
415
|
|
|
@@ -654,19 +656,12 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
|
|
|
654
656
|
)
|
|
655
657
|
|
|
656
658
|
with LoRAContext():
|
|
657
|
-
attn_kwargs = {
|
|
658
|
-
"attn_impl": config.dit_attn_impl.value,
|
|
659
|
-
"sparge_smooth_k": config.sparge_smooth_k,
|
|
660
|
-
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
661
|
-
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
662
|
-
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
663
|
-
}
|
|
664
659
|
dit = WanS2VDiT.from_state_dict(
|
|
665
660
|
state_dicts.model,
|
|
666
661
|
config=model_config,
|
|
667
662
|
device=("cpu" if config.use_fsdp else init_device),
|
|
668
663
|
dtype=config.model_dtype,
|
|
669
|
-
|
|
664
|
+
use_vsa=(config.dit_attn_impl.value == "vsa"),
|
|
670
665
|
)
|
|
671
666
|
if config.use_fp8_linear:
|
|
672
667
|
enable_fp8_linear(dit)
|
|
@@ -323,6 +323,7 @@ class WanVideoPipeline(BasePipeline):
|
|
|
323
323
|
|
|
324
324
|
def predict_noise(self, model, latents, image_clip_feature, image_y, timestep, context):
|
|
325
325
|
latents = latents.to(dtype=self.config.model_dtype, device=self.device)
|
|
326
|
+
attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
|
|
326
327
|
|
|
327
328
|
noise_pred = model(
|
|
328
329
|
x=latents,
|
|
@@ -330,6 +331,7 @@ class WanVideoPipeline(BasePipeline):
|
|
|
330
331
|
context=context,
|
|
331
332
|
clip_feature=image_clip_feature,
|
|
332
333
|
y=image_y,
|
|
334
|
+
attn_kwargs=attn_kwargs,
|
|
333
335
|
)
|
|
334
336
|
return noise_pred
|
|
335
337
|
|
|
@@ -578,19 +580,12 @@ class WanVideoPipeline(BasePipeline):
|
|
|
578
580
|
dit_state_dict = state_dicts.model
|
|
579
581
|
|
|
580
582
|
with LoRAContext():
|
|
581
|
-
attn_kwargs = {
|
|
582
|
-
"attn_impl": config.dit_attn_impl.value,
|
|
583
|
-
"sparge_smooth_k": config.sparge_smooth_k,
|
|
584
|
-
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
585
|
-
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
586
|
-
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
587
|
-
}
|
|
588
583
|
dit = WanDiT.from_state_dict(
|
|
589
584
|
dit_state_dict,
|
|
590
585
|
config=dit_config,
|
|
591
586
|
device=("cpu" if config.use_fsdp else init_device),
|
|
592
587
|
dtype=config.model_dtype,
|
|
593
|
-
|
|
588
|
+
use_vsa=(config.dit_attn_impl.value == "vsa"),
|
|
594
589
|
)
|
|
595
590
|
if config.use_fp8_linear:
|
|
596
591
|
enable_fp8_linear(dit)
|
|
@@ -602,7 +597,7 @@ class WanVideoPipeline(BasePipeline):
|
|
|
602
597
|
config=dit_config,
|
|
603
598
|
device=("cpu" if config.use_fsdp else init_device),
|
|
604
599
|
dtype=config.model_dtype,
|
|
605
|
-
|
|
600
|
+
use_vsa=(config.dit_attn_impl.value == "vsa"),
|
|
606
601
|
)
|
|
607
602
|
if config.use_fp8_linear:
|
|
608
603
|
enable_fp8_linear(dit2)
|
|
@@ -640,19 +635,22 @@ class WanVideoPipeline(BasePipeline):
|
|
|
640
635
|
@staticmethod
|
|
641
636
|
def _get_dit_type(model_state_dict: Dict[str, torch.Tensor] | Dict[str, Dict[str, torch.Tensor]]) -> str:
|
|
642
637
|
# determine wan dit type by model params
|
|
638
|
+
def has_any_key(*xs):
|
|
639
|
+
return any(x in model_state_dict for x in xs)
|
|
640
|
+
|
|
643
641
|
dit_type = None
|
|
644
|
-
if "high_noise_model"
|
|
642
|
+
if has_any_key("high_noise_model"):
|
|
645
643
|
if model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 36:
|
|
646
644
|
dit_type = "wan2.2-i2v-a14b"
|
|
647
645
|
elif model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 16:
|
|
648
646
|
dit_type = "wan2.2-t2v-a14b"
|
|
649
647
|
elif model_state_dict["patch_embedding.weight"].shape[1] == 48:
|
|
650
648
|
dit_type = "wan2.2-ti2v-5b"
|
|
651
|
-
elif "img_emb.emb_pos"
|
|
649
|
+
elif has_any_key("img_emb.emb_pos", "condition_embedder.image_embedder.pos_embed"):
|
|
652
650
|
dit_type = "wan2.1-flf2v-14b"
|
|
653
|
-
elif "img_emb.proj.0.weight"
|
|
651
|
+
elif has_any_key("img_emb.proj.0.weight", "condition_embedder.image_embedder.norm1"):
|
|
654
652
|
dit_type = "wan2.1-i2v-14b"
|
|
655
|
-
elif "blocks.39.self_attn.norm_q.weight"
|
|
653
|
+
elif has_any_key("blocks.39.self_attn.norm_q.weight", "blocks.39.attn1.norm_q.weight"):
|
|
656
654
|
dit_type = "wan2.1-t2v-14b"
|
|
657
655
|
else:
|
|
658
656
|
dit_type = "wan2.1-t2v-1.3b"
|
|
@@ -1,10 +1,16 @@
|
|
|
1
1
|
# Modified from transformers.tokenization_utils_base
|
|
2
2
|
from typing import Dict, List, Union, overload
|
|
3
|
+
from enum import Enum
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
|
|
6
7
|
|
|
7
8
|
|
|
9
|
+
class PaddingStrategy(str, Enum):
|
|
10
|
+
LONGEST = "longest"
|
|
11
|
+
MAX_LENGTH = "max_length"
|
|
12
|
+
|
|
13
|
+
|
|
8
14
|
class BaseTokenizer:
|
|
9
15
|
SPECIAL_TOKENS_ATTRIBUTES = [
|
|
10
16
|
"bos_token",
|
|
@@ -4,7 +4,7 @@ import torch
|
|
|
4
4
|
from typing import Dict, List, Union, Optional
|
|
5
5
|
from tokenizers import Tokenizer as TokenizerFast, AddedToken
|
|
6
6
|
|
|
7
|
-
from diffsynth_engine.tokenizers.base import BaseTokenizer, TOKENIZER_CONFIG_FILE
|
|
7
|
+
from diffsynth_engine.tokenizers.base import BaseTokenizer, PaddingStrategy, TOKENIZER_CONFIG_FILE
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
|
|
@@ -165,22 +165,28 @@ class Qwen2TokenizerFast(BaseTokenizer):
|
|
|
165
165
|
texts: Union[str, List[str]],
|
|
166
166
|
max_length: Optional[int] = None,
|
|
167
167
|
padding_side: Optional[str] = None,
|
|
168
|
+
padding_strategy: Union[PaddingStrategy, str] = "longest",
|
|
168
169
|
**kwargs,
|
|
169
170
|
) -> Dict[str, "torch.Tensor"]:
|
|
170
171
|
"""
|
|
171
172
|
Tokenize text and prepare for model inputs.
|
|
172
173
|
|
|
173
174
|
Args:
|
|
174
|
-
|
|
175
|
+
texts (`str`, `List[str]`):
|
|
175
176
|
The sequence or batch of sequences to be encoded.
|
|
176
177
|
|
|
177
178
|
max_length (`int`, *optional*):
|
|
178
|
-
|
|
179
|
+
Maximum length of the encoded sequences.
|
|
179
180
|
|
|
180
181
|
padding_side (`str`, *optional*):
|
|
181
182
|
The side on which the padding should be applied. Should be selected between `"right"` and `"left"`.
|
|
182
183
|
Defaults to `"right"`.
|
|
183
184
|
|
|
185
|
+
padding_strategy (`PaddingStrategy`, `str`, *optional*):
|
|
186
|
+
If `"longest"`, will pad the sequences to the longest sequence in the batch.
|
|
187
|
+
If `"max_length"`, will pad the sequences to the `max_length` argument.
|
|
188
|
+
Defaults to `"longest"`.
|
|
189
|
+
|
|
184
190
|
Returns:
|
|
185
191
|
`Dict[str, "torch.Tensor"]`: tensor dict compatible with model_input_names.
|
|
186
192
|
"""
|
|
@@ -190,7 +196,9 @@ class Qwen2TokenizerFast(BaseTokenizer):
|
|
|
190
196
|
|
|
191
197
|
batch_ids = self.batch_encode(texts)
|
|
192
198
|
ids_lens = [len(ids_) for ids_ in batch_ids]
|
|
193
|
-
max_length = max_length if max_length is not None else
|
|
199
|
+
max_length = max_length if max_length is not None else self.model_max_length
|
|
200
|
+
if padding_strategy == PaddingStrategy.LONGEST:
|
|
201
|
+
max_length = min(max(ids_lens), max_length)
|
|
194
202
|
padding_side = padding_side if padding_side is not None else self.padding_side
|
|
195
203
|
|
|
196
204
|
encoded = torch.zeros(len(texts), max_length, dtype=torch.long)
|
|
@@ -27,18 +27,19 @@ SD3_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd3", "sd3_tex
|
|
|
27
27
|
SDXL_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_text_encoder.json")
|
|
28
28
|
SDXL_UNET_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_unet.json")
|
|
29
29
|
|
|
30
|
-
WAN2_1_DIT_T2V_1_3B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
31
|
-
WAN2_1_DIT_T2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
32
|
-
WAN2_1_DIT_I2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
33
|
-
WAN2_1_DIT_FLF2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
34
|
-
WAN2_2_DIT_TI2V_5B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
35
|
-
WAN2_2_DIT_T2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
36
|
-
WAN2_2_DIT_I2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
37
|
-
WAN2_2_DIT_S2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
30
|
+
WAN2_1_DIT_T2V_1_3B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_t2v_1.3b.json")
|
|
31
|
+
WAN2_1_DIT_T2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_t2v_14b.json")
|
|
32
|
+
WAN2_1_DIT_I2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_i2v_14b.json")
|
|
33
|
+
WAN2_1_DIT_FLF2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_flf2v_14b.json")
|
|
34
|
+
WAN2_2_DIT_TI2V_5B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_ti2v_5b.json")
|
|
35
|
+
WAN2_2_DIT_T2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_t2v_a14b.json")
|
|
36
|
+
WAN2_2_DIT_I2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_i2v_a14b.json")
|
|
37
|
+
WAN2_2_DIT_S2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_s2v_14b.json")
|
|
38
|
+
WAN_DIT_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan_dit_keymap.json")
|
|
39
|
+
|
|
40
|
+
WAN2_1_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.1_vae.json")
|
|
41
|
+
WAN2_2_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.2_vae.json")
|
|
42
|
+
WAN_VAE_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan_vae_keymap.json")
|
|
42
43
|
|
|
43
44
|
QWEN_IMAGE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "qwen_image", "qwen2_5_vl_config.json")
|
|
44
45
|
QWEN_IMAGE_VISION_CONFIG_FILE = os.path.join(CONF_PATH, "models", "qwen_image", "qwen2_5_vl_vision_config.json")
|
diffsynth_engine/utils/flag.py
CHANGED
|
@@ -44,3 +44,9 @@ if SPARGE_ATTN_AVAILABLE:
|
|
|
44
44
|
logger.info("Sparge attention is available")
|
|
45
45
|
else:
|
|
46
46
|
logger.info("Sparge attention is not available")
|
|
47
|
+
|
|
48
|
+
VIDEO_SPARSE_ATTN_AVAILABLE = importlib.util.find_spec("vsa") is not None
|
|
49
|
+
if VIDEO_SPARSE_ATTN_AVAILABLE:
|
|
50
|
+
logger.info("Video sparse attention is available")
|
|
51
|
+
else:
|
|
52
|
+
logger.info("Video sparse attention is not available")
|