diffsynth-engine 0.3.6.dev13__py3-none-any.whl → 0.3.6.dev14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +2 -3
  2. diffsynth_engine/conf/models/wan/dit/{14b-i2v.json → wan2.1-flf2v-14b.json} +5 -2
  3. diffsynth_engine/conf/models/wan/dit/{14b-flf2v.json → wan2.1-i2v-14b.json} +2 -2
  4. diffsynth_engine/conf/models/wan/dit/{1.3b-t2v.json → wan2.1-t2v-1.3b.json} +0 -1
  5. diffsynth_engine/conf/models/wan/dit/{14b-t2v.json → wan2.1-t2v-14b.json} +0 -1
  6. diffsynth_engine/conf/models/wan/dit/wan2.2-i2v-a14b.json +16 -0
  7. diffsynth_engine/conf/models/wan/dit/wan2.2-t2v-a14b.json +16 -0
  8. diffsynth_engine/conf/models/wan/dit/wan2.2-ti2v-5b.json +14 -0
  9. diffsynth_engine/conf/models/wan/vae/wan2.1-vae.json +48 -0
  10. diffsynth_engine/conf/models/wan/vae/wan2.2-vae.json +112 -0
  11. diffsynth_engine/configs/pipeline.py +6 -1
  12. diffsynth_engine/models/wan/wan_dit.py +52 -32
  13. diffsynth_engine/models/wan/wan_vae.py +355 -60
  14. diffsynth_engine/pipelines/base.py +15 -11
  15. diffsynth_engine/pipelines/wan_video.py +175 -74
  16. diffsynth_engine/utils/constants.py +10 -4
  17. diffsynth_engine/utils/parallel.py +3 -1
  18. {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/METADATA +1 -1
  19. {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/RECORD +22 -17
  20. {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/WHEEL +0 -0
  21. {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/licenses/LICENSE +0 -0
  22. {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,15 @@
1
+ import json
1
2
  import torch
2
3
  import torch.nn as nn
3
4
  import torch.nn.functional as F
4
5
  import torch.distributed as dist
5
6
  from einops import rearrange, repeat
6
7
  from tqdm import tqdm
8
+ from typing import Any, Dict
7
9
 
8
10
  from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
9
11
  from diffsynth_engine.models.utils import no_init_weights
12
+ from diffsynth_engine.utils.constants import WAN2_1_VAE_CONFIG_FILE, WAN2_2_VAE_CONFIG_FILE
10
13
 
11
14
  CACHE_T = 2
12
15
 
@@ -77,7 +80,7 @@ class Upsample(nn.Upsample):
77
80
 
78
81
 
79
82
  class Resample(nn.Module):
80
- def __init__(self, dim, mode):
83
+ def __init__(self, dim, mode, keep_channels=False):
81
84
  assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d")
82
85
  super().__init__()
83
86
  self.dim = dim
@@ -86,11 +89,13 @@ class Resample(nn.Module):
86
89
  # layers
87
90
  if mode == "upsample2d":
88
91
  self.resample = nn.Sequential(
89
- Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
92
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
93
+ nn.Conv2d(dim, dim if keep_channels else dim // 2, 3, padding=1),
90
94
  )
91
95
  elif mode == "upsample3d":
92
96
  self.resample = nn.Sequential(
93
- Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
97
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
98
+ nn.Conv2d(dim, dim if keep_channels else dim // 2, 3, padding=1),
94
99
  )
95
100
  self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
96
101
 
@@ -218,47 +223,259 @@ class AttentionBlock(nn.Module):
218
223
  return x + identity
219
224
 
220
225
 
226
+ def patchify(x, patch_size):
227
+ if patch_size == 1:
228
+ return x
229
+ if x.dim() == 4:
230
+ x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
231
+ elif x.dim() == 5:
232
+ x = rearrange(
233
+ x,
234
+ "b c f (h q) (w r) -> b (c r q) f h w",
235
+ q=patch_size,
236
+ r=patch_size,
237
+ )
238
+ else:
239
+ raise ValueError(f"Invalid input shape: {x.shape}")
240
+
241
+ return x
242
+
243
+
244
+ def unpatchify(x, patch_size):
245
+ if patch_size == 1:
246
+ return x
247
+
248
+ if x.dim() == 4:
249
+ x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
250
+ elif x.dim() == 5:
251
+ x = rearrange(
252
+ x,
253
+ "b (c r q) f h w -> b c f (h q) (w r)",
254
+ q=patch_size,
255
+ r=patch_size,
256
+ )
257
+ return x
258
+
259
+
260
+ class AvgDown3D(nn.Module):
261
+ def __init__(
262
+ self,
263
+ in_channels,
264
+ out_channels,
265
+ factor_t,
266
+ factor_s=1,
267
+ ):
268
+ super().__init__()
269
+ self.in_channels = in_channels
270
+ self.out_channels = out_channels
271
+ self.factor_t = factor_t
272
+ self.factor_s = factor_s
273
+ self.factor = self.factor_t * self.factor_s * self.factor_s
274
+
275
+ assert in_channels * self.factor % out_channels == 0
276
+ self.group_size = in_channels * self.factor // out_channels
277
+
278
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
279
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
280
+ pad = (0, 0, 0, 0, pad_t, 0)
281
+ x = F.pad(x, pad)
282
+ B, C, T, H, W = x.shape
283
+ x = x.view(
284
+ B,
285
+ C,
286
+ T // self.factor_t,
287
+ self.factor_t,
288
+ H // self.factor_s,
289
+ self.factor_s,
290
+ W // self.factor_s,
291
+ self.factor_s,
292
+ )
293
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
294
+ x = x.view(
295
+ B,
296
+ C * self.factor,
297
+ T // self.factor_t,
298
+ H // self.factor_s,
299
+ W // self.factor_s,
300
+ )
301
+ x = x.view(
302
+ B,
303
+ self.out_channels,
304
+ self.group_size,
305
+ T // self.factor_t,
306
+ H // self.factor_s,
307
+ W // self.factor_s,
308
+ )
309
+ x = x.mean(dim=2)
310
+ return x
311
+
312
+
313
+ class DupUp3D(nn.Module):
314
+ def __init__(
315
+ self,
316
+ in_channels: int,
317
+ out_channels: int,
318
+ factor_t,
319
+ factor_s=1,
320
+ ):
321
+ super().__init__()
322
+ self.in_channels = in_channels
323
+ self.out_channels = out_channels
324
+
325
+ self.factor_t = factor_t
326
+ self.factor_s = factor_s
327
+ self.factor = self.factor_t * self.factor_s * self.factor_s
328
+
329
+ assert out_channels * self.factor % in_channels == 0
330
+ self.repeats = out_channels * self.factor // in_channels
331
+
332
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
333
+ x = x.repeat_interleave(self.repeats, dim=1)
334
+ x = x.view(
335
+ x.size(0),
336
+ self.out_channels,
337
+ self.factor_t,
338
+ self.factor_s,
339
+ self.factor_s,
340
+ x.size(2),
341
+ x.size(3),
342
+ x.size(4),
343
+ )
344
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
345
+ x = x.view(
346
+ x.size(0),
347
+ self.out_channels,
348
+ x.size(2) * self.factor_t,
349
+ x.size(4) * self.factor_s,
350
+ x.size(6) * self.factor_s,
351
+ )
352
+ if first_chunk:
353
+ x = x[:, :, self.factor_t - 1 :, :, :]
354
+ return x
355
+
356
+
357
+ class Down_ResidualBlock(nn.Module):
358
+ def __init__(self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False):
359
+ super().__init__()
360
+
361
+ # Shortcut path with downsample
362
+ self.avg_shortcut = AvgDown3D(
363
+ in_dim,
364
+ out_dim,
365
+ factor_t=2 if temperal_downsample else 1,
366
+ factor_s=2 if down_flag else 1,
367
+ )
368
+
369
+ # Main path with residual blocks and downsample
370
+ downsamples = []
371
+ for _ in range(mult):
372
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
373
+ in_dim = out_dim
374
+
375
+ # Add the final downsample block
376
+ if down_flag:
377
+ mode = "downsample3d" if temperal_downsample else "downsample2d"
378
+ downsamples.append(Resample(out_dim, mode=mode))
379
+
380
+ self.downsamples = nn.Sequential(*downsamples)
381
+
382
+ def forward(self, x, feat_cache=None):
383
+ x_copy = x.clone()
384
+ for module in self.downsamples:
385
+ x = module(x, feat_cache)
386
+
387
+ return x + self.avg_shortcut(x_copy)
388
+
389
+
390
+ class Up_ResidualBlock(nn.Module):
391
+ def __init__(self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False):
392
+ super().__init__()
393
+ # Shortcut path with upsample
394
+ if up_flag:
395
+ self.avg_shortcut = DupUp3D(
396
+ in_dim,
397
+ out_dim,
398
+ factor_t=2 if temperal_upsample else 1,
399
+ factor_s=2 if up_flag else 1,
400
+ )
401
+ else:
402
+ self.avg_shortcut = None
403
+
404
+ # Main path with residual blocks and upsample
405
+ upsamples = []
406
+ for _ in range(mult):
407
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
408
+ in_dim = out_dim
409
+
410
+ # Add the final upsample block
411
+ if up_flag:
412
+ mode = "upsample3d" if temperal_upsample else "upsample2d"
413
+ upsamples.append(Resample(out_dim, mode=mode, keep_channels=True))
414
+
415
+ self.upsamples = nn.Sequential(*upsamples)
416
+
417
+ def forward(self, x, feat_cache=None, first_chunk=False):
418
+ x_main = x.clone()
419
+ for module in self.upsamples:
420
+ x_main = module(x_main, feat_cache)
421
+ if self.avg_shortcut is not None:
422
+ x_shortcut = self.avg_shortcut(x, first_chunk)
423
+ return x_main + x_shortcut
424
+ else:
425
+ return x_main
426
+
427
+
221
428
  class Encoder3d(nn.Module):
222
429
  def __init__(
223
430
  self,
431
+ in_channels=3,
224
432
  dim=128,
225
433
  z_dim=4,
226
434
  dim_mult=[1, 2, 4, 4],
227
435
  num_res_blocks=2,
228
- attn_scales=[],
229
- temperal_downsample=[True, True, False],
436
+ temperal_downsample=[False, True, True],
230
437
  dropout=0.0,
231
438
  ):
232
439
  super().__init__()
440
+ self.in_channels = in_channels
233
441
  self.dim = dim
234
442
  self.z_dim = z_dim
235
443
  self.dim_mult = dim_mult
236
444
  self.num_res_blocks = num_res_blocks
237
- self.attn_scales = attn_scales
238
445
  self.temperal_downsample = temperal_downsample
239
446
 
240
447
  # dimensions
241
448
  dims = [dim * u for u in [1] + dim_mult]
242
- scale = 1.0
243
449
 
244
450
  # init block
245
- self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
451
+ self.conv1 = CausalConv3d(in_channels, dims[0], 3, padding=1)
246
452
 
247
453
  # downsample blocks
248
454
  downsamples = []
455
+ use_down_residual_block = in_channels == 12
249
456
  for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
250
- # residual (+attention) blocks
251
- for _ in range(num_res_blocks):
252
- downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
253
- if scale in attn_scales:
254
- downsamples.append(AttentionBlock(out_dim))
255
- in_dim = out_dim
256
-
257
- # downsample block
258
- if i != len(dim_mult) - 1:
259
- mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
260
- downsamples.append(Resample(out_dim, mode=mode))
261
- scale /= 2.0
457
+ if use_down_residual_block:
458
+ t_down_flag = temperal_downsample[i] if i < len(temperal_downsample) else False
459
+ downsamples.append(
460
+ Down_ResidualBlock(
461
+ in_dim=in_dim,
462
+ out_dim=out_dim,
463
+ dropout=dropout,
464
+ mult=num_res_blocks,
465
+ temperal_downsample=t_down_flag,
466
+ down_flag=i != len(dim_mult) - 1,
467
+ )
468
+ )
469
+ else:
470
+ # residual (+attention) blocks
471
+ for _ in range(num_res_blocks):
472
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
473
+ in_dim = out_dim
474
+
475
+ # downsample block
476
+ if i != len(dim_mult) - 1:
477
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
478
+ downsamples.append(Resample(out_dim, mode=mode))
262
479
  self.downsamples = nn.Sequential(*downsamples)
263
480
 
264
481
  # middle blocks
@@ -315,25 +532,24 @@ class Encoder3d(nn.Module):
315
532
  class Decoder3d(nn.Module):
316
533
  def __init__(
317
534
  self,
535
+ out_channels=3,
318
536
  dim=128,
319
537
  z_dim=4,
320
538
  dim_mult=[1, 2, 4, 4],
321
539
  num_res_blocks=2,
322
- attn_scales=[],
323
- temperal_upsample=[False, True, True],
540
+ temperal_upsample=[True, True, False],
324
541
  dropout=0.0,
325
542
  ):
326
543
  super().__init__()
544
+ self.out_channels = out_channels
327
545
  self.dim = dim
328
546
  self.z_dim = z_dim
329
547
  self.dim_mult = dim_mult
330
548
  self.num_res_blocks = num_res_blocks
331
- self.attn_scales = attn_scales
332
549
  self.temperal_upsample = temperal_upsample
333
550
 
334
551
  # dimensions
335
552
  dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
336
- scale = 1.0 / 2 ** (len(dim_mult) - 2)
337
553
 
338
554
  # init block
339
555
  self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
@@ -345,27 +561,40 @@ class Decoder3d(nn.Module):
345
561
 
346
562
  # upsample blocks
347
563
  upsamples = []
564
+ use_up_residual_block = out_channels == 12
348
565
  for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
349
- # residual (+attention) blocks
350
- if i == 1 or i == 2 or i == 3:
351
- in_dim = in_dim // 2
352
- for _ in range(num_res_blocks + 1):
353
- upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
354
- if scale in attn_scales:
355
- upsamples.append(AttentionBlock(out_dim))
356
- in_dim = out_dim
357
-
358
- # upsample block
359
- if i != len(dim_mult) - 1:
360
- mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
361
- upsamples.append(Resample(out_dim, mode=mode))
362
- scale *= 2.0
566
+ if use_up_residual_block:
567
+ t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
568
+ upsamples.append(
569
+ Up_ResidualBlock(
570
+ in_dim=in_dim,
571
+ out_dim=out_dim,
572
+ dropout=dropout,
573
+ mult=num_res_blocks + 1,
574
+ temperal_upsample=t_up_flag,
575
+ up_flag=i != len(dim_mult) - 1,
576
+ )
577
+ )
578
+ else:
579
+ # residual (+attention) blocks
580
+ if i == 1 or i == 2 or i == 3:
581
+ in_dim = in_dim // 2
582
+ for _ in range(num_res_blocks + 1):
583
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
584
+ in_dim = out_dim
585
+
586
+ # upsample block
587
+ if i != len(dim_mult) - 1:
588
+ mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
589
+ upsamples.append(Resample(out_dim, mode=mode))
363
590
  self.upsamples = nn.Sequential(*upsamples)
364
591
 
365
592
  # output blocks
366
- self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1))
593
+ self.head = nn.Sequential(
594
+ RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, out_channels, 3, padding=1)
595
+ )
367
596
 
368
- def forward(self, x, feat_cache=None):
597
+ def forward(self, x, feat_cache=None, first_chunk=False):
369
598
  ## conv1
370
599
  if feat_cache is not None:
371
600
  key = id(self.conv1)
@@ -387,7 +616,9 @@ class Decoder3d(nn.Module):
387
616
 
388
617
  ## upsamples
389
618
  for layer in self.upsamples:
390
- if feat_cache is not None:
619
+ if check_is_instance(layer, Up_ResidualBlock) and feat_cache is not None:
620
+ x = layer(x, feat_cache, first_chunk)
621
+ elif feat_cache is not None:
391
622
  x = layer(x, feat_cache)
392
623
  else:
393
624
  x = layer(x)
@@ -410,30 +641,36 @@ class Decoder3d(nn.Module):
410
641
  class VideoVAE(nn.Module):
411
642
  def __init__(
412
643
  self,
413
- dim=96,
644
+ in_channels=3,
645
+ out_channels=3,
646
+ encoder_dim=96,
647
+ decoder_dim=96,
414
648
  z_dim=16,
415
649
  dim_mult=[1, 2, 4, 4],
416
650
  num_res_blocks=2,
417
- attn_scales=[],
418
651
  temperal_downsample=[False, True, True],
419
652
  dropout=0.0,
420
653
  ):
421
654
  super().__init__()
422
- self.dim = dim
655
+ self.in_channels = in_channels
656
+ self.out_channels = out_channels
657
+ self.encoder_dim = encoder_dim
658
+ self.decoder_dim = decoder_dim
423
659
  self.z_dim = z_dim
424
660
  self.dim_mult = dim_mult
425
661
  self.num_res_blocks = num_res_blocks
426
- self.attn_scales = attn_scales
427
662
  self.temperal_downsample = temperal_downsample
428
663
  self.temperal_upsample = temperal_downsample[::-1]
429
664
 
430
665
  # modules
431
666
  self.encoder = Encoder3d(
432
- dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
667
+ in_channels, encoder_dim, z_dim * 2, dim_mult, num_res_blocks, self.temperal_downsample, dropout
433
668
  )
434
669
  self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
435
670
  self.conv2 = CausalConv3d(z_dim, z_dim, 1)
436
- self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout)
671
+ self.decoder = Decoder3d(
672
+ out_channels, decoder_dim, z_dim, dim_mult, num_res_blocks, self.temperal_upsample, dropout
673
+ )
437
674
 
438
675
  def forward(self, x):
439
676
  mu, log_var = self.encode(x)
@@ -443,6 +680,7 @@ class VideoVAE(nn.Module):
443
680
 
444
681
  def encode(self, x, scale):
445
682
  feat_cache = {}
683
+ x = patchify(x, patch_size=2 if self.in_channels == 12 else 1)
446
684
  t = x.shape[2]
447
685
  iter_ = 1 + (t - 1) // 4
448
686
 
@@ -477,10 +715,11 @@ class VideoVAE(nn.Module):
477
715
  x = self.conv2(z)
478
716
  for i in range(iter_):
479
717
  if i == 0:
480
- out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=feat_cache)
718
+ out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=feat_cache, first_chunk=True)
481
719
  else:
482
720
  out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=feat_cache)
