diffsynth-engine 0.3.6.dev12__py3-none-any.whl → 0.3.6.dev14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +2 -3
- diffsynth_engine/conf/models/wan/dit/{14b-i2v.json → wan2.1-flf2v-14b.json} +5 -2
- diffsynth_engine/conf/models/wan/dit/{14b-flf2v.json → wan2.1-i2v-14b.json} +2 -2
- diffsynth_engine/conf/models/wan/dit/{1.3b-t2v.json → wan2.1-t2v-1.3b.json} +0 -1
- diffsynth_engine/conf/models/wan/dit/{14b-t2v.json → wan2.1-t2v-14b.json} +0 -1
- diffsynth_engine/conf/models/wan/dit/wan2.2-i2v-a14b.json +16 -0
- diffsynth_engine/conf/models/wan/dit/wan2.2-t2v-a14b.json +16 -0
- diffsynth_engine/conf/models/wan/dit/wan2.2-ti2v-5b.json +14 -0
- diffsynth_engine/conf/models/wan/vae/wan2.1-vae.json +48 -0
- diffsynth_engine/conf/models/wan/vae/wan2.2-vae.json +112 -0
- diffsynth_engine/configs/pipeline.py +6 -1
- diffsynth_engine/models/basic/attention.py +53 -33
- diffsynth_engine/models/wan/wan_dit.py +52 -32
- diffsynth_engine/models/wan/wan_vae.py +355 -60
- diffsynth_engine/pipelines/base.py +15 -11
- diffsynth_engine/pipelines/wan_video.py +175 -74
- diffsynth_engine/utils/constants.py +10 -4
- diffsynth_engine/utils/parallel.py +3 -1
- {diffsynth_engine-0.3.6.dev12.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.3.6.dev12.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/RECORD +23 -18
- {diffsynth_engine-0.3.6.dev12.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.3.6.dev12.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.3.6.dev12.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/top_level.txt +0 -0
|
@@ -9,13 +9,12 @@ class FlowMatchEulerSampler:
|
|
|
9
9
|
self.mask = mask
|
|
10
10
|
|
|
11
11
|
def step(self, latents, model_outputs, i):
|
|
12
|
-
if self.mask is not None:
|
|
13
|
-
model_outputs = model_outputs * self.mask + self.init_latents * (1 - self.mask)
|
|
14
|
-
|
|
15
12
|
dt = self.sigmas[i + 1] - self.sigmas[i]
|
|
16
13
|
latents = latents.to(dtype=torch.float32)
|
|
17
14
|
latents = latents + model_outputs * dt
|
|
18
15
|
latents = latents.to(dtype=model_outputs.dtype)
|
|
16
|
+
if self.mask is not None:
|
|
17
|
+
latents = latents * self.mask + self.init_latents * (1 - self.mask)
|
|
19
18
|
return latents
|
|
20
19
|
|
|
21
20
|
def add_noise(self, latents, noise, sigma):
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
{
|
|
2
|
-
"
|
|
2
|
+
"has_clip_feature": true,
|
|
3
|
+
"has_vae_feature": true,
|
|
4
|
+
"flf_pos_emb": true,
|
|
3
5
|
"patch_size": [1, 2, 2],
|
|
4
6
|
"in_dim": 36,
|
|
5
7
|
"dim": 5120,
|
|
@@ -9,5 +11,6 @@
|
|
|
9
11
|
"out_dim": 16,
|
|
10
12
|
"num_heads": 40,
|
|
11
13
|
"num_layers": 40,
|
|
12
|
-
"eps": 1e-6
|
|
14
|
+
"eps": 1e-6,
|
|
15
|
+
"shift": 16.0
|
|
13
16
|
}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
{
|
|
2
|
+
"has_vae_feature": true,
|
|
3
|
+
"patch_size": [1, 2, 2],
|
|
4
|
+
"in_dim": 36,
|
|
5
|
+
"dim": 5120,
|
|
6
|
+
"ffn_dim": 13824,
|
|
7
|
+
"freq_dim": 256,
|
|
8
|
+
"text_dim": 4096,
|
|
9
|
+
"out_dim": 16,
|
|
10
|
+
"num_heads": 40,
|
|
11
|
+
"num_layers": 40,
|
|
12
|
+
"eps": 1e-6,
|
|
13
|
+
"boundary": 0.900,
|
|
14
|
+
"cfg_scale": [3.5, 3.5],
|
|
15
|
+
"num_inference_steps": 40
|
|
16
|
+
}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
{
|
|
2
|
+
"patch_size": [1, 2, 2],
|
|
3
|
+
"in_dim": 16,
|
|
4
|
+
"dim": 5120,
|
|
5
|
+
"ffn_dim": 13824,
|
|
6
|
+
"freq_dim": 256,
|
|
7
|
+
"text_dim": 4096,
|
|
8
|
+
"out_dim": 16,
|
|
9
|
+
"num_heads": 40,
|
|
10
|
+
"num_layers": 40,
|
|
11
|
+
"eps": 1e-6,
|
|
12
|
+
"boundary": 0.875,
|
|
13
|
+
"shift": 12.0,
|
|
14
|
+
"cfg_scale": [3.0, 4.0],
|
|
15
|
+
"num_inference_steps": 40
|
|
16
|
+
}
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
{
|
|
2
|
+
"in_channels": 3,
|
|
3
|
+
"out_channels": 3,
|
|
4
|
+
"encoder_dim": 96,
|
|
5
|
+
"decoder_dim": 96,
|
|
6
|
+
"z_dim": 16,
|
|
7
|
+
"dim_mult": [1, 2, 4, 4],
|
|
8
|
+
"num_res_blocks": 2,
|
|
9
|
+
"temperal_downsample": [false, true, true],
|
|
10
|
+
"dropout": 0.0,
|
|
11
|
+
"patch_size": 1,
|
|
12
|
+
"mean": [
|
|
13
|
+
-0.7571,
|
|
14
|
+
-0.7089,
|
|
15
|
+
-0.9113,
|
|
16
|
+
0.1075,
|
|
17
|
+
-0.1745,
|
|
18
|
+
0.9653,
|
|
19
|
+
-0.1517,
|
|
20
|
+
1.5508,
|
|
21
|
+
0.4134,
|
|
22
|
+
-0.0715,
|
|
23
|
+
0.5517,
|
|
24
|
+
-0.3632,
|
|
25
|
+
-0.1922,
|
|
26
|
+
-0.9497,
|
|
27
|
+
0.2503,
|
|
28
|
+
-0.2921
|
|
29
|
+
],
|
|
30
|
+
"std": [
|
|
31
|
+
2.8184,
|
|
32
|
+
1.4541,
|
|
33
|
+
2.3275,
|
|
34
|
+
2.6558,
|
|
35
|
+
1.2196,
|
|
36
|
+
1.7708,
|
|
37
|
+
2.6052,
|
|
38
|
+
2.0743,
|
|
39
|
+
3.2687,
|
|
40
|
+
2.1526,
|
|
41
|
+
2.8652,
|
|
42
|
+
1.5579,
|
|
43
|
+
1.6382,
|
|
44
|
+
1.1253,
|
|
45
|
+
2.8251,
|
|
46
|
+
1.9160
|
|
47
|
+
]
|
|
48
|
+
}
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
{
|
|
2
|
+
"in_channels": 12,
|
|
3
|
+
"out_channels": 12,
|
|
4
|
+
"encoder_dim": 160,
|
|
5
|
+
"decoder_dim": 256,
|
|
6
|
+
"z_dim": 48,
|
|
7
|
+
"dim_mult": [1, 2, 4, 4],
|
|
8
|
+
"num_res_blocks": 2,
|
|
9
|
+
"temperal_downsample": [false, true, true],
|
|
10
|
+
"dropout": 0.0,
|
|
11
|
+
"patch_size": 2,
|
|
12
|
+
"mean": [
|
|
13
|
+
-0.2289,
|
|
14
|
+
-0.0052,
|
|
15
|
+
-0.1323,
|
|
16
|
+
-0.2339,
|
|
17
|
+
-0.2799,
|
|
18
|
+
0.0174,
|
|
19
|
+
0.1838,
|
|
20
|
+
0.1557,
|
|
21
|
+
-0.1382,
|
|
22
|
+
0.0542,
|
|
23
|
+
0.2813,
|
|
24
|
+
0.0891,
|
|
25
|
+
0.1570,
|
|
26
|
+
-0.0098,
|
|
27
|
+
0.0375,
|
|
28
|
+
-0.1825,
|
|
29
|
+
-0.2246,
|
|
30
|
+
-0.1207,
|
|
31
|
+
-0.0698,
|
|
32
|
+
0.5109,
|
|
33
|
+
0.2665,
|
|
34
|
+
-0.2108,
|
|
35
|
+
-0.2158,
|
|
36
|
+
0.2502,
|
|
37
|
+
-0.2055,
|
|
38
|
+
-0.0322,
|
|
39
|
+
0.1109,
|
|
40
|
+
0.1567,
|
|
41
|
+
-0.0729,
|
|
42
|
+
0.0899,
|
|
43
|
+
-0.2799,
|
|
44
|
+
-0.1230,
|
|
45
|
+
-0.0313,
|
|
46
|
+
-0.1649,
|
|
47
|
+
0.0117,
|
|
48
|
+
0.0723,
|
|
49
|
+
-0.2839,
|
|
50
|
+
-0.2083,
|
|
51
|
+
-0.0520,
|
|
52
|
+
0.3748,
|
|
53
|
+
0.0152,
|
|
54
|
+
0.1957,
|
|
55
|
+
0.1433,
|
|
56
|
+
-0.2944,
|
|
57
|
+
0.3573,
|
|
58
|
+
-0.0548,
|
|
59
|
+
-0.1681,
|
|
60
|
+
-0.0667
|
|
61
|
+
],
|
|
62
|
+
"std": [
|
|
63
|
+
0.4765,
|
|
64
|
+
1.0364,
|
|
65
|
+
0.4514,
|
|
66
|
+
1.1677,
|
|
67
|
+
0.5313,
|
|
68
|
+
0.4990,
|
|
69
|
+
0.4818,
|
|
70
|
+
0.5013,
|
|
71
|
+
0.8158,
|
|
72
|
+
1.0344,
|
|
73
|
+
0.5894,
|
|
74
|
+
1.0901,
|
|
75
|
+
0.6885,
|
|
76
|
+
0.6165,
|
|
77
|
+
0.8454,
|
|
78
|
+
0.4978,
|
|
79
|
+
0.5759,
|
|
80
|
+
0.3523,
|
|
81
|
+
0.7135,
|
|
82
|
+
0.6804,
|
|
83
|
+
0.5833,
|
|
84
|
+
1.4146,
|
|
85
|
+
0.8986,
|
|
86
|
+
0.5659,
|
|
87
|
+
0.7069,
|
|
88
|
+
0.5338,
|
|
89
|
+
0.4889,
|
|
90
|
+
0.4917,
|
|
91
|
+
0.4069,
|
|
92
|
+
0.4999,
|
|
93
|
+
0.6866,
|
|
94
|
+
0.4093,
|
|
95
|
+
0.5709,
|
|
96
|
+
0.6065,
|
|
97
|
+
0.6415,
|
|
98
|
+
0.4944,
|
|
99
|
+
0.5726,
|
|
100
|
+
1.2042,
|
|
101
|
+
0.5458,
|
|
102
|
+
1.6887,
|
|
103
|
+
0.3971,
|
|
104
|
+
1.0600,
|
|
105
|
+
0.3943,
|
|
106
|
+
0.5537,
|
|
107
|
+
0.5444,
|
|
108
|
+
0.4089,
|
|
109
|
+
0.7468,
|
|
110
|
+
0.7744
|
|
111
|
+
]
|
|
112
|
+
}
|
|
@@ -139,7 +139,12 @@ class WanPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, Bas
|
|
|
139
139
|
vae_dtype: torch.dtype = torch.bfloat16
|
|
140
140
|
image_encoder_dtype: torch.dtype = torch.bfloat16
|
|
141
141
|
|
|
142
|
-
|
|
142
|
+
# default params set by model type
|
|
143
|
+
boundary: Optional[float] = field(default=None, init=False) # boundary
|
|
144
|
+
shift: Optional[float] = field(default=None, init=False) # RecifitedFlowScheduler shift factor
|
|
145
|
+
cfg_scale: Optional[float | Tuple[float, float]] = field(default=None, init=False) # default CFG scale
|
|
146
|
+
num_inference_steps: Optional[int] = field(default=None, init=False) # default inference steps
|
|
147
|
+
fps: Optional[int] = field(default=None, init=False) # default FPS
|
|
143
148
|
|
|
144
149
|
# override BaseConfig
|
|
145
150
|
vae_tiled: bool = True
|
|
@@ -14,6 +14,8 @@ from diffsynth_engine.utils.flag import (
|
|
|
14
14
|
SPARGE_ATTN_AVAILABLE,
|
|
15
15
|
)
|
|
16
16
|
|
|
17
|
+
FA3_MAX_HEADDIM = 256
|
|
18
|
+
|
|
17
19
|
logger = logging.get_logger(__name__)
|
|
18
20
|
|
|
19
21
|
|
|
@@ -130,31 +132,40 @@ def attention(
|
|
|
130
132
|
"sage_attn",
|
|
131
133
|
"sparge_attn",
|
|
132
134
|
]
|
|
135
|
+
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
|
|
133
136
|
if attn_impl is None or attn_impl == "auto":
|
|
134
137
|
if FLASH_ATTN_3_AVAILABLE:
|
|
135
|
-
|
|
136
|
-
|
|
138
|
+
if flash_attn3_compatible:
|
|
139
|
+
return flash_attn3(q, k, v, softmax_scale=scale)
|
|
140
|
+
else:
|
|
141
|
+
logger.warning(
|
|
142
|
+
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
|
|
143
|
+
)
|
|
144
|
+
if XFORMERS_AVAILABLE:
|
|
137
145
|
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
138
|
-
|
|
146
|
+
if SDPA_AVAILABLE:
|
|
139
147
|
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
140
|
-
|
|
148
|
+
if FLASH_ATTN_2_AVAILABLE:
|
|
141
149
|
return flash_attn2(q, k, v, softmax_scale=scale)
|
|
142
|
-
|
|
143
|
-
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
150
|
+
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
144
151
|
else:
|
|
145
152
|
if attn_impl == "eager":
|
|
146
153
|
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
147
|
-
|
|
154
|
+
if attn_impl == "flash_attn_3":
|
|
155
|
+
if not flash_attn3_compatible:
|
|
156
|
+
raise RuntimeError(
|
|
157
|
+
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
|
|
158
|
+
)
|
|
148
159
|
return flash_attn3(q, k, v, softmax_scale=scale)
|
|
149
|
-
|
|
160
|
+
if attn_impl == "flash_attn_2":
|
|
150
161
|
return flash_attn2(q, k, v, softmax_scale=scale)
|
|
151
|
-
|
|
162
|
+
if attn_impl == "xformers":
|
|
152
163
|
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
153
|
-
|
|
164
|
+
if attn_impl == "sdpa":
|
|
154
165
|
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
155
|
-
|
|
166
|
+
if attn_impl == "sage_attn":
|
|
156
167
|
return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
157
|
-
|
|
168
|
+
if attn_impl == "sparge_attn":
|
|
158
169
|
return sparge_attn(
|
|
159
170
|
q,
|
|
160
171
|
k,
|
|
@@ -166,8 +177,7 @@ def attention(
|
|
|
166
177
|
cdfthreshd=kwargs.get("sparge_cdfthreshd", 0.98),
|
|
167
178
|
pvthreshd=kwargs.get("sparge_pvthreshd", 50),
|
|
168
179
|
)
|
|
169
|
-
|
|
170
|
-
raise ValueError(f"Invalid attention implementation: {attn_impl}")
|
|
180
|
+
raise ValueError(f"Invalid attention implementation: {attn_impl}")
|
|
171
181
|
|
|
172
182
|
|
|
173
183
|
class Attention(nn.Module):
|
|
@@ -240,32 +250,42 @@ def long_context_attention(
|
|
|
240
250
|
"sage_attn",
|
|
241
251
|
"sparge_attn",
|
|
242
252
|
]
|
|
253
|
+
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
|
|
243
254
|
if attn_impl is None or attn_impl == "auto":
|
|
244
255
|
if FLASH_ATTN_3_AVAILABLE:
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
256
|
+
if flash_attn3_compatible:
|
|
257
|
+
return LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
|
|
258
|
+
else:
|
|
259
|
+
logger.warning(
|
|
260
|
+
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
|
|
261
|
+
)
|
|
262
|
+
if SDPA_AVAILABLE:
|
|
263
|
+
return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
|
|
264
|
+
if FLASH_ATTN_2_AVAILABLE:
|
|
265
|
+
return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
|
|
266
|
+
raise ValueError("No available long context attention implementation")
|
|
252
267
|
else:
|
|
253
268
|
if attn_impl == "flash_attn_3":
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
269
|
+
if flash_attn3_compatible:
|
|
270
|
+
return LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
|
|
271
|
+
else:
|
|
272
|
+
raise RuntimeError(
|
|
273
|
+
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
|
|
274
|
+
)
|
|
275
|
+
if attn_impl == "flash_attn_2":
|
|
276
|
+
return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
|
|
277
|
+
if attn_impl == "sdpa":
|
|
278
|
+
return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
|
|
279
|
+
if attn_impl == "sage_attn":
|
|
280
|
+
return LongContextAttention(attn_type=AttnType.SAGE_FP8)(q, k, v, softmax_scale=scale)
|
|
281
|
+
if attn_impl == "sparge_attn":
|
|
262
282
|
attn_processor = SparseAttentionMeansim()
|
|
263
283
|
# default args from spas_sage2_attn_meansim_cuda
|
|
264
284
|
attn_processor.smooth_k = torch.tensor(kwargs.get("sparge_smooth_k", True))
|
|
265
285
|
attn_processor.simthreshd1 = torch.tensor(kwargs.get("sparge_simthreshd1", 0.6))
|
|
266
286
|
attn_processor.cdfthreshd = torch.tensor(kwargs.get("sparge_cdfthreshd", 0.98))
|
|
267
287
|
attn_processor.pvthreshd = torch.tensor(kwargs.get("sparge_pvthreshd", 50))
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
288
|
+
return LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)(
|
|
289
|
+
q, k, v, softmax_scale=scale
|
|
290
|
+
)
|
|
291
|
+
raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
|
|
@@ -10,10 +10,13 @@ from diffsynth_engine.models.basic import attention as attention_ops
|
|
|
10
10
|
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
|
|
11
11
|
from diffsynth_engine.models.utils import no_init_weights
|
|
12
12
|
from diffsynth_engine.utils.constants import (
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
13
|
+
WAN2_1_DIT_T2V_1_3B_CONFIG_FILE,
|
|
14
|
+
WAN2_1_DIT_I2V_14B_CONFIG_FILE,
|
|
15
|
+
WAN2_1_DIT_T2V_14B_CONFIG_FILE,
|
|
16
|
+
WAN2_1_DIT_FLF2V_14B_CONFIG_FILE,
|
|
17
|
+
WAN2_2_DIT_TI2V_5B_CONFIG_FILE,
|
|
18
|
+
WAN2_2_DIT_I2V_A14B_CONFIG_FILE,
|
|
19
|
+
WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
|
|
17
20
|
)
|
|
18
21
|
from diffsynth_engine.utils.gguf import gguf_inference
|
|
19
22
|
from diffsynth_engine.utils.parallel import (
|
|
@@ -182,7 +185,9 @@ class DiTBlock(nn.Module):
|
|
|
182
185
|
|
|
183
186
|
def forward(self, x, context, t_mod, freqs):
|
|
184
187
|
# msa: multi-head self-attention mlp: multi-layer perceptron
|
|
185
|
-
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp =
|
|
188
|
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
|
189
|
+
t.squeeze(1) for t in (self.modulation + t_mod).chunk(6, dim=1)
|
|
190
|
+
]
|
|
186
191
|
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
|
187
192
|
x = x + gate_msa * self.self_attn(input_x, freqs)
|
|
188
193
|
x = x + self.cross_attn(self.norm3(x), context)
|
|
@@ -237,7 +242,7 @@ class Head(nn.Module):
|
|
|
237
242
|
self.modulation = nn.Parameter(torch.randn(1, 2, dim, device=device, dtype=dtype) / dim**0.5)
|
|
238
243
|
|
|
239
244
|
def forward(self, x, t_mod):
|
|
240
|
-
shift, scale = (self.modulation + t_mod).chunk(2, dim=1)
|
|
245
|
+
shift, scale = [t.squeeze(1) for t in (self.modulation + t_mod.unsqueeze(1)).chunk(2, dim=1)]
|
|
241
246
|
x = self.head(self.norm(x) * (1 + scale) + shift)
|
|
242
247
|
return x
|
|
243
248
|
|
|
@@ -263,17 +268,22 @@ class WanDiT(PreTrainedModel):
|
|
|
263
268
|
patch_size: Tuple[int, int, int],
|
|
264
269
|
num_heads: int,
|
|
265
270
|
num_layers: int,
|
|
266
|
-
|
|
271
|
+
has_clip_feature: bool = False,
|
|
272
|
+
has_vae_feature: bool = False,
|
|
273
|
+
fuse_image_latents: bool = False,
|
|
267
274
|
flf_pos_emb: bool = False,
|
|
268
275
|
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
269
|
-
device: str = "
|
|
276
|
+
device: str = "cuda:0",
|
|
270
277
|
dtype: torch.dtype = torch.bfloat16,
|
|
271
278
|
):
|
|
272
279
|
super().__init__()
|
|
273
280
|
|
|
281
|
+
self.in_dim = in_dim
|
|
274
282
|
self.dim = dim
|
|
275
283
|
self.freq_dim = freq_dim
|
|
276
|
-
self.
|
|
284
|
+
self.has_clip_feature = has_clip_feature
|
|
285
|
+
self.has_vae_feature = has_vae_feature
|
|
286
|
+
self.fuse_image_latents = fuse_image_latents
|
|
277
287
|
self.patch_size = patch_size
|
|
278
288
|
|
|
279
289
|
self.patch_embedding = nn.Conv3d(
|
|
@@ -296,7 +306,7 @@ class WanDiT(PreTrainedModel):
|
|
|
296
306
|
)
|
|
297
307
|
self.blocks = nn.ModuleList(
|
|
298
308
|
[
|
|
299
|
-
DiTBlock(
|
|
309
|
+
DiTBlock(has_clip_feature, dim, num_heads, ffn_dim, eps, attn_kwargs, device=device, dtype=dtype)
|
|
300
310
|
for _ in range(num_layers)
|
|
301
311
|
]
|
|
302
312
|
)
|
|
@@ -305,7 +315,7 @@ class WanDiT(PreTrainedModel):
|
|
|
305
315
|
head_dim = dim // num_heads
|
|
306
316
|
self.freqs = precompute_freqs_cis_3d(head_dim)
|
|
307
317
|
|
|
308
|
-
if
|
|
318
|
+
if has_clip_feature:
|
|
309
319
|
self.img_emb = MLP(1280, dim, flf_pos_emb, device=device, dtype=dtype) # clip_feature_dim = 1280
|
|
310
320
|
|
|
311
321
|
def patchify(self, x: torch.Tensor):
|
|
@@ -339,13 +349,14 @@ class WanDiT(PreTrainedModel):
|
|
|
339
349
|
gguf_inference(),
|
|
340
350
|
cfg_parallel((x, context, timestep, clip_feature, y), use_cfg=use_cfg),
|
|
341
351
|
):
|
|
342
|
-
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
|
|
343
|
-
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
|
352
|
+
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) # (s, d)
|
|
353
|
+
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) # (s, 6, d)
|
|
344
354
|
context = self.text_embedding(context)
|
|
345
|
-
if self.
|
|
355
|
+
if self.has_vae_feature:
|
|
346
356
|
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
|
347
|
-
|
|
348
|
-
|
|
357
|
+
if self.has_clip_feature:
|
|
358
|
+
clip_embedding = self.img_emb(clip_feature)
|
|
359
|
+
context = torch.cat([clip_embedding, context], dim=1) # (b, s1 + s2, d)
|
|
349
360
|
x, (f, h, w) = self.patchify(x)
|
|
350
361
|
freqs = (
|
|
351
362
|
torch.cat(
|
|
@@ -360,7 +371,7 @@ class WanDiT(PreTrainedModel):
|
|
|
360
371
|
.to(x.device)
|
|
361
372
|
)
|
|
362
373
|
|
|
363
|
-
with sequence_parallel((x, freqs), seq_dims=(1, 0)):
|
|
374
|
+
with sequence_parallel((x, t, t_mod, freqs), seq_dims=(1, 0, 0, 0)):
|
|
364
375
|
for block in self.blocks:
|
|
365
376
|
x = block(x, context, t_mod, freqs)
|
|
366
377
|
x = self.head(x, t)
|
|
@@ -369,26 +380,35 @@ class WanDiT(PreTrainedModel):
|
|
|
369
380
|
(x,) = cfg_parallel_unshard((x,), use_cfg=use_cfg)
|
|
370
381
|
return x
|
|
371
382
|
|
|
383
|
+
@staticmethod
|
|
384
|
+
def get_model_config(model_type: str):
|
|
385
|
+
MODEL_CONFIG_FILES = {
|
|
386
|
+
"wan2.1-t2v-1.3b": WAN2_1_DIT_T2V_1_3B_CONFIG_FILE,
|
|
387
|
+
"wan2.1-t2v-14b": WAN2_1_DIT_T2V_14B_CONFIG_FILE,
|
|
388
|
+
"wan2.1-i2v-14b": WAN2_1_DIT_I2V_14B_CONFIG_FILE,
|
|
389
|
+
"wan2.1-flf2v-14b": WAN2_1_DIT_FLF2V_14B_CONFIG_FILE,
|
|
390
|
+
"wan2.2-ti2v-5b": WAN2_2_DIT_TI2V_5B_CONFIG_FILE,
|
|
391
|
+
"wan2.2-t2v-a14b": WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
|
|
392
|
+
"wan2.2-i2v-a14b": WAN2_2_DIT_I2V_A14B_CONFIG_FILE,
|
|
393
|
+
}
|
|
394
|
+
if model_type not in MODEL_CONFIG_FILES:
|
|
395
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
396
|
+
|
|
397
|
+
config_file = MODEL_CONFIG_FILES[model_type]
|
|
398
|
+
with open(config_file, "r") as f:
|
|
399
|
+
config = json.load(f)
|
|
400
|
+
return config
|
|
401
|
+
|
|
372
402
|
@classmethod
|
|
373
403
|
def from_state_dict(
|
|
374
404
|
cls,
|
|
375
|
-
state_dict,
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
405
|
+
state_dict: Dict[str, torch.Tensor],
|
|
406
|
+
config: Dict[str, Any],
|
|
407
|
+
device: str = "cuda:0",
|
|
408
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
379
409
|
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
380
|
-
assign=True,
|
|
410
|
+
assign: bool = True,
|
|
381
411
|
):
|
|
382
|
-
if model_type == "1.3b-t2v":
|
|
383
|
-
config = json.load(open(WAN_DIT_1_3B_T2V_CONFIG_FILE, "r"))
|
|
384
|
-
elif model_type == "14b-t2v":
|
|
385
|
-
config = json.load(open(WAN_DIT_14B_T2V_CONFIG_FILE, "r"))
|
|
386
|
-
elif model_type == "14b-i2v":
|
|
387
|
-
config = json.load(open(WAN_DIT_14B_I2V_CONFIG_FILE, "r"))
|
|
388
|
-
elif model_type == "14b-flf2v":
|
|
389
|
-
config = json.load(open(WAN_DIT_14B_FLF2V_CONFIG_FILE, "r"))
|
|
390
|
-
else:
|
|
391
|
-
raise ValueError(f"Unsupported model type: {model_type}")
|
|
392
412
|
with no_init_weights():
|
|
393
413
|
model = torch.nn.utils.skip_init(cls, **config, device=device, dtype=dtype, attn_kwargs=attn_kwargs)
|
|
394
414
|
model = model.requires_grad_(False)
|