diffsynth-engine 0.5.1.dev4__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +12 -0
- diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +19 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +22 -6
- diffsynth_engine/conf/models/flux/flux_dit.json +20 -1
- diffsynth_engine/conf/models/flux/flux_vae.json +253 -5
- diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
- diffsynth_engine/configs/__init__.py +16 -1
- diffsynth_engine/configs/controlnet.py +13 -0
- diffsynth_engine/configs/pipeline.py +37 -11
- diffsynth_engine/models/base.py +1 -1
- diffsynth_engine/models/basic/attention.py +105 -43
- diffsynth_engine/models/basic/transformer_helper.py +36 -2
- diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
- diffsynth_engine/models/flux/flux_controlnet.py +16 -30
- diffsynth_engine/models/flux/flux_dit.py +49 -62
- diffsynth_engine/models/flux/flux_dit_fbcache.py +26 -28
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/flux/flux_text_encoder.py +1 -1
- diffsynth_engine/models/flux/flux_vae.py +20 -2
- diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +4 -2
- diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
- diffsynth_engine/models/qwen_image/qwen_image_dit.py +151 -58
- diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
- diffsynth_engine/models/qwen_image/qwen_image_vae.py +1 -1
- diffsynth_engine/models/sd/sd_text_encoder.py +1 -1
- diffsynth_engine/models/sd/sd_unet.py +1 -1
- diffsynth_engine/models/sd3/sd3_dit.py +1 -1
- diffsynth_engine/models/sd3/sd3_text_encoder.py +1 -1
- diffsynth_engine/models/sdxl/sdxl_text_encoder.py +1 -1
- diffsynth_engine/models/sdxl/sdxl_unet.py +1 -1
- diffsynth_engine/models/vae/vae.py +1 -1
- diffsynth_engine/models/wan/wan_audio_encoder.py +6 -3
- diffsynth_engine/models/wan/wan_dit.py +65 -28
- diffsynth_engine/models/wan/wan_s2v_dit.py +1 -1
- diffsynth_engine/models/wan/wan_text_encoder.py +13 -13
- diffsynth_engine/models/wan/wan_vae.py +2 -2
- diffsynth_engine/pipelines/base.py +73 -7
- diffsynth_engine/pipelines/flux_image.py +139 -120
- diffsynth_engine/pipelines/hunyuan3d_shape.py +4 -0
- diffsynth_engine/pipelines/qwen_image.py +272 -87
- diffsynth_engine/pipelines/sdxl_image.py +1 -1
- diffsynth_engine/pipelines/utils.py +52 -0
- diffsynth_engine/pipelines/wan_s2v.py +25 -14
- diffsynth_engine/pipelines/wan_video.py +43 -19
- diffsynth_engine/tokenizers/base.py +6 -0
- diffsynth_engine/tokenizers/qwen2.py +12 -4
- diffsynth_engine/utils/constants.py +13 -12
- diffsynth_engine/utils/download.py +4 -2
- diffsynth_engine/utils/env.py +2 -0
- diffsynth_engine/utils/flag.py +6 -0
- diffsynth_engine/utils/loader.py +25 -6
- diffsynth_engine/utils/parallel.py +62 -29
- diffsynth_engine/utils/video.py +3 -1
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +69 -67
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
|
@@ -17,6 +17,7 @@ from diffsynth_engine.utils.constants import (
|
|
|
17
17
|
WAN2_2_DIT_TI2V_5B_CONFIG_FILE,
|
|
18
18
|
WAN2_2_DIT_I2V_A14B_CONFIG_FILE,
|
|
19
19
|
WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
|
|
20
|
+
WAN_DIT_KEYMAP_FILE,
|
|
20
21
|
)
|
|
21
22
|
from diffsynth_engine.utils.gguf import gguf_inference
|
|
22
23
|
from diffsynth_engine.utils.fp8_linear import fp8_inference
|
|
@@ -30,6 +31,9 @@ from diffsynth_engine.utils.parallel import (
|
|
|
30
31
|
T5_TOKEN_NUM = 512
|
|
31
32
|
FLF_TOKEN_NUM = 257 * 2
|
|
32
33
|
|
|
34
|
+
with open(WAN_DIT_KEYMAP_FILE, "r", encoding="utf-8") as f:
|
|
35
|
+
config = json.load(f)
|
|
36
|
+
|
|
33
37
|
|
|
34
38
|
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
|
35
39
|
return x * (1 + scale) + shift
|
|
@@ -73,7 +77,7 @@ class SelfAttention(nn.Module):
|
|
|
73
77
|
dim: int,
|
|
74
78
|
num_heads: int,
|
|
75
79
|
eps: float = 1e-6,
|
|
76
|
-
|
|
80
|
+
use_vsa: bool = False,
|
|
77
81
|
device: str = "cuda:0",
|
|
78
82
|
dtype: torch.dtype = torch.bfloat16,
|
|
79
83
|
):
|
|
@@ -86,19 +90,25 @@ class SelfAttention(nn.Module):
|
|
|
86
90
|
self.o = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
87
91
|
self.norm_q = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
88
92
|
self.norm_k = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
89
|
-
self.
|
|
93
|
+
self.gate_compress = nn.Linear(dim, dim, device=device, dtype=dtype) if use_vsa else None
|
|
90
94
|
|
|
91
|
-
def forward(self, x, freqs):
|
|
95
|
+
def forward(self, x, freqs, attn_kwargs=None):
|
|
92
96
|
q, k, v = self.norm_q(self.q(x)), self.norm_k(self.k(x)), self.v(x)
|
|
97
|
+
g = self.gate_compress(x) if self.gate_compress is not None else None
|
|
98
|
+
|
|
93
99
|
num_heads = q.shape[2] // self.head_dim
|
|
94
100
|
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
|
95
101
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
|
96
102
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
|
103
|
+
g = rearrange(g, "b s (n d) -> b s n d", n=num_heads) if g is not None else None
|
|
104
|
+
|
|
105
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
97
106
|
x = attention_ops.attention(
|
|
98
107
|
q=rope_apply(q, freqs),
|
|
99
108
|
k=rope_apply(k, freqs),
|
|
100
109
|
v=v,
|
|
101
|
-
|
|
110
|
+
g=g,
|
|
111
|
+
**attn_kwargs,
|
|
102
112
|
)
|
|
103
113
|
x = x.flatten(2)
|
|
104
114
|
return self.o(x)
|
|
@@ -111,7 +121,6 @@ class CrossAttention(nn.Module):
|
|
|
111
121
|
num_heads: int,
|
|
112
122
|
eps: float = 1e-6,
|
|
113
123
|
has_image_input: bool = False,
|
|
114
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
115
124
|
device: str = "cuda:0",
|
|
116
125
|
dtype: torch.dtype = torch.bfloat16,
|
|
117
126
|
):
|
|
@@ -130,9 +139,8 @@ class CrossAttention(nn.Module):
|
|
|
130
139
|
self.k_img = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
131
140
|
self.v_img = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
132
141
|
self.norm_k_img = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
133
|
-
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
134
142
|
|
|
135
|
-
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
143
|
+
def forward(self, x: torch.Tensor, y: torch.Tensor, attn_kwargs=None):
|
|
136
144
|
if self.has_image_input:
|
|
137
145
|
img = y[:, :-T5_TOKEN_NUM]
|
|
138
146
|
ctx = y[:, -T5_TOKEN_NUM:]
|
|
@@ -144,12 +152,16 @@ class CrossAttention(nn.Module):
|
|
|
144
152
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
|
145
153
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
|
146
154
|
|
|
147
|
-
|
|
155
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
156
|
+
if attn_kwargs.get("attn_impl", None) == "vsa":
|
|
157
|
+
attn_kwargs = attn_kwargs.copy()
|
|
158
|
+
attn_kwargs["attn_impl"] = "sdpa"
|
|
159
|
+
x = attention(q, k, v, **attn_kwargs).flatten(2)
|
|
148
160
|
if self.has_image_input:
|
|
149
161
|
k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
|
|
150
162
|
k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
|
|
151
163
|
v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
|
|
152
|
-
y = attention(q, k_img, v_img, **
|
|
164
|
+
y = attention(q, k_img, v_img, **attn_kwargs).flatten(2)
|
|
153
165
|
x = x + y
|
|
154
166
|
return self.o(x)
|
|
155
167
|
|
|
@@ -162,7 +174,7 @@ class DiTBlock(nn.Module):
|
|
|
162
174
|
num_heads: int,
|
|
163
175
|
ffn_dim: int,
|
|
164
176
|
eps: float = 1e-6,
|
|
165
|
-
|
|
177
|
+
use_vsa: bool = False,
|
|
166
178
|
device: str = "cuda:0",
|
|
167
179
|
dtype: torch.dtype = torch.bfloat16,
|
|
168
180
|
):
|
|
@@ -170,9 +182,9 @@ class DiTBlock(nn.Module):
|
|
|
170
182
|
self.dim = dim
|
|
171
183
|
self.num_heads = num_heads
|
|
172
184
|
self.ffn_dim = ffn_dim
|
|
173
|
-
self.self_attn = SelfAttention(dim, num_heads, eps,
|
|
185
|
+
self.self_attn = SelfAttention(dim, num_heads, eps, use_vsa=use_vsa, device=device, dtype=dtype)
|
|
174
186
|
self.cross_attn = CrossAttention(
|
|
175
|
-
dim, num_heads, eps, has_image_input=has_image_input,
|
|
187
|
+
dim, num_heads, eps, has_image_input=has_image_input, device=device, dtype=dtype
|
|
176
188
|
)
|
|
177
189
|
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
|
|
178
190
|
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
|
|
@@ -184,14 +196,14 @@ class DiTBlock(nn.Module):
|
|
|
184
196
|
)
|
|
185
197
|
self.modulation = nn.Parameter(torch.randn(1, 6, dim, device=device, dtype=dtype) / dim**0.5)
|
|
186
198
|
|
|
187
|
-
def forward(self, x, context, t_mod, freqs):
|
|
199
|
+
def forward(self, x, context, t_mod, freqs, attn_kwargs=None):
|
|
188
200
|
# msa: multi-head self-attention mlp: multi-layer perceptron
|
|
189
201
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
|
190
202
|
t.squeeze(1) for t in (self.modulation + t_mod).chunk(6, dim=1)
|
|
191
203
|
]
|
|
192
204
|
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
|
193
|
-
x = x + gate_msa * self.self_attn(input_x, freqs)
|
|
194
|
-
x = x + self.cross_attn(self.norm3(x), context)
|
|
205
|
+
x = x + gate_msa * self.self_attn(input_x, freqs, attn_kwargs)
|
|
206
|
+
x = x + self.cross_attn(self.norm3(x), context, attn_kwargs)
|
|
195
207
|
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
|
196
208
|
x = x + gate_mlp * self.ffn(input_x)
|
|
197
209
|
return x
|
|
@@ -249,7 +261,26 @@ class Head(nn.Module):
|
|
|
249
261
|
|
|
250
262
|
|
|
251
263
|
class WanDiTStateDictConverter(StateDictConverter):
|
|
264
|
+
def _from_diffusers(self, state_dict):
|
|
265
|
+
global_rename_dict = config["diffusers"]["global_rename_dict"]
|
|
266
|
+
rename_dict = config["diffusers"]["rename_dict"]
|
|
267
|
+
state_dict_ = {}
|
|
268
|
+
for name, param in state_dict.items():
|
|
269
|
+
suffix = ""
|
|
270
|
+
suffix = ".weight" if name.endswith(".weight") else suffix
|
|
271
|
+
suffix = ".bias" if name.endswith(".bias") else suffix
|
|
272
|
+
prefix = name[: -len(suffix)] if suffix else name
|
|
273
|
+
if prefix in global_rename_dict:
|
|
274
|
+
state_dict_[f"{global_rename_dict[prefix]}{suffix}"] = param
|
|
275
|
+
if prefix.startswith("blocks."):
|
|
276
|
+
_, idx, middle = prefix.split(".", 2)
|
|
277
|
+
if middle in rename_dict:
|
|
278
|
+
state_dict_[f"blocks.{idx}.{rename_dict[middle]}{suffix}"] = param
|
|
279
|
+
return state_dict_
|
|
280
|
+
|
|
252
281
|
def convert(self, state_dict):
|
|
282
|
+
if "condition_embedder.time_proj.weight" in state_dict:
|
|
283
|
+
return self._from_diffusers(state_dict)
|
|
253
284
|
return state_dict
|
|
254
285
|
|
|
255
286
|
|
|
@@ -273,7 +304,7 @@ class WanDiT(PreTrainedModel):
|
|
|
273
304
|
has_vae_feature: bool = False,
|
|
274
305
|
fuse_image_latents: bool = False,
|
|
275
306
|
flf_pos_emb: bool = False,
|
|
276
|
-
|
|
307
|
+
use_vsa: bool = False,
|
|
277
308
|
device: str = "cuda:0",
|
|
278
309
|
dtype: torch.dtype = torch.bfloat16,
|
|
279
310
|
):
|
|
@@ -307,7 +338,16 @@ class WanDiT(PreTrainedModel):
|
|
|
307
338
|
)
|
|
308
339
|
self.blocks = nn.ModuleList(
|
|
309
340
|
[
|
|
310
|
-
DiTBlock(
|
|
341
|
+
DiTBlock(
|
|
342
|
+
has_clip_feature,
|
|
343
|
+
dim,
|
|
344
|
+
num_heads,
|
|
345
|
+
ffn_dim,
|
|
346
|
+
eps,
|
|
347
|
+
use_vsa,
|
|
348
|
+
device=device,
|
|
349
|
+
dtype=dtype,
|
|
350
|
+
)
|
|
311
351
|
for _ in range(num_layers)
|
|
312
352
|
]
|
|
313
353
|
)
|
|
@@ -344,6 +384,7 @@ class WanDiT(PreTrainedModel):
|
|
|
344
384
|
timestep: torch.Tensor,
|
|
345
385
|
clip_feature: Optional[torch.Tensor] = None, # clip_vision_encoder(img)
|
|
346
386
|
y: Optional[torch.Tensor] = None, # vae_encoder(img)
|
|
387
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
347
388
|
):
|
|
348
389
|
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
|
|
349
390
|
use_cfg = x.shape[0] > 1
|
|
@@ -376,7 +417,7 @@ class WanDiT(PreTrainedModel):
|
|
|
376
417
|
|
|
377
418
|
with sequence_parallel((x, t, t_mod, freqs), seq_dims=(1, 0, 0, 0)):
|
|
378
419
|
for block in self.blocks:
|
|
379
|
-
x = block(x, context, t_mod, freqs)
|
|
420
|
+
x = block(x, context, t_mod, freqs, attn_kwargs)
|
|
380
421
|
x = self.head(x, t)
|
|
381
422
|
(x,) = sequence_parallel_unshard((x,), seq_dims=(1,), seq_lens=(f * h * w,))
|
|
382
423
|
x = self.unpatchify(x, (f, h, w))
|
|
@@ -398,7 +439,7 @@ class WanDiT(PreTrainedModel):
|
|
|
398
439
|
raise ValueError(f"Unsupported model type: {model_type}")
|
|
399
440
|
|
|
400
441
|
config_file = MODEL_CONFIG_FILES[model_type]
|
|
401
|
-
with open(config_file, "r") as f:
|
|
442
|
+
with open(config_file, "r", encoding="utf-8") as f:
|
|
402
443
|
config = json.load(f)
|
|
403
444
|
return config
|
|
404
445
|
|
|
@@ -409,12 +450,11 @@ class WanDiT(PreTrainedModel):
|
|
|
409
450
|
config: Dict[str, Any],
|
|
410
451
|
device: str = "cuda:0",
|
|
411
452
|
dtype: torch.dtype = torch.bfloat16,
|
|
412
|
-
|
|
413
|
-
assign: bool = True,
|
|
453
|
+
use_vsa: bool = False,
|
|
414
454
|
):
|
|
415
|
-
model = cls(**config, device="meta", dtype=dtype,
|
|
455
|
+
model = cls(**config, device="meta", dtype=dtype, use_vsa=use_vsa)
|
|
416
456
|
model = model.requires_grad_(False)
|
|
417
|
-
model.load_state_dict(state_dict, assign=
|
|
457
|
+
model.load_state_dict(state_dict, assign=True)
|
|
418
458
|
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
419
459
|
return model
|
|
420
460
|
|
|
@@ -499,8 +539,5 @@ class WanDiT(PreTrainedModel):
|
|
|
499
539
|
for block in self.blocks:
|
|
500
540
|
block.compile(*args, **kwargs)
|
|
501
541
|
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
def get_fsdp_modules(self):
|
|
506
|
-
return ["blocks"]
|
|
542
|
+
def get_fsdp_module_cls(self):
|
|
543
|
+
return {DiTBlock}
|
|
@@ -360,7 +360,7 @@ class WanS2VDiT(WanDiT):
|
|
|
360
360
|
raise ValueError(f"Unsupported model type: {model_type}")
|
|
361
361
|
|
|
362
362
|
config_file = MODEL_CONFIG_FILES[model_type]
|
|
363
|
-
with open(config_file, "r") as f:
|
|
363
|
+
with open(config_file, "r", encoding="utf-8") as f:
|
|
364
364
|
config = json.load(f)
|
|
365
365
|
return config
|
|
366
366
|
|
|
@@ -198,22 +198,22 @@ class WanTextEncoderStateDictConverter(StateDictConverter):
|
|
|
198
198
|
|
|
199
199
|
def _from_diffusers(self, state_dict):
|
|
200
200
|
rename_dict = {
|
|
201
|
-
"
|
|
202
|
-
"
|
|
201
|
+
"shared.weight": "token_embedding.weight",
|
|
202
|
+
"encoder.final_layer_norm.weight": "norm.weight",
|
|
203
203
|
}
|
|
204
204
|
for i in range(self.num_encoder_layers):
|
|
205
205
|
rename_dict.update(
|
|
206
206
|
{
|
|
207
|
-
f"
|
|
208
|
-
f"
|
|
209
|
-
f"
|
|
210
|
-
f"
|
|
211
|
-
f"
|
|
212
|
-
f"
|
|
213
|
-
f"
|
|
214
|
-
f"
|
|
215
|
-
f"
|
|
216
|
-
f"
|
|
207
|
+
f"encoder.block.{i}.layer.0.SelfAttention.q.weight": f"blocks.{i}.attn.q.weight",
|
|
208
|
+
f"encoder.block.{i}.layer.0.SelfAttention.k.weight": f"blocks.{i}.attn.k.weight",
|
|
209
|
+
f"encoder.block.{i}.layer.0.SelfAttention.v.weight": f"blocks.{i}.attn.v.weight",
|
|
210
|
+
f"encoder.block.{i}.layer.0.SelfAttention.o.weight": f"blocks.{i}.attn.o.weight",
|
|
211
|
+
f"encoder.block.{i}.layer.0.SelfAttention.relative_attention_bias.weight": f"blocks.{i}.pos_embedding.embedding.weight",
|
|
212
|
+
f"encoder.block.{i}.layer.0.layer_norm.weight": f"blocks.{i}.norm1.weight",
|
|
213
|
+
f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight": f"blocks.{i}.ffn.gate.0.weight",
|
|
214
|
+
f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight": f"blocks.{i}.ffn.fc1.weight",
|
|
215
|
+
f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight": f"blocks.{i}.ffn.fc2.weight",
|
|
216
|
+
f"encoder.block.{i}.layer.1.layer_norm.weight": f"blocks.{i}.norm2.weight",
|
|
217
217
|
}
|
|
218
218
|
)
|
|
219
219
|
|
|
@@ -224,7 +224,7 @@ class WanTextEncoderStateDictConverter(StateDictConverter):
|
|
|
224
224
|
return new_state_dict
|
|
225
225
|
|
|
226
226
|
def convert(self, state_dict):
|
|
227
|
-
if "
|
|
227
|
+
if "encoder.final_layer_norm.weight" in state_dict:
|
|
228
228
|
logger.info("use diffusers format state dict")
|
|
229
229
|
return self._from_diffusers(state_dict)
|
|
230
230
|
return state_dict
|
|
@@ -12,7 +12,7 @@ from diffsynth_engine.utils.constants import WAN2_1_VAE_CONFIG_FILE, WAN2_2_VAE_
|
|
|
12
12
|
|
|
13
13
|
CACHE_T = 2
|
|
14
14
|
|
|
15
|
-
with open(WAN_VAE_KEYMAP_FILE, "r") as f:
|
|
15
|
+
with open(WAN_VAE_KEYMAP_FILE, "r", encoding="utf-8") as f:
|
|
16
16
|
config = json.load(f)
|
|
17
17
|
|
|
18
18
|
|
|
@@ -855,7 +855,7 @@ class WanVideoVAE(PreTrainedModel):
|
|
|
855
855
|
raise ValueError(f"Unsupported model type: {model_type}")
|
|
856
856
|
|
|
857
857
|
config_file = MODEL_CONFIG_FILES[model_type]
|
|
858
|
-
with open(config_file, "r") as f:
|
|
858
|
+
with open(config_file, "r", encoding="utf-8") as f:
|
|
859
859
|
config = json.load(f)
|
|
860
860
|
return config
|
|
861
861
|
|
|
@@ -2,10 +2,18 @@ import os
|
|
|
2
2
|
import torch
|
|
3
3
|
import numpy as np
|
|
4
4
|
from einops import rearrange
|
|
5
|
-
from typing import Dict, List, Tuple
|
|
5
|
+
from typing import Dict, List, Tuple, Union, Optional
|
|
6
6
|
from PIL import Image
|
|
7
7
|
|
|
8
|
-
from diffsynth_engine.configs import
|
|
8
|
+
from diffsynth_engine.configs import (
|
|
9
|
+
BaseConfig,
|
|
10
|
+
BaseStateDicts,
|
|
11
|
+
LoraConfig,
|
|
12
|
+
AttnImpl,
|
|
13
|
+
SpargeAttentionParams,
|
|
14
|
+
VideoSparseAttentionParams,
|
|
15
|
+
)
|
|
16
|
+
from diffsynth_engine.models.basic.video_sparse_attention import get_vsa_kwargs
|
|
9
17
|
from diffsynth_engine.utils.offload import enable_sequential_cpu_offload, offload_model_to_dict, restore_model_from_dict
|
|
10
18
|
from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
|
|
11
19
|
from diffsynth_engine.utils.gguf import load_gguf_checkpoint
|
|
@@ -33,6 +41,7 @@ class BasePipeline:
|
|
|
33
41
|
dtype=torch.float16,
|
|
34
42
|
):
|
|
35
43
|
super().__init__()
|
|
44
|
+
self.config = None
|
|
36
45
|
self.vae_tiled = vae_tiled
|
|
37
46
|
self.vae_tile_size = vae_tile_size
|
|
38
47
|
self.vae_tile_stride = vae_tile_stride
|
|
@@ -48,14 +57,49 @@ class BasePipeline:
|
|
|
48
57
|
raise NotImplementedError()
|
|
49
58
|
|
|
50
59
|
@classmethod
|
|
51
|
-
def from_state_dict(cls, state_dicts: BaseStateDicts,
|
|
60
|
+
def from_state_dict(cls, state_dicts: BaseStateDicts, config: BaseConfig) -> "BasePipeline":
|
|
52
61
|
raise NotImplementedError()
|
|
53
62
|
|
|
54
|
-
def
|
|
55
|
-
|
|
56
|
-
|
|
63
|
+
def update_weights(self, state_dicts: BaseStateDicts) -> None:
|
|
64
|
+
raise NotImplementedError()
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def update_component(
|
|
68
|
+
component: torch.nn.Module,
|
|
69
|
+
state_dict: Dict[str, torch.Tensor],
|
|
70
|
+
device: str,
|
|
71
|
+
dtype: torch.dtype,
|
|
72
|
+
) -> None:
|
|
73
|
+
if component and state_dict:
|
|
74
|
+
component.load_state_dict(state_dict, assign=True)
|
|
75
|
+
component.to(device=device, dtype=dtype, non_blocking=True)
|
|
76
|
+
|
|
77
|
+
def load_loras(
|
|
78
|
+
self,
|
|
79
|
+
lora_list: List[Tuple[str, Union[float, LoraConfig]]],
|
|
80
|
+
fused: bool = True,
|
|
81
|
+
save_original_weight: bool = False,
|
|
82
|
+
lora_converter: Optional[LoRAStateDictConverter] = None,
|
|
83
|
+
):
|
|
84
|
+
if not lora_converter:
|
|
85
|
+
lora_converter = self.lora_converter
|
|
86
|
+
|
|
87
|
+
for lora_path, lora_item in lora_list:
|
|
88
|
+
if isinstance(lora_item, float):
|
|
89
|
+
lora_scale = lora_item
|
|
90
|
+
scheduler_config = None
|
|
91
|
+
if isinstance(lora_item, LoraConfig):
|
|
92
|
+
lora_scale = lora_item.scale
|
|
93
|
+
scheduler_config = lora_item.scheduler_config
|
|
94
|
+
|
|
95
|
+
logger.info(f"loading lora from {lora_path} with LoraConfig (scale={lora_scale})")
|
|
57
96
|
state_dict = load_file(lora_path, device=self.device)
|
|
58
|
-
|
|
97
|
+
|
|
98
|
+
if scheduler_config is not None:
|
|
99
|
+
self.apply_scheduler_config(scheduler_config)
|
|
100
|
+
logger.info(f"Applied scheduler args from LoraConfig: {scheduler_config}")
|
|
101
|
+
|
|
102
|
+
lora_state_dict = lora_converter.convert(state_dict)
|
|
59
103
|
for model_name, state_dict in lora_state_dict.items():
|
|
60
104
|
model = getattr(self, model_name)
|
|
61
105
|
lora_args = []
|
|
@@ -78,6 +122,9 @@ class BasePipeline:
|
|
|
78
122
|
def load_lora(self, path: str, scale: float, fused: bool = True, save_original_weight: bool = False):
|
|
79
123
|
self.load_loras([(path, scale)], fused, save_original_weight)
|
|
80
124
|
|
|
125
|
+
def apply_scheduler_config(self, scheduler_config: Dict):
|
|
126
|
+
pass
|
|
127
|
+
|
|
81
128
|
def unload_loras(self):
|
|
82
129
|
raise NotImplementedError()
|
|
83
130
|
|
|
@@ -222,6 +269,25 @@ class BasePipeline:
|
|
|
222
269
|
)
|
|
223
270
|
return init_latents, latents, sigmas, timesteps
|
|
224
271
|
|
|
272
|
+
def get_attn_kwargs(self, latents: torch.Tensor) -> Dict:
|
|
273
|
+
attn_kwargs = {"attn_impl": self.config.dit_attn_impl.value}
|
|
274
|
+
if isinstance(self.config.attn_params, SpargeAttentionParams):
|
|
275
|
+
assert self.config.dit_attn_impl == AttnImpl.SPARGE
|
|
276
|
+
attn_kwargs.update(
|
|
277
|
+
{
|
|
278
|
+
"smooth_k": self.config.attn_params.smooth_k,
|
|
279
|
+
"simthreshd1": self.config.attn_params.simthreshd1,
|
|
280
|
+
"cdfthreshd": self.config.attn_params.cdfthreshd,
|
|
281
|
+
"pvthreshd": self.config.attn_params.pvthreshd,
|
|
282
|
+
}
|
|
283
|
+
)
|
|
284
|
+
elif isinstance(self.config.attn_params, VideoSparseAttentionParams):
|
|
285
|
+
assert self.config.dit_attn_impl == AttnImpl.VSA
|
|
286
|
+
attn_kwargs.update(
|
|
287
|
+
get_vsa_kwargs(latents.shape[2:], (1, 2, 2), self.config.attn_params.sparsity, device=self.device)
|
|
288
|
+
)
|
|
289
|
+
return attn_kwargs
|
|
290
|
+
|
|
225
291
|
def eval(self):
|
|
226
292
|
for model_name in self.model_names:
|
|
227
293
|
model = getattr(self, model_name)
|