diffsynth-engine 0.3.6.dev13__py3-none-any.whl → 0.3.6.dev14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +2 -3
  2. diffsynth_engine/conf/models/wan/dit/{14b-i2v.json → wan2.1-flf2v-14b.json} +5 -2
  3. diffsynth_engine/conf/models/wan/dit/{14b-flf2v.json → wan2.1-i2v-14b.json} +2 -2
  4. diffsynth_engine/conf/models/wan/dit/{1.3b-t2v.json → wan2.1-t2v-1.3b.json} +0 -1
  5. diffsynth_engine/conf/models/wan/dit/{14b-t2v.json → wan2.1-t2v-14b.json} +0 -1
  6. diffsynth_engine/conf/models/wan/dit/wan2.2-i2v-a14b.json +16 -0
  7. diffsynth_engine/conf/models/wan/dit/wan2.2-t2v-a14b.json +16 -0
  8. diffsynth_engine/conf/models/wan/dit/wan2.2-ti2v-5b.json +14 -0
  9. diffsynth_engine/conf/models/wan/vae/wan2.1-vae.json +48 -0
  10. diffsynth_engine/conf/models/wan/vae/wan2.2-vae.json +112 -0
  11. diffsynth_engine/configs/pipeline.py +6 -1
  12. diffsynth_engine/models/wan/wan_dit.py +52 -32
  13. diffsynth_engine/models/wan/wan_vae.py +355 -60
  14. diffsynth_engine/pipelines/base.py +15 -11
  15. diffsynth_engine/pipelines/wan_video.py +175 -74
  16. diffsynth_engine/utils/constants.py +10 -4
  17. diffsynth_engine/utils/parallel.py +3 -1
  18. {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/METADATA +1 -1
  19. {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/RECORD +22 -17
  20. {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/WHEEL +0 -0
  21. {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/licenses/LICENSE +0 -0
  22. {diffsynth_engine-0.3.6.dev13.dist-info → diffsynth_engine-0.3.6.dev14.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,5 @@
1
1
  import torch
2
2
  import torch.distributed as dist
3
- import numpy as np
4
- from einops import rearrange
5
3
  from typing import Callable, List, Tuple, Optional
6
4
  from tqdm import tqdm
7
5
  from PIL import Image
@@ -97,14 +95,6 @@ class WanLoRAConverter(LoRAStateDictConverter):
97
95
  return state_dict
98
96
 
99
97
 
100
- SHIFT_FACTORS = {
101
- "1.3b-t2v": 5.0,
102
- "14b-t2v": 5.0,
103
- "14b-i2v": 5.0,
104
- "14b-flf2v": 16.0,
105
- }
106
-
107
-
108
98
  class WanVideoPipeline(BasePipeline):
109
99
  lora_converter = WanLoRAConverter()
110
100
 
@@ -114,6 +104,7 @@ class WanVideoPipeline(BasePipeline):
114
104
  tokenizer: WanT5Tokenizer,
115
105
  text_encoder: WanTextEncoder,
116
106
  dit: WanDiT,
107
+ dit2: WanDiT | None,
117
108
  vae: WanVideoVAE,
118
109
  image_encoder: WanImageEncoder,
119
110
  ):
@@ -125,6 +116,7 @@ class WanVideoPipeline(BasePipeline):
125
116
  dtype=config.model_dtype,
126
117
  )
127
118
  self.config = config
119
+ self.upsampling_factor = vae.upsampling_factor
128
120
  # sampler
129
121
  self.noise_scheduler = RecifitedFlowScheduler(
130
122
  shift=config.shift if config.shift is not None else 5.0,
@@ -135,10 +127,11 @@ class WanVideoPipeline(BasePipeline):
135
127
  # models
136
128
  self.tokenizer = tokenizer
137
129
  self.text_encoder = text_encoder
138
- self.dit = dit
130
+ self.dit = dit # high noise model
131
+ self.dit2 = dit2 # low noise model
139
132
  self.vae = vae
140
133
  self.image_encoder = image_encoder
141
- self.model_names = ["text_encoder", "dit", "vae", "image_encoder"]
134
+ self.model_names = ["text_encoder", "dit", "dit2", "vae", "image_encoder"]
142
135
 
143
136
  def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
144
137
  assert self.config.tp_degree is None or self.config.tp_degree == 1, (
@@ -166,40 +159,62 @@ class WanVideoPipeline(BasePipeline):
166
159
  prompt_emb = prompt_emb.masked_fill(mask.unsqueeze(-1).expand_as(prompt_emb) == 0, 0)
167
160
  return prompt_emb
168
161
 
169
- def encode_image(self, images: Image.Image | List[Image.Image], num_frames, height, width):
162
+ def encode_clip_feature(self, images: Image.Image | List[Image.Image], height, width):
163
+ if not images or not self.dit.has_clip_feature:
164
+ return None
165
+
166
+ self.load_models_to_device(["image_encoder"])
170
167
  if isinstance(images, Image.Image):
171
168
  images = [images]
172
- images = [
173
- self.preprocess_image(image.resize((width, height), Image.Resampling.LANCZOS)).to(
174
- device=self.device, dtype=self.config.image_encoder_dtype
175
- )
176
- for image in images
177
- ]
169
+ images = [self.preprocess_image(img.resize((width, height), Image.Resampling.LANCZOS)) for img in images]
170
+ images = [img.to(device=self.device, dtype=self.config.image_encoder_dtype) for img in images]
178
171
  clip_context = self.image_encoder.encode_image(images).to(self.dtype)
172
+ return clip_context
173
+
174
+ def encode_vae_feature(self, images: Image.Image | List[Image.Image], num_frames, height, width):
175
+ if not images or not self.dit.has_vae_feature:
176
+ return None
179
177
 
178
+ self.load_models_to_device(["vae"])
179
+ if isinstance(images, Image.Image):
180
+ images = [images]
181
+ images = [self.preprocess_image(img.resize((width, height), Image.Resampling.LANCZOS)) for img in images]
180
182
  indices = torch.linspace(0, num_frames - 1, len(images), dtype=torch.long)
181
- msk = torch.zeros(1, num_frames, height // 8, width // 8, device=self.device, dtype=self.config.vae_dtype)
183
+ msk = torch.zeros(
184
+ 1,
185
+ num_frames,
186
+ height // self.upsampling_factor,
187
+ width // self.upsampling_factor,
188
+ device=self.device,
189
+ dtype=self.config.vae_dtype,
190
+ )
182
191
  msk[:, indices] = 1
183
192
  msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
184
- msk = msk.view(1, msk.shape[1] // 4, 4, height // 8, width // 8)
193
+ msk = msk.view(1, msk.shape[1] // 4, 4, height // self.upsampling_factor, width // self.upsampling_factor)
185
194
  msk = msk.transpose(1, 2).squeeze(0)
186
195
 
187
196
  video = torch.zeros(3, num_frames, height, width).to(device=self.device, dtype=self.config.vae_dtype)
188
- video[:, indices] = torch.concat([image.transpose(0, 1) for image in images], dim=1).to(
189
- dtype=self.config.vae_dtype
197
+ video[:, indices] = torch.concat([img.transpose(0, 1) for img in images], dim=1).to(
198
+ device=self.device, dtype=self.config.vae_dtype
190
199
  )
191
200
  y = self.vae.encode([video], device=self.device)[0]
192
201
  y = torch.concat([msk, y]).to(dtype=self.dtype)
193
- return clip_context, y.unsqueeze(0)
202
+ return y.unsqueeze(0)
194
203
 
195
- def tensor2video(self, frames):
196
- frames = rearrange(frames, "C T H W -> T H W C")
197
- frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
198
- frames = [Image.fromarray(frame) for frame in frames]
199
- return frames
204
+ def encode_image_latents(self, images: Image.Image | List[Image.Image], height, width):
205
+ if not images or not self.dit.fuse_image_latents:
206
+ return
200
207
 
201
- def encode_video(self, videos: torch.Tensor):
202
- videos = videos.to(dtype=self.config.vae_dtype, device=self.device)
208
+ self.load_models_to_device(["vae"])
209
+ if isinstance(images, Image.Image):
210
+ images = [images]
211
+ frames = [self.preprocess_image(img.resize((width, height), Image.Resampling.LANCZOS)) for img in images]
212
+ video = torch.stack(frames, dim=2).squeeze(0)
213
+ latents = self.encode_video([video]).to(dtype=self.dtype, device=self.device)
214
+ return latents
215
+
216
+ def encode_video(self, videos: List[torch.Tensor]) -> torch.Tensor:
217
+ videos = [video.to(dtype=self.config.vae_dtype, device=self.device) for video in videos]
203
218
  latents = self.vae.encode(
204
219
  videos,
205
220
  device=self.device,
@@ -210,7 +225,7 @@ class WanVideoPipeline(BasePipeline):
210
225
  latents = latents.to(dtype=self.config.model_dtype, device=self.device)
211
226
  return latents
212
227
 
213
- def decode_video(self, latents, progress_callback=None) -> List[torch.Tensor]:
228
+ def decode_video(self, latents: torch.Tensor, progress_callback=None) -> List[torch.Tensor]:
214
229
  latents = latents.to(dtype=self.config.vae_dtype, device=self.device)
215
230
  videos = self.vae.decode(
216
231
  latents,
@@ -225,6 +240,7 @@ class WanVideoPipeline(BasePipeline):
225
240
 
226
241
  def predict_noise_with_cfg(
227
242
  self,
243
+ model: WanDiT,
228
244
  latents: torch.Tensor,
229
245
  image_clip_feature: torch.Tensor,
230
246
  image_y: torch.Tensor,
@@ -236,6 +252,7 @@ class WanVideoPipeline(BasePipeline):
236
252
  ):
237
253
  if cfg_scale <= 1.0:
238
254
  return self.predict_noise(
255
+ model=model,
239
256
  latents=latents,
240
257
  image_clip_feature=image_clip_feature,
241
258
  image_y=image_y,
@@ -245,6 +262,7 @@ class WanVideoPipeline(BasePipeline):
245
262
  if not batch_cfg:
246
263
  # cfg by predict noise one by one
247
264
  positive_noise_pred = self.predict_noise(
265
+ model=model,
248
266
  latents=latents,
249
267
  image_clip_feature=image_clip_feature,
250
268
  image_y=image_y,
@@ -252,6 +270,7 @@ class WanVideoPipeline(BasePipeline):
252
270
  context=positive_prompt_emb,
253
271
  )
254
272
  negative_noise_pred = self.predict_noise(
273
+ model=model,
255
274
  latents=latents,
256
275
  image_clip_feature=image_clip_feature,
257
276
  image_y=image_y,
@@ -270,6 +289,7 @@ class WanVideoPipeline(BasePipeline):
270
289
  if image_clip_feature is not None:
271
290
  image_clip_feature = torch.cat([image_clip_feature, image_clip_feature], dim=0)
272
291
  positive_noise_pred, negative_noise_pred = self.predict_noise(
292
+ model=model,
273
293
  latents=latents,
274
294
  image_clip_feature=image_clip_feature,
275
295
  image_y=image_y,
@@ -279,10 +299,10 @@ class WanVideoPipeline(BasePipeline):
279
299
  noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
280
300
  return noise_pred
281
301
 
282
- def predict_noise(self, latents, image_clip_feature, image_y, timestep, context):
302
+ def predict_noise(self, model, latents, image_clip_feature, image_y, timestep, context):
283
303
  latents = latents.to(dtype=self.config.model_dtype, device=self.device)
284
304
 
285
- noise_pred = self.dit(
305
+ noise_pred = model(
286
306
  x=latents,
287
307
  timestep=timestep,
288
308
  context=context,
@@ -298,17 +318,23 @@ class WanVideoPipeline(BasePipeline):
298
318
  denoising_strength,
299
319
  num_inference_steps,
300
320
  ):
301
- if input_video is not None:
321
+ height, width = latents.shape[-2:]
322
+ height, width = height * self.upsampling_factor, width * self.upsampling_factor
323
+ if input_video is not None: # video to video
302
324
  total_steps = num_inference_steps
303
325
  sigmas, timesteps = self.noise_scheduler.schedule(total_steps)
304
326
  t_start = max(total_steps - int(num_inference_steps * denoising_strength), 1)
305
327
  sigma_start, sigmas = sigmas[t_start - 1], sigmas[t_start - 1 :]
306
328
  timesteps = timesteps[t_start - 1 :]
307
329
 
330
+ self.load_models_to_device(["vae"])
308
331
  noise = latents
309
- input_video = self.preprocess_images(input_video)
310
- input_video = torch.stack(input_video, dim=2)
311
- latents = self.encode_video(input_video).to(dtype=latents.dtype, device=latents.device)
332
+ frames = [
333
+ self.preprocess_image(frame.resize((width, height), Image.Resampling.LANCZOS)) for frame in input_video
334
+ ]
335
+ video = torch.stack(frames, dim=2).squeeze(0)
336
+ video = video.to(dtype=self.config.vae_dtype, device=self.device)
337
+ latents = self.encode_video([video]).to(dtype=latents.dtype, device=latents.device)
312
338
  init_latents = latents.clone()
313
339
  latents = self.sampler.add_noise(latents, noise, sigma_start)
314
340
  else:
@@ -329,18 +355,29 @@ class WanVideoPipeline(BasePipeline):
329
355
  height=480,
330
356
  width=832,
331
357
  num_frames=81,
332
- cfg_scale=5.0,
333
- num_inference_steps=50,
358
+ cfg_scale=None,
359
+ num_inference_steps=None,
334
360
  progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
335
361
  ):
336
362
  assert height % 16 == 0 and width % 16 == 0, "height and width must be divisible by 16"
337
363
  assert (num_frames - 1) % 4 == 0, "num_frames must be 4X+1"
364
+ cfg_scale = self.config.cfg_scale if cfg_scale is None else cfg_scale
365
+ num_inference_steps = self.config.num_inference_steps if num_inference_steps is None else num_inference_steps
338
366
 
339
367
  # Initialize noise
340
368
  if dist.is_initialized() and seed is None:
341
369
  raise ValueError("must provide a seed when parallelism is enabled")
342
370
  noise = self.generate_noise(
343
- (1, 16, (num_frames - 1) // 4 + 1, height // 8, width // 8), seed=seed, device="cpu", dtype=torch.float32
371
+ (
372
+ 1,
373
+ self.vae.z_dim,
374
+ (num_frames - 1) // 4 + 1,
375
+ height // self.upsampling_factor,
376
+ width // self.upsampling_factor,
377
+ ),
378
+ seed=seed,
379
+ device="cpu",
380
+ dtype=torch.float32,
344
381
  ).to(self.device)
345
382
  init_latents, latents, sigmas, timesteps = self.prepare_latents(
346
383
  noise,
@@ -348,33 +385,49 @@ class WanVideoPipeline(BasePipeline):
348
385
  denoising_strength,
349
386
  num_inference_steps,
350
387
  )
351
- self.sampler.initialize(init_latents=init_latents, timesteps=timesteps, sigmas=sigmas)
388
+ mask = torch.ones((1, 1, *latents.shape[2:]), dtype=latents.dtype, device=latents.device)
389
+
352
390
  # Encode prompts
353
391
  self.load_models_to_device(["text_encoder"])
354
392
  prompt_emb_posi = self.encode_prompt(prompt)
355
- prompt_emb_nega = None if cfg_scale <= 1.0 else self.encode_prompt(negative_prompt)
393
+ prompt_emb_nega = self.encode_prompt(negative_prompt)
356
394
 
357
395
  # Encode image
358
- if input_image is not None and self.image_encoder is not None:
359
- self.load_models_to_device(["image_encoder", "vae"])
360
- image_clip_feature, image_y = self.encode_image(input_image, num_frames, height, width)
361
- else:
362
- image_clip_feature, image_y = None, None
396
+ image_clip_feature = self.encode_clip_feature(input_image, height, width)
397
+ image_y = self.encode_vae_feature(input_image, num_frames, height, width)
398
+ image_latents = self.encode_image_latents(input_image, height, width)
399
+ if image_latents is not None:
400
+ latents[:, :, : image_latents.shape[2], :, :] = image_latents
401
+ init_latents = latents.clone()
402
+ mask[:, :, : image_latents.shape[2], :, :] = 0
403
+
404
+ # Initialize sampler
405
+ self.sampler.initialize(init_latents=init_latents, timesteps=timesteps, sigmas=sigmas, mask=mask)
363
406
 
364
407
  # Denoise
365
- self.load_models_to_device(["dit"])
366
408
  hide_progress = dist.is_initialized() and dist.get_rank() != 0
367
409
  for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)):
368
- timestep = timestep.unsqueeze(0).to(dtype=self.config.model_dtype, device=self.device)
410
+ if timestep.item() / 1000 >= self.config.boundary:
411
+ self.load_models_to_device(["dit"])
412
+ model = self.dit
413
+ cfg_scale_ = cfg_scale if isinstance(cfg_scale, float) else cfg_scale[1]
414
+ else:
415
+ self.load_models_to_device(["dit2"])
416
+ model = self.dit2
417
+ cfg_scale_ = cfg_scale if isinstance(cfg_scale, float) else cfg_scale[0]
418
+
419
+ timestep = timestep * mask[:, :, :, ::2, ::2].flatten() # seq_len
420
+ timestep = timestep.to(dtype=self.config.model_dtype, device=self.device)
369
421
  # Classifier-free guidance
370
422
  noise_pred = self.predict_noise_with_cfg(
423
+ model=model,
371
424
  latents=latents,
372
425
  timestep=timestep,
373
426
  positive_prompt_emb=prompt_emb_posi,
374
427
  negative_prompt_emb=prompt_emb_nega,
375
428
  image_clip_feature=image_clip_feature,
376
429
  image_y=image_y,
377
- cfg_scale=cfg_scale,
430
+ cfg_scale=cfg_scale_,
378
431
  batch_cfg=self.config.batch_cfg,
379
432
  )
380
433
  # Scheduler
@@ -385,7 +438,7 @@ class WanVideoPipeline(BasePipeline):
385
438
  # Decode
386
439
  self.load_models_to_device(["vae"])
387
440
  frames = self.decode_video(latents, progress_callback=progress_callback)
388
- frames = self.tensor2video(frames[0])
441
+ frames = self.vae_output_to_image(frames)
389
442
  return frames
390
443
 
391
444
  @classmethod
@@ -395,24 +448,73 @@ class WanVideoPipeline(BasePipeline):
395
448
  else:
396
449
  config = model_path_or_config
397
450
 
451
+ dit_state_dict, dit2_state_dict = None, None
452
+ if isinstance(config.model_path, list):
453
+ high_noise_model_ckpt = [path for path in config.model_path if "high_noise_model" in path]
454
+ low_noise_model_ckpt = [path for path in config.model_path if "low_noise_model" in path]
455
+ if high_noise_model_ckpt and low_noise_model_ckpt:
456
+ logger.info(f"loading high noise model state dict from {high_noise_model_ckpt} ...")
457
+ dit_state_dict = cls.load_model_checkpoint(
458
+ high_noise_model_ckpt, device="cpu", dtype=config.model_dtype
459
+ )
460
+ logger.info(f"loading low noise model state dict from {low_noise_model_ckpt} ...")
461
+ dit2_state_dict = cls.load_model_checkpoint(
462
+ low_noise_model_ckpt, device="cpu", dtype=config.model_dtype
463
+ )
464
+ if dit_state_dict is None:
465
+ logger.info(f"loading dit state dict from {config.model_path} ...")
466
+ dit_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
467
+
468
+ # determine wan dit type by model params
469
+ dit_type = None
470
+ if dit2_state_dict is not None and dit2_state_dict["patch_embedding.weight"].shape[1] == 36:
471
+ dit_type = "wan2.2-i2v-a14b"
472
+ elif dit2_state_dict is not None and dit2_state_dict["patch_embedding.weight"].shape[1] == 16:
473
+ dit_type = "wan2.2-t2v-a14b"
474
+ elif dit_state_dict["patch_embedding.weight"].shape[1] == 48:
475
+ dit_type = "wan2.2-ti2v-5b"
476
+ elif "img_emb.emb_pos" in dit_state_dict:
477
+ dit_type = "wan2.1-flf2v-14b"
478
+ elif "img_emb.proj.0.weight" in dit_state_dict:
479
+ dit_type = "wan2.1-i2v-14b"
480
+ elif "blocks.39.self_attn.norm_q.weight" in dit_state_dict:
481
+ dit_type = "wan2.1-t2v-14b"
482
+ else:
483
+ dit_type = "wan2.1-t2v-1.3b"
484
+
398
485
  if config.t5_path is None:
399
486
  config.t5_path = fetch_model("muse/wan2.1-umt5", path="umt5.safetensors")
400
487
  if config.vae_path is None:
401
- config.vae_path = fetch_model("muse/wan2.1-vae", path="vae.safetensors")
402
-
403
- logger.info(f"loading state dict from {config.model_path} ...")
404
- dit_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
488
+ config.vae_path = (
489
+ fetch_model("muse/wan2.2-vae", path="vae.safetensors")
490
+ if dit_type == "wan2.2-ti2v-5b"
491
+ else fetch_model("muse/wan2.1-vae", path="vae.safetensors")
492
+ )
405
493
 
406
- logger.info(f"loading state dict from {config.t5_path} ...")
494
+ logger.info(f"loading t5 state dict from {config.t5_path} ...")
407
495
  t5_state_dict = cls.load_model_checkpoint(config.t5_path, device="cpu", dtype=config.t5_dtype)
408
496
 
409
- logger.info(f"loading state dict from {config.vae_path} ...")
497
+ logger.info(f"loading vae state dict from {config.vae_path} ...")
410
498
  vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
411
499
 
500
+ # determine wan vae type by model params
501
+ vae_type = "wan2.1-vae"
502
+ if vae_state_dict["encoder.conv1.weight"].shape[1] == 12: # in_channels
503
+ vae_type = "wan2.2-vae"
504
+
505
+ # default params from model config
506
+ vae_config: dict = WanVideoVAE.get_model_config(vae_type)
507
+ model_config: dict = WanDiT.get_model_config(dit_type)
508
+ config.boundary = model_config.pop("boundary", -1.0)
509
+ config.shift = model_config.pop("shift", 5.0)
510
+ config.cfg_scale = model_config.pop("cfg_scale", 5.0)
511
+ config.num_inference_steps = model_config.pop("num_inference_steps", 50)
512
+ config.fps = model_config.pop("fps", 16)
513
+
412
514
  init_device = "cpu" if config.parallelism > 1 or config.offload_mode is not None else config.device
413
515
  tokenizer = WanT5Tokenizer(WAN_TOKENIZER_CONF_PATH, seq_len=512, clean="whitespace")
414
516
  text_encoder = WanTextEncoder.from_state_dict(t5_state_dict, device=init_device, dtype=config.t5_dtype)
415
- vae = WanVideoVAE.from_state_dict(vae_state_dict, device=init_device, dtype=config.vae_dtype)
517
+ vae = WanVideoVAE.from_state_dict(vae_state_dict, config=vae_config, device=init_device, dtype=config.vae_dtype)
416
518
 
417
519
  image_encoder = None
418
520
  if config.image_encoder_path is not None:
@@ -428,20 +530,6 @@ class WanVideoPipeline(BasePipeline):
428
530
  dtype=config.image_encoder_dtype,
429
531
  )
430
532
 
431
- # determine wan video model type by dit params
432
- model_type = None
433
- if "img_emb.emb_pos" in dit_state_dict:
434
- model_type = "14b-flf2v"
435
- elif "img_emb.proj.0.weight" in dit_state_dict:
436
- model_type = "14b-i2v"
437
- elif "blocks.39.self_attn.norm_q.weight" in dit_state_dict:
438
- model_type = "14b-t2v"
439
- else:
440
- model_type = "1.3b-t2v"
441
-
442
- # shift for different model_type
443
- config.shift = SHIFT_FACTORS[model_type] if config.shift is None else config.shift
444
-
445
533
  with LoRAContext():
446
534
  attn_kwargs = {
447
535
  "attn_impl": config.dit_attn_impl,
@@ -452,7 +540,7 @@ class WanVideoPipeline(BasePipeline):
452
540
  }
453
541
  dit = WanDiT.from_state_dict(
454
542
  dit_state_dict,
455
- model_type=model_type,
543
+ config=model_config,
456
544
  device=init_device,
457
545
  dtype=config.model_dtype,
458
546
  attn_kwargs=attn_kwargs,
@@ -460,11 +548,24 @@ class WanVideoPipeline(BasePipeline):
460
548
  if config.use_fp8_linear:
461
549
  enable_fp8_linear(dit)
462
550
 
551
+ dit2 = None
552
+ if dit2_state_dict is not None:
553
+ dit2 = WanDiT.from_state_dict(
554
+ dit2_state_dict,
555
+ config=model_config,
556
+ device=init_device,
557
+ dtype=config.model_dtype,
558
+ attn_kwargs=attn_kwargs,
559
+ )
560
+ if config.use_fp8_linear:
561
+ enable_fp8_linear(dit2)
562
+
463
563
  pipe = cls(
464
564
  config=config,
465
565
  tokenizer=tokenizer,
466
566
  text_encoder=text_encoder,
467
567
  dit=dit,
568
+ dit2=dit2,
468
569
  vae=vae,
469
570
  image_encoder=image_encoder,
470
571
  )
@@ -23,10 +23,16 @@ SD3_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd3", "sd3_tex
23
23
  SDXL_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_text_encoder.json")
24
24
  SDXL_UNET_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_unet.json")
25
25
 
26
- WAN_DIT_1_3B_T2V_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "1.3b-t2v.json")
27
- WAN_DIT_14B_I2V_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "14b-i2v.json")
28
- WAN_DIT_14B_T2V_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "14b-t2v.json")
29
- WAN_DIT_14B_FLF2V_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "14b-flf2v.json")
26
+ WAN2_1_DIT_T2V_1_3B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-t2v-1.3b.json")
27
+ WAN2_1_DIT_T2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-t2v-14b.json")
28
+ WAN2_1_DIT_I2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-i2v-14b.json")
29
+ WAN2_1_DIT_FLF2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-flf2v-14b.json")
30
+ WAN2_2_DIT_TI2V_5B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-ti2v-5b.json")
31
+ WAN2_2_DIT_T2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-t2v-a14b.json")
32
+ WAN2_2_DIT_I2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-i2v-a14b.json")
33
+
34
+ WAN2_1_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.1-vae.json")
35
+ WAN2_2_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.2-vae.json")
30
36
 
31
37
  # data size
32
38
  KB = 1024
@@ -336,6 +336,9 @@ class ParallelWrapper:
336
336
  except RuntimeError as e:
337
337
  raise RuntimeError("Failed to set start method to spawn:", e)
338
338
  super().__init__()
339
+ self.config = module.config if isinstance(module, BasePipeline) else None
340
+ self._module_name = module.__class__.__name__
341
+
339
342
  self.world_size = cfg_degree * sp_ulysses_degree * sp_ring_degree * tp_degree
340
343
  self.queue_in = mp.Queue()
341
344
  self.queue_out = mp.Queue()
@@ -357,7 +360,6 @@ class ParallelWrapper:
357
360
  nprocs=self.world_size,
358
361
  join=False,
359
362
  )
360
- self._module_name = module.__class__.__name__
361
363
 
362
364
  def __call__(self, *args, **kwargs):
363
365
  data = ["__call__", args, kwargs]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.3.6.dev13
3
+ Version: 0.3.6.dev14
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -15,7 +15,7 @@ diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py,sha256=QDz
15
15
  diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py,sha256=ZQ5OLY6_CMmV0V2MtUzHxcXyVpanhMopWYiRr2CtFTk,683
16
16
  diffsynth_engine/algorithm/sampler/__init__.py,sha256=Ow07B9JeQbgCjDtaxYPeU_p2k76CUOuHGDGvoAyD1SU,725
17
17
  diffsynth_engine/algorithm/sampler/flow_match/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
- diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py,sha256=wJVanlk6R075fBNGHwA3BJENxRyPgOslIjo0VGRXgKQ,746
18
+ diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py,sha256=7wnI0MjgaTZnkK-JWAHIStloCnQNgpZx4JVKjw2SAEE,733
19
19
  diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
20
  diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py,sha256=Y2yTAp_aMWb7z8V7GO48jwPv1LhEkxVJTcsljq0qHqg,2106
21
21
  diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py,sha256=H5BkzNAUpp_RhTmHRyYWekCYjChX7uuuxFSgE5ZBAlc,1288
@@ -36,10 +36,15 @@ diffsynth_engine/conf/models/sd3/sd3_dit.json,sha256=RyJeCKjd4UPRf2Qbicd8Oxlioxg
36
36
  diffsynth_engine/conf/models/sd3/sd3_text_encoder.json,sha256=1yXwzKbbIIVg1QPhQJxjdwvbFkA1mJ6NR6dw2vrN-1A,91415
37
37
  diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json,sha256=cBN3mIm4BjJYbSpL2gz4yeb1aP0BvGt9na4hmuafyJo,35642
38
38
  diffsynth_engine/conf/models/sdxl/sdxl_unet.json,sha256=9f9ca1qYQALaDkA5KTCfVP9mKFvhM2xFP5e042Ryppw,129779
39
- diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json,sha256=mtBk_lj4R18wadhaIu-EDSErhQik3GK2xHSUkGWW6BM,239
40
- diffsynth_engine/conf/models/wan/dit/14b-flf2v.json,sha256=xsONv_2NgDmcBKsNtjShN2Llp2gniPdKHAfgEOdTtHQ,264
41
- diffsynth_engine/conf/models/wan/dit/14b-i2v.json,sha256=5xnvVzevKep0xQcbNkuIlskF1jS6co9y8WsZV2BqV9Q,239
42
- diffsynth_engine/conf/models/wan/dit/14b-t2v.json,sha256=NkggFTpaKb2pOXRLpEi3xv3uXrJdsmY-mve_UjAoVR4,240
39
+ diffsynth_engine/conf/models/wan/dit/wan2.1-flf2v-14b.json,sha256=s7yoVErSiuSlGwwqfrvhvmzz6MD4oAqBKg7iZfL1vX8,313
40
+ diffsynth_engine/conf/models/wan/dit/wan2.1-i2v-14b.json,sha256=BkDV80TkA-_vTRR_1AWpGIzwlgtuKbh-gezW2Q20dlQ,269
41
+ diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-1.3b.json,sha256=M_h55-mMhpgXUuY85sBK6-_f4fg3bfCa6T7n1CyMP3s,209
42
+ diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-14b.json,sha256=7i2Hq8BRH4kDVYBKcIBt8m3vCl_HGZZPFY5fmFw4xgs,210
43
+ diffsynth_engine/conf/models/wan/dit/wan2.2-i2v-a14b.json,sha256=7OmPEfreIu8Ex6NDr1IW69zmKRp21hZkmg_9yg6sUg8,322
44
+ diffsynth_engine/conf/models/wan/dit/wan2.2-t2v-a14b.json,sha256=MqxjGwq8VqD-1RwbPocbkKx0JzsMgwn18hfVK7M0d4k,312
45
+ diffsynth_engine/conf/models/wan/dit/wan2.2-ti2v-5b.json,sha256=tO7nymyqQgBIgxlswITnIc_MsRr1RRPhZbbhJ-1gHow,257
46
+ diffsynth_engine/conf/models/wan/vae/wan2.1-vae.json,sha256=eVLTSRqbXm3JD8QDkLbM6vFfCdynlS-8QxqCfi4BzrI,815
47
+ diffsynth_engine/conf/models/wan/vae/wan2.2-vae.json,sha256=pdnYEEZ_GcZHM_iH1y5ASdf_qZUGCOuDEaFmjdg9RKY,1860
43
48
  diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt,sha256=n9aR98gDkhDg_O0VhlRmxlgg0JtjmIsBdL_iXeKZBRo,524619
44
49
  diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json,sha256=LNs7gzGmDJL8HlWhPp_WH9IpPFpRJ1_czNYreABSUw4,588
45
50
  diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json,sha256=a9zunMzioWyitMDF7QC0LFDqIl9EcqjEweljopAsKIE,705
@@ -62,7 +67,7 @@ diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json,sha256=bhl7TT29cdoU
62
67
  diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json,sha256=7Zo6iw-qcacKMoR-BDX-A25uES1N9O23u0ipIeNE3AU,61728
63
68
  diffsynth_engine/configs/__init__.py,sha256=qvfbnHf3wK9THPU_mFr1Qx_lU80BaUp5HpxUmjoNy60,502
64
69
  diffsynth_engine/configs/controlnet.py,sha256=EpUkCdRNk2G5uo56syaOzPFdR9g0sDHRXckagmMsgaQ,948
65
- diffsynth_engine/configs/pipeline.py,sha256=NPQlNz-AOpi8qFzRob0RNnOqSc8C-vCdHbstLyUugeo,7703
70
+ diffsynth_engine/configs/pipeline.py,sha256=ltTn4tRETeS-v7IMm3fkokza2PKluVb2DWGk8mPOghU,8079
66
71
  diffsynth_engine/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
67
72
  diffsynth_engine/models/__init__.py,sha256=8Ze7cSE8InetgXWTNb0neVA2Q44K7WlE-h7O-02m2sY,119
68
73
  diffsynth_engine/models/base.py,sha256=sbyyGP-ENnqicr6cxjEmXRf6dWrmKjCu6k5yamuJ518,2665
@@ -103,17 +108,17 @@ diffsynth_engine/models/text_encoder/t5.py,sha256=iSYyYQF4DUU0zpN65V_slWoftBTDVw
103
108
  diffsynth_engine/models/vae/__init__.py,sha256=TFSIXZ-UyRaZbEr5KUXm1d4koS5gbgsCi7Soh6jDV0Y,140
104
109
  diffsynth_engine/models/vae/vae.py,sha256=FWMVqahY1BdnIkzLi8ykCp_VWHs05l0JF21wk7763LI,15844
105
110
  diffsynth_engine/models/wan/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
106
- diffsynth_engine/models/wan/wan_dit.py,sha256=gUd9KeMl7y_VPLntGoGtT2Io94opPiKlrrfo8E8O-8o,18869
111
+ diffsynth_engine/models/wan/wan_dit.py,sha256=05UG5B3wMu565HGCkfTMHjUHxT18xZ_lz0rvNqVoMqM,19753
107
112
  diffsynth_engine/models/wan/wan_image_encoder.py,sha256=LYwcfCcQmXf9FP08DGaU2bfaPgFfdpJ23OpJP8UCggo,14397
108
113
  diffsynth_engine/models/wan/wan_text_encoder.py,sha256=bkphxtqNNwXcEA_OaUrwV9CvICV-s16awu5Z9gjjzsM,10912
109
- diffsynth_engine/models/wan/wan_vae.py,sha256=RxyuHExQmRjGBAqhZdIbtwZFdCibTzh__U4-Sa00zdI,29004
114
+ diffsynth_engine/models/wan/wan_vae.py,sha256=bYXW-7FdLAi7391y9nQbYpKPYFDYyTxWbY_0TBrn2Yw,38444
110
115
  diffsynth_engine/pipelines/__init__.py,sha256=kTvANqHcMPrHqiJVg-XohfqRdW6Cj4aElfItTb1B7Vs,380
111
- diffsynth_engine/pipelines/base.py,sha256=yVp4hSPCqk98azzy3ykKBfPAufvq_ncTFOURN95z7d0,12178
116
+ diffsynth_engine/pipelines/base.py,sha256=o31tD_iFobNbEPsl_d8ih9-GL023-qqb55r06i0SvAw,12050
112
117
  diffsynth_engine/pipelines/flux_image.py,sha256=MtQqTnCqQjIFovhA3lzBXpnkS4DkZH2PtFUwNZdl42M,48839
113
118
  diffsynth_engine/pipelines/sd_image.py,sha256=5dGIa6crtklO7xPd1eeBVkqj54Pe89Uo3bMyXVEaXxM,17822
114
119
  diffsynth_engine/pipelines/sdxl_image.py,sha256=Ns4bCSO3BtCXdjGJEQ0s5oY0S3jrp5yE5lhfon-iNiw,21575
115
120
  diffsynth_engine/pipelines/utils.py,sha256=VfSTwRejSVSKXIa7w0VhObmvaBFRvDP-uiYsHHkPAgs,165
116
- diffsynth_engine/pipelines/wan_video.py,sha256=vi_xW-jU4PeMtZzjkfQbnj8eOymJrTZMrOQau6tx6ks,20187
121
+ diffsynth_engine/pipelines/wan_video.py,sha256=wbCHPDgs4BmyX1DsawaXqxeCoVAcISNcXNnFr2qcTx0,25424
117
122
  diffsynth_engine/processor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
118
123
  diffsynth_engine/processor/canny_processor.py,sha256=hV30NlblTkEFUAmF_O-LJrNlGVM2SFrqq6okfF8VpOo,602
119
124
  diffsynth_engine/processor/depth_processor.py,sha256=dQvs3JsnyMbz4dyI9QoR8oO-mMFBFAgNvgqeCoaU5jk,1532
@@ -128,7 +133,7 @@ diffsynth_engine/tools/flux_outpainting_tool.py,sha256=sxGRAiht27he9CT_dL9KkXVvM
128
133
  diffsynth_engine/tools/flux_reference_tool.py,sha256=BJlXQxH8j3AhEhlymIlE6OnIH2gU_l_qv5k10JDZKng,3705
129
134
  diffsynth_engine/tools/flux_replace_tool.py,sha256=M_q8KnsBEwNi4w8NOK-F2Bmj7cKUNcA9QMzwrp3zm6E,3336
130
135
  diffsynth_engine/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
131
- diffsynth_engine/utils/constants.py,sha256=L7sIxGNMfCvcZG66ul7GIT6fDctkcwhePAjMjG6WXx8,1969
136
+ diffsynth_engine/utils/constants.py,sha256=9N0BuLmDeHgiKAlu1vaCTb9-tPClbCA8nTu916_UumM,2510
132
137
  diffsynth_engine/utils/download.py,sha256=NCgfL9tUca-sOhT41k6w4o__Ktbw-1aDwFTR4JDkT28,5639
133
138
  diffsynth_engine/utils/env.py,sha256=43x-kBjt5zI2cwZ9G4BOeTbedi2k6TuBzHGOBeFbFvU,280
134
139
  diffsynth_engine/utils/flag.py,sha256=6zQLnoEaU69pBEyhavCgydQfP0khw5ppCU7sue4yRqg,1370
@@ -140,12 +145,12 @@ diffsynth_engine/utils/lock.py,sha256=1Ipgst9eEFfFdViAvD5bxdB6HnHHBcqWYOb__fGaPU
140
145
  diffsynth_engine/utils/logging.py,sha256=XB0xTT8PBN6btkOjFtOvjlrOCRVgDGT8PFAp1vmse28,467
141
146
  diffsynth_engine/utils/offload.py,sha256=jUR4u7J60o4KZIRxHhMCwaeDkiXJvBa0KJkYKKT6mrg,1587
142
147
  diffsynth_engine/utils/onnx.py,sha256=jeWUudJHnESjuiEAHyUZYUZz7dCj34O9aGjHCe8yjWo,1149
143
- diffsynth_engine/utils/parallel.py,sha256=gbIeilfOYsqeDcgkaP68TfLjIXxvD0KfLiAsR_8gJco,16917
148
+ diffsynth_engine/utils/parallel.py,sha256=eXFglYH2w478oYusktpllm0v4IC8CABGmy0HsE-zE_8,17000
144
149
  diffsynth_engine/utils/platform.py,sha256=2lXdw6YkqcRONCeT98n4cyg1Ii8Ybbyj2Ns72Se9tlk,496
145
150
  diffsynth_engine/utils/prompt.py,sha256=YItMchoVzsG6y-LB4vzzDUWrkhKRVlt1HfVhxZjSxMQ,280
146
151
  diffsynth_engine/utils/video.py,sha256=Ne0rd2lb59UT1q5EotpjlY7OT8F9oTCFDyo1ST77uoQ,1004
147
- diffsynth_engine-0.3.6.dev13.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
148
- diffsynth_engine-0.3.6.dev13.dist-info/METADATA,sha256=2jH1jlJdbUga4JOoDHfRyKEn6E4xQ1w9wRhLKVYaqRk,1069
149
- diffsynth_engine-0.3.6.dev13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
150
- diffsynth_engine-0.3.6.dev13.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
151
- diffsynth_engine-0.3.6.dev13.dist-info/RECORD,,
152
+ diffsynth_engine-0.3.6.dev14.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
153
+ diffsynth_engine-0.3.6.dev14.dist-info/METADATA,sha256=qRZwaOSJBZh1MNJLITijv3WWBJEWHoN0Dyx-CIPdd2w,1069
154
+ diffsynth_engine-0.3.6.dev14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
155
+ diffsynth_engine-0.3.6.dev14.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
156
+ diffsynth_engine-0.3.6.dev14.dist-info/RECORD,,