483
721
  out = torch.cat([out, out_], 2) # may add tensor offload
722
+ out = unpatchify(out, patch_size=2 if self.out_channels == 12 else 1)
484
723
  return out
485
724
 
486
725
  def reparameterize(self, mu, log_var):
@@ -515,10 +754,26 @@ class WanVideoVAEStateDictConverter(StateDictConverter):
515
754
  class WanVideoVAE(PreTrainedModel):
516
755
  converter = WanVideoVAEStateDictConverter()
517
756
 
518
- def __init__(self, z_dim=16, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
757
+ def __init__(
758
+ self,
759
+ in_channels=3,
760
+ out_channels=3,
761
+ encoder_dim=96,
762
+ decoder_dim=96,
763
+ z_dim=16,
764
+ dim_mult=[1, 2, 4, 4],
765
+ num_res_blocks=2,
766
+ temperal_downsample=[False, True, True],
767
+ dropout=0.0,
768
+ patch_size=1,
769
+ mean=None,
770
+ std=None,
771
+ device: str = "cuda:0",
772
+ dtype: torch.dtype = torch.float32,
773
+ ):
519
774
  super().__init__()
520
775
 
521
- mean = [
776
+ mean = mean or [
522
777
  -0.7571,
523
778
  -0.7089,
524
779
  -0.9113,
@@ -536,7 +791,7 @@ class WanVideoVAE(PreTrainedModel):
536
791
  0.2503,
537
792
  -0.2921,
538
793
  ]
539
- std = [
794
+ std = std or [
540
795
  2.8184,
541
796
  1.4541,
542
797
  2.3275,
@@ -559,13 +814,49 @@ class WanVideoVAE(PreTrainedModel):
559
814
  self.scale = [self.mean, 1.0 / self.std]
560
815
 
561
816
  # init model
562
- self.model = VideoVAE(z_dim=z_dim).eval().requires_grad_(False)
563
- self.upsampling_factor = 8
817
+ self.model = (
818
+ VideoVAE(
819
+ in_channels=in_channels,
820
+ out_channels=out_channels,
821
+ encoder_dim=encoder_dim,
822
+ decoder_dim=decoder_dim,
823
+ z_dim=z_dim,
824
+ dim_mult=dim_mult,
825
+ num_res_blocks=num_res_blocks,
826
+ temperal_downsample=temperal_downsample,
827
+ dropout=dropout,
828
+ )
829
+ .eval()
830
+ .requires_grad_(False)
831
+ )
832
+ self.z_dim = z_dim
833
+ self.patch_size = patch_size
834
+ self.upsampling_factor = 8 * patch_size
835
+
836
+ @staticmethod
837
+ def get_model_config(model_type: str) -> dict:
838
+ MODEL_CONFIG_FILES = {
839
+ "wan2.1-vae": WAN2_1_VAE_CONFIG_FILE,
840
+ "wan2.2-vae": WAN2_2_VAE_CONFIG_FILE,
841
+ }
842
+ if model_type not in MODEL_CONFIG_FILES:
843
+ raise ValueError(f"Unsupported model type: {model_type}")
844
+
845
+ config_file = MODEL_CONFIG_FILES[model_type]
846
+ with open(config_file, "r") as f:
847
+ config = json.load(f)
848
+ return config
564
849
 
565
850
  @classmethod
566
- def from_state_dict(cls, state_dict, device="cuda:0", dtype=torch.float32) -> "WanVideoVAE":
851
+ def from_state_dict(
852
+ cls,
853
+ state_dict: Dict[str, torch.Tensor],
854
+ config: Dict[str, Any],
855
+ device: str = "cuda:0",
856
+ dtype: torch.dtype = torch.float32,
857
+ ) -> "WanVideoVAE":
567
858
  with no_init_weights():
568
- model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
859
+ model = torch.nn.utils.skip_init(cls, **config, device=device, dtype=dtype)
569
860
  model.load_state_dict(state_dict, assign=True)
570
861
  model.to(device=device, dtype=dtype, non_blocking=True)
571
862
  return model
@@ -690,13 +981,13 @@ class WanVideoVAE(PreTrainedModel):
690
981
  device=data_device,
691
982
  )
692
983
  values = torch.zeros(
693
- (1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor),
984
+ (1, self.z_dim, out_T, H // self.upsampling_factor, W // self.upsampling_factor),
694
985
  dtype=video.dtype,
695
986
  device=data_device,
696
987
  )
697
988
 
698
- hide_progress_bar = dist.is_initialized() and dist.get_rank() != 0
699
- for i, (h, h_, w, w_) in enumerate(tqdm(tasks, desc="VAE ENCODING", disable=hide_progress_bar)):
989
+ hide_progress = dist.is_initialized() and dist.get_rank() != 0
990
+ for i, (h, h_, w, w_) in enumerate(tqdm(tasks, desc="VAE ENCODING", disable=hide_progress)):
700
991
  if dist.is_initialized() and (i % dist.get_world_size() != dist.get_rank()):
701
992
  continue
702
993
  hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
@@ -727,9 +1018,9 @@ class WanVideoVAE(PreTrainedModel):
727
1018
  target_h : target_h + hidden_states_batch.shape[3],
728
1019
  target_w : target_w + hidden_states_batch.shape[4],
729
1020
  ] += mask
730
- if progress_callback is not None and not hide_progress_bar:
1021
+ if progress_callback is not None and not hide_progress:
731
1022
  progress_callback(i + 1, len(tasks), "VAE ENCODING")
732
- if progress_callback is not None and not hide_progress_bar:
1023
+ if progress_callback is not None and not hide_progress:
733
1024
  progress_callback(len(tasks), len(tasks), "VAE ENCODING")
734
1025
  if dist.is_initialized():
735
1026
  dist.all_reduce(values)
@@ -778,6 +1069,10 @@ class WanVideoVAE(PreTrainedModel):
778
1069
  for i, hidden_state in enumerate(hidden_states):
779
1070
  hidden_state = hidden_state.unsqueeze(0)
780
1071
  if tiled:
1072
+ assert tile_size[0] % self.patch_size == 0 and tile_size[1] % self.patch_size == 0
1073
+ assert tile_stride[0] % self.patch_size == 0 and tile_stride[1] % self.patch_size == 0
1074
+ tile_size = (tile_size[0] // self.patch_size, tile_size[1] // self.patch_size)
1075
+ tile_stride = (tile_stride[0] // self.patch_size, tile_stride[1] // self.patch_size)
781
1076
  video = self.tiled_decode(
782
1077
  hidden_state, device, tile_size, tile_stride, progress_callback=progress_callback
783
1078
  )
@@ -1,8 +1,10 @@
1
1
  import os
2
2
  import torch
3
3
  import numpy as np
4
+ from einops import rearrange
4
5
  from typing import Dict, List, Tuple
5
6
  from PIL import Image
7
+
6
8
  from diffsynth_engine.configs import BaseConfig
7
9
  from diffsynth_engine.utils.offload import enable_sequential_cpu_offload
8
10
  from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
@@ -38,7 +40,6 @@ class BasePipeline:
38
40
  self.dtype = dtype
39
41
  self.offload_mode = None
40
42
  self.model_names = []
41
- self._models_offload_params = {}
42
43
 
43
44
  @classmethod
44
45
  def from_pretrained(cls, model_path_or_config: str | BaseConfig) -> "BasePipeline":
@@ -140,9 +141,18 @@ class BasePipeline:
140
141
  return [BasePipeline.preprocess_image(image) for image in images]
141
142
 
142
143
  @staticmethod
143
- def vae_output_to_image(vae_output: torch.Tensor) -> Image.Image:
144
- image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
145
- image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
144
+ def vae_output_to_image(vae_output: torch.Tensor) -> Image.Image | List[Image.Image]:
145
+ vae_output = vae_output[0]
146
+ if vae_output.ndim == 4:
147
+ vae_output = rearrange(vae_output, "c t h w -> t h w c")
148
+ else:
149
+ vae_output = rearrange(vae_output, "c h w -> h w c")
150
+
151
+ image = ((vae_output.float() / 2 + 0.5).clip(0, 1) * 255).cpu().numpy().astype("uint8")
152
+ if image.ndim == 4:
153
+ image = [Image.fromarray(img) for img in image]
154
+ else:
155
+ image = Image.fromarray(image)
146
156
  return image
147
157
 
148
158
  @staticmethod
@@ -230,10 +240,6 @@ class BasePipeline:
230
240
  model = getattr(self, model_name)
231
241
  if model is not None:
232
242
  model.to("cpu")
233
- self._models_offload_params[model_name] = {}
234
- for name, param in model.named_parameters(recurse=True):
235
- param.data = param.data.pin_memory()
236
- self._models_offload_params[model_name][name] = param.data
237
243
  self.offload_mode = "cpu_offload"
238
244
 
239
245
  def _enable_sequential_cpu_offload(self):
@@ -272,9 +278,7 @@ class BasePipeline:
272
278
  and (p := next(model.parameters(), None)) is not None
273
279
  and p.device != torch.device("cpu")
274
280
  ):
275
- param_cache = self._models_offload_params[model_name]
276
- for name, param in model.named_parameters(recurse=True):
277
- param.data = param_cache[name]
281
+ model.to("cpu")
278
282
  # load the needed models to device
279
283
  for model_name in load_model_names:
280
284
  model = getattr(self, model_name)