diffsynth-engine 0.3.6.dev13__py3-none-any.whl → 0.3.6.dev14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +2 -3
  2. diffsynth_engine/conf/models/wan/dit/{14b-i2v.json → wan2.1-flf2v-14b.json} +5 -2
  3. diffsynth_engine/conf/models/wan/dit/{14b-flf2v.json → wan2.1-i2v-14b.json} +2 -2
  4. diffsynth_engine/conf/models/wan/dit/{1.3b-t2v.json → wan2.1-t2v-1.3b.json} +0 -1
  5. diffsynth_engine/conf/models/wan/dit/{14b-t2v.json → wan2.1-t2v-14b.json} +0 -1
  6. diffsynth_engine/conf/models/wan/dit/wan2.2-i2v-a14b.json +16 -0
  7. diffsynth_engine/conf/models/wan/dit/wan2.2-t2v-a14b.json +16 -0
  8. diffsynth_engine/conf/models/wan/dit/wan2.2-ti2v-5b.json +14 -0
  9. diffsynth_engine/conf/models/wan/vae/wan2.1-vae.json +48 -0
  10. diffsynth_engine/conf/models/wan/vae/wan2.2-vae.json +112 -0
  11. diffsynth_engine/configs/pipeline.py +6 -1
  12. diffsynth_engine/models/wan/wan_dit.py +52 -32
  13. diffsynth_engine/models/wan/wan_vae.py +355 -60
  14. diffsynth_engine/pipelines/base.py +15 -11
  15. diffsynth_engine/pipelines/wan_video.py +175 -74
  16. diffsynth_engine/utils/constants.py +10 -4
  17. diffsynth_engine/utils/parallel.py +3 -1
  18. {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/METADATA +1 -1
  19. {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/RECORD +22 -17
  20. {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/WHEEL +0 -0
  21. {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/licenses/LICENSE +0 -0
  22. {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/top_level.txt +0 -0
@@ -9,13 +9,12 @@ class FlowMatchEulerSampler:
9
9
  self.mask = mask
10
10
 
11
11
  def step(self, latents, model_outputs, i):
12
- if self.mask is not None:
13
- model_outputs = model_outputs * self.mask + self.init_latents * (1 - self.mask)
14
-
15
12
  dt = self.sigmas[i + 1] - self.sigmas[i]
16
13
  latents = latents.to(dtype=torch.float32)
17
14
  latents = latents + model_outputs * dt
18
15
  latents = latents.to(dtype=model_outputs.dtype)
16
+ if self.mask is not None:
17
+ latents = latents * self.mask + self.init_latents * (1 - self.mask)
19
18
  return latents
20
19
 
21
20
  def add_noise(self, latents, noise, sigma):
@@ -1,5 +1,7 @@
1
1
  {
2
- "has_image_input": true,
2
+ "has_clip_feature": true,
3
+ "has_vae_feature": true,
4
+ "flf_pos_emb": true,
3
5
  "patch_size": [1, 2, 2],
4
6
  "in_dim": 36,
5
7
  "dim": 5120,
@@ -9,5 +11,6 @@
9
11
  "out_dim": 16,
10
12
  "num_heads": 40,
11
13
  "num_layers": 40,
12
- "eps": 1e-6
14
+ "eps": 1e-6,
15
+ "shift": 16.0
13
16
  }
@@ -1,6 +1,6 @@
1
1
  {
2
- "has_image_input": true,
3
- "flf_pos_emb": true,
2
+ "has_clip_feature": true,
3
+ "has_vae_feature": true,
4
4
  "patch_size": [1, 2, 2],
5
5
  "in_dim": 36,
6
6
  "dim": 5120,
@@ -1,5 +1,4 @@
1
1
  {
2
- "has_image_input": false,
3
2
  "patch_size": [1, 2, 2],
4
3
  "in_dim": 16,
5
4
  "dim": 1536,
@@ -1,5 +1,4 @@
1
1
  {
2
- "has_image_input": false,
3
2
  "patch_size": [1, 2, 2],
4
3
  "in_dim": 16,
5
4
  "dim": 5120,
@@ -0,0 +1,16 @@
1
+ {
2
+ "has_vae_feature": true,
3
+ "patch_size": [1, 2, 2],
4
+ "in_dim": 36,
5
+ "dim": 5120,
6
+ "ffn_dim": 13824,
7
+ "freq_dim": 256,
8
+ "text_dim": 4096,
9
+ "out_dim": 16,
10
+ "num_heads": 40,
11
+ "num_layers": 40,
12
+ "eps": 1e-6,
13
+ "boundary": 0.900,
14
+ "cfg_scale": [3.5, 3.5],
15
+ "num_inference_steps": 40
16
+ }
@@ -0,0 +1,16 @@
1
+ {
2
+ "patch_size": [1, 2, 2],
3
+ "in_dim": 16,
4
+ "dim": 5120,
5
+ "ffn_dim": 13824,
6
+ "freq_dim": 256,
7
+ "text_dim": 4096,
8
+ "out_dim": 16,
9
+ "num_heads": 40,
10
+ "num_layers": 40,
11
+ "eps": 1e-6,
12
+ "boundary": 0.875,
13
+ "shift": 12.0,
14
+ "cfg_scale": [3.0, 4.0],
15
+ "num_inference_steps": 40
16
+ }
@@ -0,0 +1,14 @@
1
+ {
2
+ "fuse_image_latents": true,
3
+ "patch_size": [1, 2, 2],
4
+ "in_dim": 48,
5
+ "dim": 3072,
6
+ "ffn_dim": 14336,
7
+ "freq_dim": 256,
8
+ "text_dim": 4096,
9
+ "out_dim": 48,
10
+ "num_heads": 24,
11
+ "num_layers": 30,
12
+ "eps": 1e-6,
13
+ "fps": 24
14
+ }
@@ -0,0 +1,48 @@
1
+ {
2
+ "in_channels": 3,
3
+ "out_channels": 3,
4
+ "encoder_dim": 96,
5
+ "decoder_dim": 96,
6
+ "z_dim": 16,
7
+ "dim_mult": [1, 2, 4, 4],
8
+ "num_res_blocks": 2,
9
+ "temperal_downsample": [false, true, true],
10
+ "dropout": 0.0,
11
+ "patch_size": 1,
12
+ "mean": [
13
+ -0.7571,
14
+ -0.7089,
15
+ -0.9113,
16
+ 0.1075,
17
+ -0.1745,
18
+ 0.9653,
19
+ -0.1517,
20
+ 1.5508,
21
+ 0.4134,
22
+ -0.0715,
23
+ 0.5517,
24
+ -0.3632,
25
+ -0.1922,
26
+ -0.9497,
27
+ 0.2503,
28
+ -0.2921
29
+ ],
30
+ "std": [
31
+ 2.8184,
32
+ 1.4541,
33
+ 2.3275,
34
+ 2.6558,
35
+ 1.2196,
36
+ 1.7708,
37
+ 2.6052,
38
+ 2.0743,
39
+ 3.2687,
40
+ 2.1526,
41
+ 2.8652,
42
+ 1.5579,
43
+ 1.6382,
44
+ 1.1253,
45
+ 2.8251,
46
+ 1.9160
47
+ ]
48
+ }
@@ -0,0 +1,112 @@
1
+ {
2
+ "in_channels": 12,
3
+ "out_channels": 12,
4
+ "encoder_dim": 160,
5
+ "decoder_dim": 256,
6
+ "z_dim": 48,
7
+ "dim_mult": [1, 2, 4, 4],
8
+ "num_res_blocks": 2,
9
+ "temperal_downsample": [false, true, true],
10
+ "dropout": 0.0,
11
+ "patch_size": 2,
12
+ "mean": [
13
+ -0.2289,
14
+ -0.0052,
15
+ -0.1323,
16
+ -0.2339,
17
+ -0.2799,
18
+ 0.0174,
19
+ 0.1838,
20
+ 0.1557,
21
+ -0.1382,
22
+ 0.0542,
23
+ 0.2813,
24
+ 0.0891,
25
+ 0.1570,
26
+ -0.0098,
27
+ 0.0375,
28
+ -0.1825,
29
+ -0.2246,
30
+ -0.1207,
31
+ -0.0698,
32
+ 0.5109,
33
+ 0.2665,
34
+ -0.2108,
35
+ -0.2158,
36
+ 0.2502,
37
+ -0.2055,
38
+ -0.0322,
39
+ 0.1109,
40
+ 0.1567,
41
+ -0.0729,
42
+ 0.0899,
43
+ -0.2799,
44
+ -0.1230,
45
+ -0.0313,
46
+ -0.1649,
47
+ 0.0117,
48
+ 0.0723,
49
+ -0.2839,
50
+ -0.2083,
51
+ -0.0520,
52
+ 0.3748,
53
+ 0.0152,
54
+ 0.1957,
55
+ 0.1433,
56
+ -0.2944,
57
+ 0.3573,
58
+ -0.0548,
59
+ -0.1681,
60
+ -0.0667
61
+ ],
62
+ "std": [
63
+ 0.4765,
64
+ 1.0364,
65
+ 0.4514,
66
+ 1.1677,
67
+ 0.5313,
68
+ 0.4990,
69
+ 0.4818,
70
+ 0.5013,
71
+ 0.8158,
72
+ 1.0344,
73
+ 0.5894,
74
+ 1.0901,
75
+ 0.6885,
76
+ 0.6165,
77
+ 0.8454,
78
+ 0.4978,
79
+ 0.5759,
80
+ 0.3523,
81
+ 0.7135,
82
+ 0.6804,
83
+ 0.5833,
84
+ 1.4146,
85
+ 0.8986,
86
+ 0.5659,
87
+ 0.7069,
88
+ 0.5338,
89
+ 0.4889,
90
+ 0.4917,
91
+ 0.4069,
92
+ 0.4999,
93
+ 0.6866,
94
+ 0.4093,
95
+ 0.5709,
96
+ 0.6065,
97
+ 0.6415,
98
+ 0.4944,
99
+ 0.5726,
100
+ 1.2042,
101
+ 0.5458,
102
+ 1.6887,
103
+ 0.3971,
104
+ 1.0600,
105
+ 0.3943,
106
+ 0.5537,
107
+ 0.5444,
108
+ 0.4089,
109
+ 0.7468,
110
+ 0.7744
111
+ ]
112
+ }
@@ -139,7 +139,12 @@ class WanPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, Bas
139
139
  vae_dtype: torch.dtype = torch.bfloat16
140
140
  image_encoder_dtype: torch.dtype = torch.bfloat16
141
141
 
142
- shift: Optional[float] = field(default=None, init=False) # RecifitedFlowScheduler shift factor, set by model type
142
+ # default params set by model type
143
+ boundary: Optional[float] = field(default=None, init=False) # boundary
144
+ shift: Optional[float] = field(default=None, init=False) # RecifitedFlowScheduler shift factor
145
+ cfg_scale: Optional[float | Tuple[float, float]] = field(default=None, init=False) # default CFG scale
146
+ num_inference_steps: Optional[int] = field(default=None, init=False) # default inference steps
147
+ fps: Optional[int] = field(default=None, init=False) # default FPS
143
148
 
144
149
  # override BaseConfig
145
150
  vae_tiled: bool = True
@@ -10,10 +10,13 @@ from diffsynth_engine.models.basic import attention as attention_ops
10
10
  from diffsynth_engine.models.basic.transformer_helper import RMSNorm
11
11
  from diffsynth_engine.models.utils import no_init_weights
12
12
  from diffsynth_engine.utils.constants import (
13
- WAN_DIT_1_3B_T2V_CONFIG_FILE,
14
- WAN_DIT_14B_I2V_CONFIG_FILE,
15
- WAN_DIT_14B_T2V_CONFIG_FILE,
16
- WAN_DIT_14B_FLF2V_CONFIG_FILE,
13
+ WAN2_1_DIT_T2V_1_3B_CONFIG_FILE,
14
+ WAN2_1_DIT_I2V_14B_CONFIG_FILE,
15
+ WAN2_1_DIT_T2V_14B_CONFIG_FILE,
16
+ WAN2_1_DIT_FLF2V_14B_CONFIG_FILE,
17
+ WAN2_2_DIT_TI2V_5B_CONFIG_FILE,
18
+ WAN2_2_DIT_I2V_A14B_CONFIG_FILE,
19
+ WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
17
20
  )
18
21
  from diffsynth_engine.utils.gguf import gguf_inference
19
22
  from diffsynth_engine.utils.parallel import (
@@ -182,7 +185,9 @@ class DiTBlock(nn.Module):
182
185
 
183
186
  def forward(self, x, context, t_mod, freqs):
184
187
  # msa: multi-head self-attention mlp: multi-layer perceptron
185
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + t_mod).chunk(6, dim=1)
188
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
189
+ t.squeeze(1) for t in (self.modulation + t_mod).chunk(6, dim=1)
190
+ ]
186
191
  input_x = modulate(self.norm1(x), shift_msa, scale_msa)
187
192
  x = x + gate_msa * self.self_attn(input_x, freqs)
188
193
  x = x + self.cross_attn(self.norm3(x), context)
@@ -237,7 +242,7 @@ class Head(nn.Module):
237
242
  self.modulation = nn.Parameter(torch.randn(1, 2, dim, device=device, dtype=dtype) / dim**0.5)
238
243
 
239
244
  def forward(self, x, t_mod):
240
- shift, scale = (self.modulation + t_mod).chunk(2, dim=1)
245
+ shift, scale = [t.squeeze(1) for t in (self.modulation + t_mod.unsqueeze(1)).chunk(2, dim=1)]
241
246
  x = self.head(self.norm(x) * (1 + scale) + shift)
242
247
  return x
243
248
 
@@ -263,17 +268,22 @@ class WanDiT(PreTrainedModel):
263
268
  patch_size: Tuple[int, int, int],
264
269
  num_heads: int,
265
270
  num_layers: int,
266
- has_image_input: bool,
271
+ has_clip_feature: bool = False,
272
+ has_vae_feature: bool = False,
273
+ fuse_image_latents: bool = False,
267
274
  flf_pos_emb: bool = False,
268
275
  attn_kwargs: Optional[Dict[str, Any]] = None,
269
- device: str = "cpu",
276
+ device: str = "cuda:0",
270
277
  dtype: torch.dtype = torch.bfloat16,
271
278
  ):
272
279
  super().__init__()
273
280
 
281
+ self.in_dim = in_dim
274
282
  self.dim = dim
275
283
  self.freq_dim = freq_dim
276
- self.has_image_input = has_image_input
284
+ self.has_clip_feature = has_clip_feature
285
+ self.has_vae_feature = has_vae_feature
286
+ self.fuse_image_latents = fuse_image_latents
277
287
  self.patch_size = patch_size
278
288
 
279
289
  self.patch_embedding = nn.Conv3d(
@@ -296,7 +306,7 @@ class WanDiT(PreTrainedModel):
296
306
  )
297
307
  self.blocks = nn.ModuleList(
298
308
  [
299
- DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps, attn_kwargs, device=device, dtype=dtype)
309
+ DiTBlock(has_clip_feature, dim, num_heads, ffn_dim, eps, attn_kwargs, device=device, dtype=dtype)
300
310
  for _ in range(num_layers)
301
311
  ]
302
312
  )
@@ -305,7 +315,7 @@ class WanDiT(PreTrainedModel):
305
315
  head_dim = dim // num_heads
306
316
  self.freqs = precompute_freqs_cis_3d(head_dim)
307
317
 
308
- if has_image_input:
318
+ if has_clip_feature:
309
319
  self.img_emb = MLP(1280, dim, flf_pos_emb, device=device, dtype=dtype) # clip_feature_dim = 1280
310
320
 
311
321
  def patchify(self, x: torch.Tensor):
@@ -339,13 +349,14 @@ class WanDiT(PreTrainedModel):
339
349
  gguf_inference(),
340
350
  cfg_parallel((x, context, timestep, clip_feature, y), use_cfg=use_cfg),
341
351
  ):
342
- t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
343
- t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
352
+ t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) # (s, d)
353
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) # (s, 6, d)
344
354
  context = self.text_embedding(context)
345
- if self.has_image_input:
355
+ if self.has_vae_feature:
346
356
  x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
347
- clip_embdding = self.img_emb(clip_feature)
348
- context = torch.cat([clip_embdding, context], dim=1) # (b, s1 + s2, d)
357
+ if self.has_clip_feature:
358
+ clip_embedding = self.img_emb(clip_feature)
359
+ context = torch.cat([clip_embedding, context], dim=1) # (b, s1 + s2, d)
349
360
  x, (f, h, w) = self.patchify(x)
350
361
  freqs = (
351
362
  torch.cat(
@@ -360,7 +371,7 @@ class WanDiT(PreTrainedModel):
360
371
  .to(x.device)
361
372
  )
362
373
 
363
- with sequence_parallel((x, freqs), seq_dims=(1, 0)):
374
+ with sequence_parallel((x, t, t_mod, freqs), seq_dims=(1, 0, 0, 0)):
364
375
  for block in self.blocks:
365
376
  x = block(x, context, t_mod, freqs)
366
377
  x = self.head(x, t)
@@ -369,26 +380,35 @@ class WanDiT(PreTrainedModel):
369
380
  (x,) = cfg_parallel_unshard((x,), use_cfg=use_cfg)
370
381
  return x
371
382
 
383
+ @staticmethod
384
+ def get_model_config(model_type: str):
385
+ MODEL_CONFIG_FILES = {
386
+ "wan2.1-t2v-1.3b": WAN2_1_DIT_T2V_1_3B_CONFIG_FILE,
387
+ "wan2.1-t2v-14b": WAN2_1_DIT_T2V_14B_CONFIG_FILE,
388
+ "wan2.1-i2v-14b": WAN2_1_DIT_I2V_14B_CONFIG_FILE,
389
+ "wan2.1-flf2v-14b": WAN2_1_DIT_FLF2V_14B_CONFIG_FILE,
390
+ "wan2.2-ti2v-5b": WAN2_2_DIT_TI2V_5B_CONFIG_FILE,
391
+ "wan2.2-t2v-a14b": WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
392
+ "wan2.2-i2v-a14b": WAN2_2_DIT_I2V_A14B_CONFIG_FILE,
393
+ }
394
+ if model_type not in MODEL_CONFIG_FILES:
395
+ raise ValueError(f"Unsupported model type: {model_type}")
396
+
397
+ config_file = MODEL_CONFIG_FILES[model_type]
398
+ with open(config_file, "r") as f:
399
+ config = json.load(f)
400
+ return config
401
+
372
402
  @classmethod
373
403
  def from_state_dict(
374
404
  cls,
375
- state_dict,
376
- device,
377
- dtype,
378
- model_type="1.3b-t2v",
405
+ state_dict: Dict[str, torch.Tensor],
406
+ config: Dict[str, Any],
407
+ device: str = "cuda:0",
408
+ dtype: torch.dtype = torch.bfloat16,
379
409
  attn_kwargs: Optional[Dict[str, Any]] = None,
380
- assign=True,
410
+ assign: bool = True,
381
411
  ):
382
- if model_type == "1.3b-t2v":
383
- config = json.load(open(WAN_DIT_1_3B_T2V_CONFIG_FILE, "r"))
384
- elif model_type == "14b-t2v":
385
- config = json.load(open(WAN_DIT_14B_T2V_CONFIG_FILE, "r"))
386
- elif model_type == "14b-i2v":
387
- config = json.load(open(WAN_DIT_14B_I2V_CONFIG_FILE, "r"))
388
- elif model_type == "14b-flf2v":
389
- config = json.load(open(WAN_DIT_14B_FLF2V_CONFIG_FILE, "r"))
390
- else:
391
- raise ValueError(f"Unsupported model type: {model_type}")
392
412
  with no_init_weights():
393
413
  model = torch.nn.utils.skip_init(cls, **config, device=device, dtype=dtype, attn_kwargs=attn_kwargs)
394
414
  model = model.requires_grad_(False)