diffsynth-engine 0.3.6.dev13__py3-none-any.whl → 0.3.6.dev14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +2 -3
- diffsynth_engine/conf/models/wan/dit/{14b-i2v.json → wan2.1-flf2v-14b.json} +5 -2
- diffsynth_engine/conf/models/wan/dit/{14b-flf2v.json → wan2.1-i2v-14b.json} +2 -2
- diffsynth_engine/conf/models/wan/dit/{1.3b-t2v.json → wan2.1-t2v-1.3b.json} +0 -1
- diffsynth_engine/conf/models/wan/dit/{14b-t2v.json → wan2.1-t2v-14b.json} +0 -1
- diffsynth_engine/conf/models/wan/dit/wan2.2-i2v-a14b.json +16 -0
- diffsynth_engine/conf/models/wan/dit/wan2.2-t2v-a14b.json +16 -0
- diffsynth_engine/conf/models/wan/dit/wan2.2-ti2v-5b.json +14 -0
- diffsynth_engine/conf/models/wan/vae/wan2.1-vae.json +48 -0
- diffsynth_engine/conf/models/wan/vae/wan2.2-vae.json +112 -0
- diffsynth_engine/configs/pipeline.py +6 -1
- diffsynth_engine/models/wan/wan_dit.py +52 -32
- diffsynth_engine/models/wan/wan_vae.py +355 -60
- diffsynth_engine/pipelines/base.py +15 -11
- diffsynth_engine/pipelines/wan_video.py +175 -74
- diffsynth_engine/utils/constants.py +10 -4
- diffsynth_engine/utils/parallel.py +3 -1
- {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/RECORD +22 -17
- {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/top_level.txt +0 -0
|
@@ -9,13 +9,12 @@ class FlowMatchEulerSampler:
|
|
|
9
9
|
self.mask = mask
|
|
10
10
|
|
|
11
11
|
def step(self, latents, model_outputs, i):
|
|
12
|
-
if self.mask is not None:
|
|
13
|
-
model_outputs = model_outputs * self.mask + self.init_latents * (1 - self.mask)
|
|
14
|
-
|
|
15
12
|
dt = self.sigmas[i + 1] - self.sigmas[i]
|
|
16
13
|
latents = latents.to(dtype=torch.float32)
|
|
17
14
|
latents = latents + model_outputs * dt
|
|
18
15
|
latents = latents.to(dtype=model_outputs.dtype)
|
|
16
|
+
if self.mask is not None:
|
|
17
|
+
latents = latents * self.mask + self.init_latents * (1 - self.mask)
|
|
19
18
|
return latents
|
|
20
19
|
|
|
21
20
|
def add_noise(self, latents, noise, sigma):
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
{
|
|
2
|
-
"
|
|
2
|
+
"has_clip_feature": true,
|
|
3
|
+
"has_vae_feature": true,
|
|
4
|
+
"flf_pos_emb": true,
|
|
3
5
|
"patch_size": [1, 2, 2],
|
|
4
6
|
"in_dim": 36,
|
|
5
7
|
"dim": 5120,
|
|
@@ -9,5 +11,6 @@
|
|
|
9
11
|
"out_dim": 16,
|
|
10
12
|
"num_heads": 40,
|
|
11
13
|
"num_layers": 40,
|
|
12
|
-
"eps": 1e-6
|
|
14
|
+
"eps": 1e-6,
|
|
15
|
+
"shift": 16.0
|
|
13
16
|
}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
{
|
|
2
|
+
"has_vae_feature": true,
|
|
3
|
+
"patch_size": [1, 2, 2],
|
|
4
|
+
"in_dim": 36,
|
|
5
|
+
"dim": 5120,
|
|
6
|
+
"ffn_dim": 13824,
|
|
7
|
+
"freq_dim": 256,
|
|
8
|
+
"text_dim": 4096,
|
|
9
|
+
"out_dim": 16,
|
|
10
|
+
"num_heads": 40,
|
|
11
|
+
"num_layers": 40,
|
|
12
|
+
"eps": 1e-6,
|
|
13
|
+
"boundary": 0.900,
|
|
14
|
+
"cfg_scale": [3.5, 3.5],
|
|
15
|
+
"num_inference_steps": 40
|
|
16
|
+
}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
{
|
|
2
|
+
"patch_size": [1, 2, 2],
|
|
3
|
+
"in_dim": 16,
|
|
4
|
+
"dim": 5120,
|
|
5
|
+
"ffn_dim": 13824,
|
|
6
|
+
"freq_dim": 256,
|
|
7
|
+
"text_dim": 4096,
|
|
8
|
+
"out_dim": 16,
|
|
9
|
+
"num_heads": 40,
|
|
10
|
+
"num_layers": 40,
|
|
11
|
+
"eps": 1e-6,
|
|
12
|
+
"boundary": 0.875,
|
|
13
|
+
"shift": 12.0,
|
|
14
|
+
"cfg_scale": [3.0, 4.0],
|
|
15
|
+
"num_inference_steps": 40
|
|
16
|
+
}
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
{
|
|
2
|
+
"in_channels": 3,
|
|
3
|
+
"out_channels": 3,
|
|
4
|
+
"encoder_dim": 96,
|
|
5
|
+
"decoder_dim": 96,
|
|
6
|
+
"z_dim": 16,
|
|
7
|
+
"dim_mult": [1, 2, 4, 4],
|
|
8
|
+
"num_res_blocks": 2,
|
|
9
|
+
"temperal_downsample": [false, true, true],
|
|
10
|
+
"dropout": 0.0,
|
|
11
|
+
"patch_size": 1,
|
|
12
|
+
"mean": [
|
|
13
|
+
-0.7571,
|
|
14
|
+
-0.7089,
|
|
15
|
+
-0.9113,
|
|
16
|
+
0.1075,
|
|
17
|
+
-0.1745,
|
|
18
|
+
0.9653,
|
|
19
|
+
-0.1517,
|
|
20
|
+
1.5508,
|
|
21
|
+
0.4134,
|
|
22
|
+
-0.0715,
|
|
23
|
+
0.5517,
|
|
24
|
+
-0.3632,
|
|
25
|
+
-0.1922,
|
|
26
|
+
-0.9497,
|
|
27
|
+
0.2503,
|
|
28
|
+
-0.2921
|
|
29
|
+
],
|
|
30
|
+
"std": [
|
|
31
|
+
2.8184,
|
|
32
|
+
1.4541,
|
|
33
|
+
2.3275,
|
|
34
|
+
2.6558,
|
|
35
|
+
1.2196,
|
|
36
|
+
1.7708,
|
|
37
|
+
2.6052,
|
|
38
|
+
2.0743,
|
|
39
|
+
3.2687,
|
|
40
|
+
2.1526,
|
|
41
|
+
2.8652,
|
|
42
|
+
1.5579,
|
|
43
|
+
1.6382,
|
|
44
|
+
1.1253,
|
|
45
|
+
2.8251,
|
|
46
|
+
1.9160
|
|
47
|
+
]
|
|
48
|
+
}
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
{
|
|
2
|
+
"in_channels": 12,
|
|
3
|
+
"out_channels": 12,
|
|
4
|
+
"encoder_dim": 160,
|
|
5
|
+
"decoder_dim": 256,
|
|
6
|
+
"z_dim": 48,
|
|
7
|
+
"dim_mult": [1, 2, 4, 4],
|
|
8
|
+
"num_res_blocks": 2,
|
|
9
|
+
"temperal_downsample": [false, true, true],
|
|
10
|
+
"dropout": 0.0,
|
|
11
|
+
"patch_size": 2,
|
|
12
|
+
"mean": [
|
|
13
|
+
-0.2289,
|
|
14
|
+
-0.0052,
|
|
15
|
+
-0.1323,
|
|
16
|
+
-0.2339,
|
|
17
|
+
-0.2799,
|
|
18
|
+
0.0174,
|
|
19
|
+
0.1838,
|
|
20
|
+
0.1557,
|
|
21
|
+
-0.1382,
|
|
22
|
+
0.0542,
|
|
23
|
+
0.2813,
|
|
24
|
+
0.0891,
|
|
25
|
+
0.1570,
|
|
26
|
+
-0.0098,
|
|
27
|
+
0.0375,
|
|
28
|
+
-0.1825,
|
|
29
|
+
-0.2246,
|
|
30
|
+
-0.1207,
|
|
31
|
+
-0.0698,
|
|
32
|
+
0.5109,
|
|
33
|
+
0.2665,
|
|
34
|
+
-0.2108,
|
|
35
|
+
-0.2158,
|
|
36
|
+
0.2502,
|
|
37
|
+
-0.2055,
|
|
38
|
+
-0.0322,
|
|
39
|
+
0.1109,
|
|
40
|
+
0.1567,
|
|
41
|
+
-0.0729,
|
|
42
|
+
0.0899,
|
|
43
|
+
-0.2799,
|
|
44
|
+
-0.1230,
|
|
45
|
+
-0.0313,
|
|
46
|
+
-0.1649,
|
|
47
|
+
0.0117,
|
|
48
|
+
0.0723,
|
|
49
|
+
-0.2839,
|
|
50
|
+
-0.2083,
|
|
51
|
+
-0.0520,
|
|
52
|
+
0.3748,
|
|
53
|
+
0.0152,
|
|
54
|
+
0.1957,
|
|
55
|
+
0.1433,
|
|
56
|
+
-0.2944,
|
|
57
|
+
0.3573,
|
|
58
|
+
-0.0548,
|
|
59
|
+
-0.1681,
|
|
60
|
+
-0.0667
|
|
61
|
+
],
|
|
62
|
+
"std": [
|
|
63
|
+
0.4765,
|
|
64
|
+
1.0364,
|
|
65
|
+
0.4514,
|
|
66
|
+
1.1677,
|
|
67
|
+
0.5313,
|
|
68
|
+
0.4990,
|
|
69
|
+
0.4818,
|
|
70
|
+
0.5013,
|
|
71
|
+
0.8158,
|
|
72
|
+
1.0344,
|
|
73
|
+
0.5894,
|
|
74
|
+
1.0901,
|
|
75
|
+
0.6885,
|
|
76
|
+
0.6165,
|
|
77
|
+
0.8454,
|
|
78
|
+
0.4978,
|
|
79
|
+
0.5759,
|
|
80
|
+
0.3523,
|
|
81
|
+
0.7135,
|
|
82
|
+
0.6804,
|
|
83
|
+
0.5833,
|
|
84
|
+
1.4146,
|
|
85
|
+
0.8986,
|
|
86
|
+
0.5659,
|
|
87
|
+
0.7069,
|
|
88
|
+
0.5338,
|
|
89
|
+
0.4889,
|
|
90
|
+
0.4917,
|
|
91
|
+
0.4069,
|
|
92
|
+
0.4999,
|
|
93
|
+
0.6866,
|
|
94
|
+
0.4093,
|
|
95
|
+
0.5709,
|
|
96
|
+
0.6065,
|
|
97
|
+
0.6415,
|
|
98
|
+
0.4944,
|
|
99
|
+
0.5726,
|
|
100
|
+
1.2042,
|
|
101
|
+
0.5458,
|
|
102
|
+
1.6887,
|
|
103
|
+
0.3971,
|
|
104
|
+
1.0600,
|
|
105
|
+
0.3943,
|
|
106
|
+
0.5537,
|
|
107
|
+
0.5444,
|
|
108
|
+
0.4089,
|
|
109
|
+
0.7468,
|
|
110
|
+
0.7744
|
|
111
|
+
]
|
|
112
|
+
}
|
|
@@ -139,7 +139,12 @@ class WanPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, Bas
|
|
|
139
139
|
vae_dtype: torch.dtype = torch.bfloat16
|
|
140
140
|
image_encoder_dtype: torch.dtype = torch.bfloat16
|
|
141
141
|
|
|
142
|
-
|
|
142
|
+
# default params set by model type
|
|
143
|
+
boundary: Optional[float] = field(default=None, init=False) # boundary
|
|
144
|
+
shift: Optional[float] = field(default=None, init=False) # RecifitedFlowScheduler shift factor
|
|
145
|
+
cfg_scale: Optional[float | Tuple[float, float]] = field(default=None, init=False) # default CFG scale
|
|
146
|
+
num_inference_steps: Optional[int] = field(default=None, init=False) # default inference steps
|
|
147
|
+
fps: Optional[int] = field(default=None, init=False) # default FPS
|
|
143
148
|
|
|
144
149
|
# override BaseConfig
|
|
145
150
|
vae_tiled: bool = True
|
|
@@ -10,10 +10,13 @@ from diffsynth_engine.models.basic import attention as attention_ops
|
|
|
10
10
|
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
|
|
11
11
|
from diffsynth_engine.models.utils import no_init_weights
|
|
12
12
|
from diffsynth_engine.utils.constants import (
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
13
|
+
WAN2_1_DIT_T2V_1_3B_CONFIG_FILE,
|
|
14
|
+
WAN2_1_DIT_I2V_14B_CONFIG_FILE,
|
|
15
|
+
WAN2_1_DIT_T2V_14B_CONFIG_FILE,
|
|
16
|
+
WAN2_1_DIT_FLF2V_14B_CONFIG_FILE,
|
|
17
|
+
WAN2_2_DIT_TI2V_5B_CONFIG_FILE,
|
|
18
|
+
WAN2_2_DIT_I2V_A14B_CONFIG_FILE,
|
|
19
|
+
WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
|
|
17
20
|
)
|
|
18
21
|
from diffsynth_engine.utils.gguf import gguf_inference
|
|
19
22
|
from diffsynth_engine.utils.parallel import (
|
|
@@ -182,7 +185,9 @@ class DiTBlock(nn.Module):
|
|
|
182
185
|
|
|
183
186
|
def forward(self, x, context, t_mod, freqs):
|
|
184
187
|
# msa: multi-head self-attention mlp: multi-layer perceptron
|
|
185
|
-
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp =
|
|
188
|
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
|
189
|
+
t.squeeze(1) for t in (self.modulation + t_mod).chunk(6, dim=1)
|
|
190
|
+
]
|
|
186
191
|
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
|
187
192
|
x = x + gate_msa * self.self_attn(input_x, freqs)
|
|
188
193
|
x = x + self.cross_attn(self.norm3(x), context)
|
|
@@ -237,7 +242,7 @@ class Head(nn.Module):
|
|
|
237
242
|
self.modulation = nn.Parameter(torch.randn(1, 2, dim, device=device, dtype=dtype) / dim**0.5)
|
|
238
243
|
|
|
239
244
|
def forward(self, x, t_mod):
|
|
240
|
-
shift, scale = (self.modulation + t_mod).chunk(2, dim=1)
|
|
245
|
+
shift, scale = [t.squeeze(1) for t in (self.modulation + t_mod.unsqueeze(1)).chunk(2, dim=1)]
|
|
241
246
|
x = self.head(self.norm(x) * (1 + scale) + shift)
|
|
242
247
|
return x
|
|
243
248
|
|
|
@@ -263,17 +268,22 @@ class WanDiT(PreTrainedModel):
|
|
|
263
268
|
patch_size: Tuple[int, int, int],
|
|
264
269
|
num_heads: int,
|
|
265
270
|
num_layers: int,
|
|
266
|
-
|
|
271
|
+
has_clip_feature: bool = False,
|
|
272
|
+
has_vae_feature: bool = False,
|
|
273
|
+
fuse_image_latents: bool = False,
|
|
267
274
|
flf_pos_emb: bool = False,
|
|
268
275
|
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
269
|
-
device: str = "
|
|
276
|
+
device: str = "cuda:0",
|
|
270
277
|
dtype: torch.dtype = torch.bfloat16,
|
|
271
278
|
):
|
|
272
279
|
super().__init__()
|
|
273
280
|
|
|
281
|
+
self.in_dim = in_dim
|
|
274
282
|
self.dim = dim
|
|
275
283
|
self.freq_dim = freq_dim
|
|
276
|
-
self.
|
|
284
|
+
self.has_clip_feature = has_clip_feature
|
|
285
|
+
self.has_vae_feature = has_vae_feature
|
|
286
|
+
self.fuse_image_latents = fuse_image_latents
|
|
277
287
|
self.patch_size = patch_size
|
|
278
288
|
|
|
279
289
|
self.patch_embedding = nn.Conv3d(
|
|
@@ -296,7 +306,7 @@ class WanDiT(PreTrainedModel):
|
|
|
296
306
|
)
|
|
297
307
|
self.blocks = nn.ModuleList(
|
|
298
308
|
[
|
|
299
|
-
DiTBlock(
|
|
309
|
+
DiTBlock(has_clip_feature, dim, num_heads, ffn_dim, eps, attn_kwargs, device=device, dtype=dtype)
|
|
300
310
|
for _ in range(num_layers)
|
|
301
311
|
]
|
|
302
312
|
)
|
|
@@ -305,7 +315,7 @@ class WanDiT(PreTrainedModel):
|
|
|
305
315
|
head_dim = dim // num_heads
|
|
306
316
|
self.freqs = precompute_freqs_cis_3d(head_dim)
|
|
307
317
|
|
|
308
|
-
if
|
|
318
|
+
if has_clip_feature:
|
|
309
319
|
self.img_emb = MLP(1280, dim, flf_pos_emb, device=device, dtype=dtype) # clip_feature_dim = 1280
|
|
310
320
|
|
|
311
321
|
def patchify(self, x: torch.Tensor):
|
|
@@ -339,13 +349,14 @@ class WanDiT(PreTrainedModel):
|
|
|
339
349
|
gguf_inference(),
|
|
340
350
|
cfg_parallel((x, context, timestep, clip_feature, y), use_cfg=use_cfg),
|
|
341
351
|
):
|
|
342
|
-
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
|
|
343
|
-
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
|
352
|
+
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) # (s, d)
|
|
353
|
+
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) # (s, 6, d)
|
|
344
354
|
context = self.text_embedding(context)
|
|
345
|
-
if self.
|
|
355
|
+
if self.has_vae_feature:
|
|
346
356
|
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
|
347
|
-
|
|
348
|
-
|
|
357
|
+
if self.has_clip_feature:
|
|
358
|
+
clip_embedding = self.img_emb(clip_feature)
|
|
359
|
+
context = torch.cat([clip_embedding, context], dim=1) # (b, s1 + s2, d)
|
|
349
360
|
x, (f, h, w) = self.patchify(x)
|
|
350
361
|
freqs = (
|
|
351
362
|
torch.cat(
|
|
@@ -360,7 +371,7 @@ class WanDiT(PreTrainedModel):
|
|
|
360
371
|
.to(x.device)
|
|
361
372
|
)
|
|
362
373
|
|
|
363
|
-
with sequence_parallel((x, freqs), seq_dims=(1, 0)):
|
|
374
|
+
with sequence_parallel((x, t, t_mod, freqs), seq_dims=(1, 0, 0, 0)):
|
|
364
375
|
for block in self.blocks:
|
|
365
376
|
x = block(x, context, t_mod, freqs)
|
|
366
377
|
x = self.head(x, t)
|
|
@@ -369,26 +380,35 @@ class WanDiT(PreTrainedModel):
|
|
|
369
380
|
(x,) = cfg_parallel_unshard((x,), use_cfg=use_cfg)
|
|
370
381
|
return x
|
|
371
382
|
|
|
383
|
+
@staticmethod
|
|
384
|
+
def get_model_config(model_type: str):
|
|
385
|
+
MODEL_CONFIG_FILES = {
|
|
386
|
+
"wan2.1-t2v-1.3b": WAN2_1_DIT_T2V_1_3B_CONFIG_FILE,
|
|
387
|
+
"wan2.1-t2v-14b": WAN2_1_DIT_T2V_14B_CONFIG_FILE,
|
|
388
|
+
"wan2.1-i2v-14b": WAN2_1_DIT_I2V_14B_CONFIG_FILE,
|
|
389
|
+
"wan2.1-flf2v-14b": WAN2_1_DIT_FLF2V_14B_CONFIG_FILE,
|
|
390
|
+
"wan2.2-ti2v-5b": WAN2_2_DIT_TI2V_5B_CONFIG_FILE,
|
|
391
|
+
"wan2.2-t2v-a14b": WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
|
|
392
|
+
"wan2.2-i2v-a14b": WAN2_2_DIT_I2V_A14B_CONFIG_FILE,
|
|
393
|
+
}
|
|
394
|
+
if model_type not in MODEL_CONFIG_FILES:
|
|
395
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
396
|
+
|
|
397
|
+
config_file = MODEL_CONFIG_FILES[model_type]
|
|
398
|
+
with open(config_file, "r") as f:
|
|
399
|
+
config = json.load(f)
|
|
400
|
+
return config
|
|
401
|
+
|
|
372
402
|
@classmethod
|
|
373
403
|
def from_state_dict(
|
|
374
404
|
cls,
|
|
375
|
-
state_dict,
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
405
|
+
state_dict: Dict[str, torch.Tensor],
|
|
406
|
+
config: Dict[str, Any],
|
|
407
|
+
device: str = "cuda:0",
|
|
408
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
379
409
|
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
380
|
-
assign=True,
|
|
410
|
+
assign: bool = True,
|
|
381
411
|
):
|
|
382
|
-
if model_type == "1.3b-t2v":
|
|
383
|
-
config = json.load(open(WAN_DIT_1_3B_T2V_CONFIG_FILE, "r"))
|
|
384
|
-
elif model_type == "14b-t2v":
|
|
385
|
-
config = json.load(open(WAN_DIT_14B_T2V_CONFIG_FILE, "r"))
|
|
386
|
-
elif model_type == "14b-i2v":
|
|
387
|
-
config = json.load(open(WAN_DIT_14B_I2V_CONFIG_FILE, "r"))
|
|
388
|
-
elif model_type == "14b-flf2v":
|
|
389
|
-
config = json.load(open(WAN_DIT_14B_FLF2V_CONFIG_FILE, "r"))
|
|
390
|
-
else:
|
|
391
|
-
raise ValueError(f"Unsupported model type: {model_type}")
|
|
392
412
|
with no_init_weights():
|
|
393
413
|
model = torch.nn.utils.skip_init(cls, **config, device=device, dtype=dtype, attn_kwargs=attn_kwargs)
|
|
394
414
|
model = model.requires_grad_(False)
|