diffsynth-engine 0.3.6.dev13__py3-none-any.whl → 0.3.6.dev14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +2 -3
- diffsynth_engine/conf/models/wan/dit/{14b-i2v.json → wan2.1-flf2v-14b.json} +5 -2
- diffsynth_engine/conf/models/wan/dit/{14b-flf2v.json → wan2.1-i2v-14b.json} +2 -2
- diffsynth_engine/conf/models/wan/dit/{1.3b-t2v.json → wan2.1-t2v-1.3b.json} +0 -1
- diffsynth_engine/conf/models/wan/dit/{14b-t2v.json → wan2.1-t2v-14b.json} +0 -1
- diffsynth_engine/conf/models/wan/dit/wan2.2-i2v-a14b.json +16 -0
- diffsynth_engine/conf/models/wan/dit/wan2.2-t2v-a14b.json +16 -0
- diffsynth_engine/conf/models/wan/dit/wan2.2-ti2v-5b.json +14 -0
- diffsynth_engine/conf/models/wan/vae/wan2.1-vae.json +48 -0
- diffsynth_engine/conf/models/wan/vae/wan2.2-vae.json +112 -0
- diffsynth_engine/configs/pipeline.py +6 -1
- diffsynth_engine/models/wan/wan_dit.py +52 -32
- diffsynth_engine/models/wan/wan_vae.py +355 -60
- diffsynth_engine/pipelines/base.py +15 -11
- diffsynth_engine/pipelines/wan_video.py +175 -74
- diffsynth_engine/utils/constants.py +10 -4
- diffsynth_engine/utils/parallel.py +3 -1
- {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/RECORD +22 -17
- {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,15 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import torch
|
|
2
3
|
import torch.nn as nn
|
|
3
4
|
import torch.nn.functional as F
|
|
4
5
|
import torch.distributed as dist
|
|
5
6
|
from einops import rearrange, repeat
|
|
6
7
|
from tqdm import tqdm
|
|
8
|
+
from typing import Any, Dict
|
|
7
9
|
|
|
8
10
|
from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
|
|
9
11
|
from diffsynth_engine.models.utils import no_init_weights
|
|
12
|
+
from diffsynth_engine.utils.constants import WAN2_1_VAE_CONFIG_FILE, WAN2_2_VAE_CONFIG_FILE
|
|
10
13
|
|
|
11
14
|
CACHE_T = 2
|
|
12
15
|
|
|
@@ -77,7 +80,7 @@ class Upsample(nn.Upsample):
|
|
|
77
80
|
|
|
78
81
|
|
|
79
82
|
class Resample(nn.Module):
|
|
80
|
-
def __init__(self, dim, mode):
|
|
83
|
+
def __init__(self, dim, mode, keep_channels=False):
|
|
81
84
|
assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d")
|
|
82
85
|
super().__init__()
|
|
83
86
|
self.dim = dim
|
|
@@ -86,11 +89,13 @@ class Resample(nn.Module):
|
|
|
86
89
|
# layers
|
|
87
90
|
if mode == "upsample2d":
|
|
88
91
|
self.resample = nn.Sequential(
|
|
89
|
-
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
|
92
|
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
|
93
|
+
nn.Conv2d(dim, dim if keep_channels else dim // 2, 3, padding=1),
|
|
90
94
|
)
|
|
91
95
|
elif mode == "upsample3d":
|
|
92
96
|
self.resample = nn.Sequential(
|
|
93
|
-
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
|
97
|
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
|
98
|
+
nn.Conv2d(dim, dim if keep_channels else dim // 2, 3, padding=1),
|
|
94
99
|
)
|
|
95
100
|
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
|
96
101
|
|
|
@@ -218,47 +223,259 @@ class AttentionBlock(nn.Module):
|
|
|
218
223
|
return x + identity
|
|
219
224
|
|
|
220
225
|
|
|
226
|
+
def patchify(x, patch_size):
|
|
227
|
+
if patch_size == 1:
|
|
228
|
+
return x
|
|
229
|
+
if x.dim() == 4:
|
|
230
|
+
x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
|
|
231
|
+
elif x.dim() == 5:
|
|
232
|
+
x = rearrange(
|
|
233
|
+
x,
|
|
234
|
+
"b c f (h q) (w r) -> b (c r q) f h w",
|
|
235
|
+
q=patch_size,
|
|
236
|
+
r=patch_size,
|
|
237
|
+
)
|
|
238
|
+
else:
|
|
239
|
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
|
240
|
+
|
|
241
|
+
return x
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def unpatchify(x, patch_size):
|
|
245
|
+
if patch_size == 1:
|
|
246
|
+
return x
|
|
247
|
+
|
|
248
|
+
if x.dim() == 4:
|
|
249
|
+
x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
|
|
250
|
+
elif x.dim() == 5:
|
|
251
|
+
x = rearrange(
|
|
252
|
+
x,
|
|
253
|
+
"b (c r q) f h w -> b c f (h q) (w r)",
|
|
254
|
+
q=patch_size,
|
|
255
|
+
r=patch_size,
|
|
256
|
+
)
|
|
257
|
+
return x
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class AvgDown3D(nn.Module):
|
|
261
|
+
def __init__(
|
|
262
|
+
self,
|
|
263
|
+
in_channels,
|
|
264
|
+
out_channels,
|
|
265
|
+
factor_t,
|
|
266
|
+
factor_s=1,
|
|
267
|
+
):
|
|
268
|
+
super().__init__()
|
|
269
|
+
self.in_channels = in_channels
|
|
270
|
+
self.out_channels = out_channels
|
|
271
|
+
self.factor_t = factor_t
|
|
272
|
+
self.factor_s = factor_s
|
|
273
|
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
|
274
|
+
|
|
275
|
+
assert in_channels * self.factor % out_channels == 0
|
|
276
|
+
self.group_size = in_channels * self.factor // out_channels
|
|
277
|
+
|
|
278
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
279
|
+
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
|
280
|
+
pad = (0, 0, 0, 0, pad_t, 0)
|
|
281
|
+
x = F.pad(x, pad)
|
|
282
|
+
B, C, T, H, W = x.shape
|
|
283
|
+
x = x.view(
|
|
284
|
+
B,
|
|
285
|
+
C,
|
|
286
|
+
T // self.factor_t,
|
|
287
|
+
self.factor_t,
|
|
288
|
+
H // self.factor_s,
|
|
289
|
+
self.factor_s,
|
|
290
|
+
W // self.factor_s,
|
|
291
|
+
self.factor_s,
|
|
292
|
+
)
|
|
293
|
+
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
|
294
|
+
x = x.view(
|
|
295
|
+
B,
|
|
296
|
+
C * self.factor,
|
|
297
|
+
T // self.factor_t,
|
|
298
|
+
H // self.factor_s,
|
|
299
|
+
W // self.factor_s,
|
|
300
|
+
)
|
|
301
|
+
x = x.view(
|
|
302
|
+
B,
|
|
303
|
+
self.out_channels,
|
|
304
|
+
self.group_size,
|
|
305
|
+
T // self.factor_t,
|
|
306
|
+
H // self.factor_s,
|
|
307
|
+
W // self.factor_s,
|
|
308
|
+
)
|
|
309
|
+
x = x.mean(dim=2)
|
|
310
|
+
return x
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
class DupUp3D(nn.Module):
|
|
314
|
+
def __init__(
|
|
315
|
+
self,
|
|
316
|
+
in_channels: int,
|
|
317
|
+
out_channels: int,
|
|
318
|
+
factor_t,
|
|
319
|
+
factor_s=1,
|
|
320
|
+
):
|
|
321
|
+
super().__init__()
|
|
322
|
+
self.in_channels = in_channels
|
|
323
|
+
self.out_channels = out_channels
|
|
324
|
+
|
|
325
|
+
self.factor_t = factor_t
|
|
326
|
+
self.factor_s = factor_s
|
|
327
|
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
|
328
|
+
|
|
329
|
+
assert out_channels * self.factor % in_channels == 0
|
|
330
|
+
self.repeats = out_channels * self.factor // in_channels
|
|
331
|
+
|
|
332
|
+
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
|
333
|
+
x = x.repeat_interleave(self.repeats, dim=1)
|
|
334
|
+
x = x.view(
|
|
335
|
+
x.size(0),
|
|
336
|
+
self.out_channels,
|
|
337
|
+
self.factor_t,
|
|
338
|
+
self.factor_s,
|
|
339
|
+
self.factor_s,
|
|
340
|
+
x.size(2),
|
|
341
|
+
x.size(3),
|
|
342
|
+
x.size(4),
|
|
343
|
+
)
|
|
344
|
+
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
|
345
|
+
x = x.view(
|
|
346
|
+
x.size(0),
|
|
347
|
+
self.out_channels,
|
|
348
|
+
x.size(2) * self.factor_t,
|
|
349
|
+
x.size(4) * self.factor_s,
|
|
350
|
+
x.size(6) * self.factor_s,
|
|
351
|
+
)
|
|
352
|
+
if first_chunk:
|
|
353
|
+
x = x[:, :, self.factor_t - 1 :, :, :]
|
|
354
|
+
return x
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
class Down_ResidualBlock(nn.Module):
|
|
358
|
+
def __init__(self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False):
|
|
359
|
+
super().__init__()
|
|
360
|
+
|
|
361
|
+
# Shortcut path with downsample
|
|
362
|
+
self.avg_shortcut = AvgDown3D(
|
|
363
|
+
in_dim,
|
|
364
|
+
out_dim,
|
|
365
|
+
factor_t=2 if temperal_downsample else 1,
|
|
366
|
+
factor_s=2 if down_flag else 1,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
# Main path with residual blocks and downsample
|
|
370
|
+
downsamples = []
|
|
371
|
+
for _ in range(mult):
|
|
372
|
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
|
373
|
+
in_dim = out_dim
|
|
374
|
+
|
|
375
|
+
# Add the final downsample block
|
|
376
|
+
if down_flag:
|
|
377
|
+
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
|
378
|
+
downsamples.append(Resample(out_dim, mode=mode))
|
|
379
|
+
|
|
380
|
+
self.downsamples = nn.Sequential(*downsamples)
|
|
381
|
+
|
|
382
|
+
def forward(self, x, feat_cache=None):
|
|
383
|
+
x_copy = x.clone()
|
|
384
|
+
for module in self.downsamples:
|
|
385
|
+
x = module(x, feat_cache)
|
|
386
|
+
|
|
387
|
+
return x + self.avg_shortcut(x_copy)
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
class Up_ResidualBlock(nn.Module):
|
|
391
|
+
def __init__(self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False):
|
|
392
|
+
super().__init__()
|
|
393
|
+
# Shortcut path with upsample
|
|
394
|
+
if up_flag:
|
|
395
|
+
self.avg_shortcut = DupUp3D(
|
|
396
|
+
in_dim,
|
|
397
|
+
out_dim,
|
|
398
|
+
factor_t=2 if temperal_upsample else 1,
|
|
399
|
+
factor_s=2 if up_flag else 1,
|
|
400
|
+
)
|
|
401
|
+
else:
|
|
402
|
+
self.avg_shortcut = None
|
|
403
|
+
|
|
404
|
+
# Main path with residual blocks and upsample
|
|
405
|
+
upsamples = []
|
|
406
|
+
for _ in range(mult):
|
|
407
|
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
|
408
|
+
in_dim = out_dim
|
|
409
|
+
|
|
410
|
+
# Add the final upsample block
|
|
411
|
+
if up_flag:
|
|
412
|
+
mode = "upsample3d" if temperal_upsample else "upsample2d"
|
|
413
|
+
upsamples.append(Resample(out_dim, mode=mode, keep_channels=True))
|
|
414
|
+
|
|
415
|
+
self.upsamples = nn.Sequential(*upsamples)
|
|
416
|
+
|
|
417
|
+
def forward(self, x, feat_cache=None, first_chunk=False):
|
|
418
|
+
x_main = x.clone()
|
|
419
|
+
for module in self.upsamples:
|
|
420
|
+
x_main = module(x_main, feat_cache)
|
|
421
|
+
if self.avg_shortcut is not None:
|
|
422
|
+
x_shortcut = self.avg_shortcut(x, first_chunk)
|
|
423
|
+
return x_main + x_shortcut
|
|
424
|
+
else:
|
|
425
|
+
return x_main
|
|
426
|
+
|
|
427
|
+
|
|
221
428
|
class Encoder3d(nn.Module):
|
|
222
429
|
def __init__(
|
|
223
430
|
self,
|
|
431
|
+
in_channels=3,
|
|
224
432
|
dim=128,
|
|
225
433
|
z_dim=4,
|
|
226
434
|
dim_mult=[1, 2, 4, 4],
|
|
227
435
|
num_res_blocks=2,
|
|
228
|
-
|
|
229
|
-
temperal_downsample=[True, True, False],
|
|
436
|
+
temperal_downsample=[False, True, True],
|
|
230
437
|
dropout=0.0,
|
|
231
438
|
):
|
|
232
439
|
super().__init__()
|
|
440
|
+
self.in_channels = in_channels
|
|
233
441
|
self.dim = dim
|
|
234
442
|
self.z_dim = z_dim
|
|
235
443
|
self.dim_mult = dim_mult
|
|
236
444
|
self.num_res_blocks = num_res_blocks
|
|
237
|
-
self.attn_scales = attn_scales
|
|
238
445
|
self.temperal_downsample = temperal_downsample
|
|
239
446
|
|
|
240
447
|
# dimensions
|
|
241
448
|
dims = [dim * u for u in [1] + dim_mult]
|
|
242
|
-
scale = 1.0
|
|
243
449
|
|
|
244
450
|
# init block
|
|
245
|
-
self.conv1 = CausalConv3d(
|
|
451
|
+
self.conv1 = CausalConv3d(in_channels, dims[0], 3, padding=1)
|
|
246
452
|
|
|
247
453
|
# downsample blocks
|
|
248
454
|
downsamples = []
|
|
455
|
+
use_down_residual_block = in_channels == 12
|
|
249
456
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
downsamples.append(
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
457
|
+
if use_down_residual_block:
|
|
458
|
+
t_down_flag = temperal_downsample[i] if i < len(temperal_downsample) else False
|
|
459
|
+
downsamples.append(
|
|
460
|
+
Down_ResidualBlock(
|
|
461
|
+
in_dim=in_dim,
|
|
462
|
+
out_dim=out_dim,
|
|
463
|
+
dropout=dropout,
|
|
464
|
+
mult=num_res_blocks,
|
|
465
|
+
temperal_downsample=t_down_flag,
|
|
466
|
+
down_flag=i != len(dim_mult) - 1,
|
|
467
|
+
)
|
|
468
|
+
)
|
|
469
|
+
else:
|
|
470
|
+
# residual (+attention) blocks
|
|
471
|
+
for _ in range(num_res_blocks):
|
|
472
|
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
|
473
|
+
in_dim = out_dim
|
|
474
|
+
|
|
475
|
+
# downsample block
|
|
476
|
+
if i != len(dim_mult) - 1:
|
|
477
|
+
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
|
478
|
+
downsamples.append(Resample(out_dim, mode=mode))
|
|
262
479
|
self.downsamples = nn.Sequential(*downsamples)
|
|
263
480
|
|
|
264
481
|
# middle blocks
|
|
@@ -315,25 +532,24 @@ class Encoder3d(nn.Module):
|
|
|
315
532
|
class Decoder3d(nn.Module):
|
|
316
533
|
def __init__(
|
|
317
534
|
self,
|
|
535
|
+
out_channels=3,
|
|
318
536
|
dim=128,
|
|
319
537
|
z_dim=4,
|
|
320
538
|
dim_mult=[1, 2, 4, 4],
|
|
321
539
|
num_res_blocks=2,
|
|
322
|
-
|
|
323
|
-
temperal_upsample=[False, True, True],
|
|
540
|
+
temperal_upsample=[True, True, False],
|
|
324
541
|
dropout=0.0,
|
|
325
542
|
):
|
|
326
543
|
super().__init__()
|
|
544
|
+
self.out_channels = out_channels
|
|
327
545
|
self.dim = dim
|
|
328
546
|
self.z_dim = z_dim
|
|
329
547
|
self.dim_mult = dim_mult
|
|
330
548
|
self.num_res_blocks = num_res_blocks
|
|
331
|
-
self.attn_scales = attn_scales
|
|
332
549
|
self.temperal_upsample = temperal_upsample
|
|
333
550
|
|
|
334
551
|
# dimensions
|
|
335
552
|
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
|
336
|
-
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
|
337
553
|
|
|
338
554
|
# init block
|
|
339
555
|
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
|
@@ -345,27 +561,40 @@ class Decoder3d(nn.Module):
|
|
|
345
561
|
|
|
346
562
|
# upsample blocks
|
|
347
563
|
upsamples = []
|
|
564
|
+
use_up_residual_block = out_channels == 12
|
|
348
565
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
566
|
+
if use_up_residual_block:
|
|
567
|
+
t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
|
|
568
|
+
upsamples.append(
|
|
569
|
+
Up_ResidualBlock(
|
|
570
|
+
in_dim=in_dim,
|
|
571
|
+
out_dim=out_dim,
|
|
572
|
+
dropout=dropout,
|
|
573
|
+
mult=num_res_blocks + 1,
|
|
574
|
+
temperal_upsample=t_up_flag,
|
|
575
|
+
up_flag=i != len(dim_mult) - 1,
|
|
576
|
+
)
|
|
577
|
+
)
|
|
578
|
+
else:
|
|
579
|
+
# residual (+attention) blocks
|
|
580
|
+
if i == 1 or i == 2 or i == 3:
|
|
581
|
+
in_dim = in_dim // 2
|
|
582
|
+
for _ in range(num_res_blocks + 1):
|
|
583
|
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
|
584
|
+
in_dim = out_dim
|
|
585
|
+
|
|
586
|
+
# upsample block
|
|
587
|
+
if i != len(dim_mult) - 1:
|
|
588
|
+
mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
|
589
|
+
upsamples.append(Resample(out_dim, mode=mode))
|
|
363
590
|
self.upsamples = nn.Sequential(*upsamples)
|
|
364
591
|
|
|
365
592
|
# output blocks
|
|
366
|
-
self.head = nn.Sequential(
|
|
593
|
+
self.head = nn.Sequential(
|
|
594
|
+
RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, out_channels, 3, padding=1)
|
|
595
|
+
)
|
|
367
596
|
|
|
368
|
-
def forward(self, x, feat_cache=None):
|
|
597
|
+
def forward(self, x, feat_cache=None, first_chunk=False):
|
|
369
598
|
## conv1
|
|
370
599
|
if feat_cache is not None:
|
|
371
600
|
key = id(self.conv1)
|
|
@@ -387,7 +616,9 @@ class Decoder3d(nn.Module):
|
|
|
387
616
|
|
|
388
617
|
## upsamples
|
|
389
618
|
for layer in self.upsamples:
|
|
390
|
-
if feat_cache is not None:
|
|
619
|
+
if check_is_instance(layer, Up_ResidualBlock) and feat_cache is not None:
|
|
620
|
+
x = layer(x, feat_cache, first_chunk)
|
|
621
|
+
elif feat_cache is not None:
|
|
391
622
|
x = layer(x, feat_cache)
|
|
392
623
|
else:
|
|
393
624
|
x = layer(x)
|
|
@@ -410,30 +641,36 @@ class Decoder3d(nn.Module):
|
|
|
410
641
|
class VideoVAE(nn.Module):
|
|
411
642
|
def __init__(
|
|
412
643
|
self,
|
|
413
|
-
|
|
644
|
+
in_channels=3,
|
|
645
|
+
out_channels=3,
|
|
646
|
+
encoder_dim=96,
|
|
647
|
+
decoder_dim=96,
|
|
414
648
|
z_dim=16,
|
|
415
649
|
dim_mult=[1, 2, 4, 4],
|
|
416
650
|
num_res_blocks=2,
|
|
417
|
-
attn_scales=[],
|
|
418
651
|
temperal_downsample=[False, True, True],
|
|
419
652
|
dropout=0.0,
|
|
420
653
|
):
|
|
421
654
|
super().__init__()
|
|
422
|
-
self.
|
|
655
|
+
self.in_channels = in_channels
|
|
656
|
+
self.out_channels = out_channels
|
|
657
|
+
self.encoder_dim = encoder_dim
|
|
658
|
+
self.decoder_dim = decoder_dim
|
|
423
659
|
self.z_dim = z_dim
|
|
424
660
|
self.dim_mult = dim_mult
|
|
425
661
|
self.num_res_blocks = num_res_blocks
|
|
426
|
-
self.attn_scales = attn_scales
|
|
427
662
|
self.temperal_downsample = temperal_downsample
|
|
428
663
|
self.temperal_upsample = temperal_downsample[::-1]
|
|
429
664
|
|
|
430
665
|
# modules
|
|
431
666
|
self.encoder = Encoder3d(
|
|
432
|
-
|
|
667
|
+
in_channels, encoder_dim, z_dim * 2, dim_mult, num_res_blocks, self.temperal_downsample, dropout
|
|
433
668
|
)
|
|
434
669
|
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
|
435
670
|
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
|
436
|
-
self.decoder = Decoder3d(
|
|
671
|
+
self.decoder = Decoder3d(
|
|
672
|
+
out_channels, decoder_dim, z_dim, dim_mult, num_res_blocks, self.temperal_upsample, dropout
|
|
673
|
+
)
|
|
437
674
|
|
|
438
675
|
def forward(self, x):
|
|
439
676
|
mu, log_var = self.encode(x)
|
|
@@ -443,6 +680,7 @@ class VideoVAE(nn.Module):
|
|
|
443
680
|
|
|
444
681
|
def encode(self, x, scale):
|
|
445
682
|
feat_cache = {}
|
|
683
|
+
x = patchify(x, patch_size=2 if self.in_channels == 12 else 1)
|
|
446
684
|
t = x.shape[2]
|
|
447
685
|
iter_ = 1 + (t - 1) // 4
|
|
448
686
|
|
|
@@ -477,10 +715,11 @@ class VideoVAE(nn.Module):
|
|
|
477
715
|
x = self.conv2(z)
|
|
478
716
|
for i in range(iter_):
|
|
479
717
|
if i == 0:
|
|
480
|
-
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=feat_cache)
|
|
718
|
+
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=feat_cache, first_chunk=True)
|
|
481
719
|
else:
|
|
482
720
|
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=feat_cache)
|
|
483
721
|
out = torch.cat([out, out_], 2) # may add tensor offload
|
|
722
|
+
out = unpatchify(out, patch_size=2 if self.out_channels == 12 else 1)
|
|
484
723
|
return out
|
|
485
724
|
|
|
486
725
|
def reparameterize(self, mu, log_var):
|
|
@@ -515,10 +754,26 @@ class WanVideoVAEStateDictConverter(StateDictConverter):
|
|
|
515
754
|
class WanVideoVAE(PreTrainedModel):
|
|
516
755
|
converter = WanVideoVAEStateDictConverter()
|
|
517
756
|
|
|
518
|
-
def __init__(
|
|
757
|
+
def __init__(
|
|
758
|
+
self,
|
|
759
|
+
in_channels=3,
|
|
760
|
+
out_channels=3,
|
|
761
|
+
encoder_dim=96,
|
|
762
|
+
decoder_dim=96,
|
|
763
|
+
z_dim=16,
|
|
764
|
+
dim_mult=[1, 2, 4, 4],
|
|
765
|
+
num_res_blocks=2,
|
|
766
|
+
temperal_downsample=[False, True, True],
|
|
767
|
+
dropout=0.0,
|
|
768
|
+
patch_size=1,
|
|
769
|
+
mean=None,
|
|
770
|
+
std=None,
|
|
771
|
+
device: str = "cuda:0",
|
|
772
|
+
dtype: torch.dtype = torch.float32,
|
|
773
|
+
):
|
|
519
774
|
super().__init__()
|
|
520
775
|
|
|
521
|
-
mean = [
|
|
776
|
+
mean = mean or [
|
|
522
777
|
-0.7571,
|
|
523
778
|
-0.7089,
|
|
524
779
|
-0.9113,
|
|
@@ -536,7 +791,7 @@ class WanVideoVAE(PreTrainedModel):
|
|
|
536
791
|
0.2503,
|
|
537
792
|
-0.2921,
|
|
538
793
|
]
|
|
539
|
-
std = [
|
|
794
|
+
std = std or [
|
|
540
795
|
2.8184,
|
|
541
796
|
1.4541,
|
|
542
797
|
2.3275,
|
|
@@ -559,13 +814,49 @@ class WanVideoVAE(PreTrainedModel):
|
|
|
559
814
|
self.scale = [self.mean, 1.0 / self.std]
|
|
560
815
|
|
|
561
816
|
# init model
|
|
562
|
-
self.model =
|
|
563
|
-
|
|
817
|
+
self.model = (
|
|
818
|
+
VideoVAE(
|
|
819
|
+
in_channels=in_channels,
|
|
820
|
+
out_channels=out_channels,
|
|
821
|
+
encoder_dim=encoder_dim,
|
|
822
|
+
decoder_dim=decoder_dim,
|
|
823
|
+
z_dim=z_dim,
|
|
824
|
+
dim_mult=dim_mult,
|
|
825
|
+
num_res_blocks=num_res_blocks,
|
|
826
|
+
temperal_downsample=temperal_downsample,
|
|
827
|
+
dropout=dropout,
|
|
828
|
+
)
|
|
829
|
+
.eval()
|
|
830
|
+
.requires_grad_(False)
|
|
831
|
+
)
|
|
832
|
+
self.z_dim = z_dim
|
|
833
|
+
self.patch_size = patch_size
|
|
834
|
+
self.upsampling_factor = 8 * patch_size
|
|
835
|
+
|
|
836
|
+
@staticmethod
|
|
837
|
+
def get_model_config(model_type: str) -> dict:
|
|
838
|
+
MODEL_CONFIG_FILES = {
|
|
839
|
+
"wan2.1-vae": WAN2_1_VAE_CONFIG_FILE,
|
|
840
|
+
"wan2.2-vae": WAN2_2_VAE_CONFIG_FILE,
|
|
841
|
+
}
|
|
842
|
+
if model_type not in MODEL_CONFIG_FILES:
|
|
843
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
844
|
+
|
|
845
|
+
config_file = MODEL_CONFIG_FILES[model_type]
|
|
846
|
+
with open(config_file, "r") as f:
|
|
847
|
+
config = json.load(f)
|
|
848
|
+
return config
|
|
564
849
|
|
|
565
850
|
@classmethod
|
|
566
|
-
def from_state_dict(
|
|
851
|
+
def from_state_dict(
|
|
852
|
+
cls,
|
|
853
|
+
state_dict: Dict[str, torch.Tensor],
|
|
854
|
+
config: Dict[str, Any],
|
|
855
|
+
device: str = "cuda:0",
|
|
856
|
+
dtype: torch.dtype = torch.float32,
|
|
857
|
+
) -> "WanVideoVAE":
|
|
567
858
|
with no_init_weights():
|
|
568
|
-
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
|
|
859
|
+
model = torch.nn.utils.skip_init(cls, **config, device=device, dtype=dtype)
|
|
569
860
|
model.load_state_dict(state_dict, assign=True)
|
|
570
861
|
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
571
862
|
return model
|
|
@@ -690,13 +981,13 @@ class WanVideoVAE(PreTrainedModel):
|
|
|
690
981
|
device=data_device,
|
|
691
982
|
)
|
|
692
983
|
values = torch.zeros(
|
|
693
|
-
(1,
|
|
984
|
+
(1, self.z_dim, out_T, H // self.upsampling_factor, W // self.upsampling_factor),
|
|
694
985
|
dtype=video.dtype,
|
|
695
986
|
device=data_device,
|
|
696
987
|
)
|
|
697
988
|
|
|
698
|
-
|
|
699
|
-
for i, (h, h_, w, w_) in enumerate(tqdm(tasks, desc="VAE ENCODING", disable=
|
|
989
|
+
hide_progress = dist.is_initialized() and dist.get_rank() != 0
|
|
990
|
+
for i, (h, h_, w, w_) in enumerate(tqdm(tasks, desc="VAE ENCODING", disable=hide_progress)):
|
|
700
991
|
if dist.is_initialized() and (i % dist.get_world_size() != dist.get_rank()):
|
|
701
992
|
continue
|
|
702
993
|
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
|
|
@@ -727,9 +1018,9 @@ class WanVideoVAE(PreTrainedModel):
|
|
|
727
1018
|
target_h : target_h + hidden_states_batch.shape[3],
|
|
728
1019
|
target_w : target_w + hidden_states_batch.shape[4],
|
|
729
1020
|
] += mask
|
|
730
|
-
if progress_callback is not None and not
|
|
1021
|
+
if progress_callback is not None and not hide_progress:
|
|
731
1022
|
progress_callback(i + 1, len(tasks), "VAE ENCODING")
|
|
732
|
-
if progress_callback is not None and not
|
|
1023
|
+
if progress_callback is not None and not hide_progress:
|
|
733
1024
|
progress_callback(len(tasks), len(tasks), "VAE ENCODING")
|
|
734
1025
|
if dist.is_initialized():
|
|
735
1026
|
dist.all_reduce(values)
|
|
@@ -778,6 +1069,10 @@ class WanVideoVAE(PreTrainedModel):
|
|
|
778
1069
|
for i, hidden_state in enumerate(hidden_states):
|
|
779
1070
|
hidden_state = hidden_state.unsqueeze(0)
|
|
780
1071
|
if tiled:
|
|
1072
|
+
assert tile_size[0] % self.patch_size == 0 and tile_size[1] % self.patch_size == 0
|
|
1073
|
+
assert tile_stride[0] % self.patch_size == 0 and tile_stride[1] % self.patch_size == 0
|
|
1074
|
+
tile_size = (tile_size[0] // self.patch_size, tile_size[1] // self.patch_size)
|
|
1075
|
+
tile_stride = (tile_stride[0] // self.patch_size, tile_stride[1] // self.patch_size)
|
|
781
1076
|
video = self.tiled_decode(
|
|
782
1077
|
hidden_state, device, tile_size, tile_stride, progress_callback=progress_callback
|
|
783
1078
|
)
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import torch
|
|
3
3
|
import numpy as np
|
|
4
|
+
from einops import rearrange
|
|
4
5
|
from typing import Dict, List, Tuple
|
|
5
6
|
from PIL import Image
|
|
7
|
+
|
|
6
8
|
from diffsynth_engine.configs import BaseConfig
|
|
7
9
|
from diffsynth_engine.utils.offload import enable_sequential_cpu_offload
|
|
8
10
|
from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
|
|
@@ -38,7 +40,6 @@ class BasePipeline:
|
|
|
38
40
|
self.dtype = dtype
|
|
39
41
|
self.offload_mode = None
|
|
40
42
|
self.model_names = []
|
|
41
|
-
self._models_offload_params = {}
|
|
42
43
|
|
|
43
44
|
@classmethod
|
|
44
45
|
def from_pretrained(cls, model_path_or_config: str | BaseConfig) -> "BasePipeline":
|
|
@@ -140,9 +141,18 @@ class BasePipeline:
|
|
|
140
141
|
return [BasePipeline.preprocess_image(image) for image in images]
|
|
141
142
|
|
|
142
143
|
@staticmethod
|
|
143
|
-
def vae_output_to_image(vae_output: torch.Tensor) -> Image.Image:
|
|
144
|
-
|
|
145
|
-
|
|
144
|
+
def vae_output_to_image(vae_output: torch.Tensor) -> Image.Image | List[Image.Image]:
|
|
145
|
+
vae_output = vae_output[0]
|
|
146
|
+
if vae_output.ndim == 4:
|
|
147
|
+
vae_output = rearrange(vae_output, "c t h w -> t h w c")
|
|
148
|
+
else:
|
|
149
|
+
vae_output = rearrange(vae_output, "c h w -> h w c")
|
|
150
|
+
|
|
151
|
+
image = ((vae_output.float() / 2 + 0.5).clip(0, 1) * 255).cpu().numpy().astype("uint8")
|
|
152
|
+
if image.ndim == 4:
|
|
153
|
+
image = [Image.fromarray(img) for img in image]
|
|
154
|
+
else:
|
|
155
|
+
image = Image.fromarray(image)
|
|
146
156
|
return image
|
|
147
157
|
|
|
148
158
|
@staticmethod
|
|
@@ -230,10 +240,6 @@ class BasePipeline:
|
|
|
230
240
|
model = getattr(self, model_name)
|
|
231
241
|
if model is not None:
|
|
232
242
|
model.to("cpu")
|
|
233
|
-
self._models_offload_params[model_name] = {}
|
|
234
|
-
for name, param in model.named_parameters(recurse=True):
|
|
235
|
-
param.data = param.data.pin_memory()
|
|
236
|
-
self._models_offload_params[model_name][name] = param.data
|
|
237
243
|
self.offload_mode = "cpu_offload"
|
|
238
244
|
|
|
239
245
|
def _enable_sequential_cpu_offload(self):
|
|
@@ -272,9 +278,7 @@ class BasePipeline:
|
|
|
272
278
|
and (p := next(model.parameters(), None)) is not None
|
|
273
279
|
and p.device != torch.device("cpu")
|
|
274
280
|
):
|
|
275
|
-
|
|
276
|
-
for name, param in model.named_parameters(recurse=True):
|
|
277
|
-
param.data = param_cache[name]
|
|
281
|
+
model.to("cpu")
|
|
278
282
|
# load the needed models to device
|
|
279
283
|
for model_name in load_model_names:
|
|
280
284
|
model = getattr(self, model_name)
|