diffsynth-engine 0.6.1.dev38__py3-none-any.whl → 0.6.1.dev39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -27,6 +27,7 @@ from .pipelines import (
27
27
  SDXLImagePipeline,
28
28
  FluxImagePipeline,
29
29
  WanVideoPipeline,
30
+ WanDMDPipeline,
30
31
  QwenImagePipeline,
31
32
  Hunyuan3DShapePipeline,
32
33
  )
@@ -81,6 +82,7 @@ __all__ = [
81
82
  "FluxIPAdapter",
82
83
  "FluxRedux",
83
84
  "WanVideoPipeline",
85
+ "WanDMDPipeline",
84
86
  "QwenImagePipeline",
85
87
  "Hunyuan3DShapePipeline",
86
88
  "FluxInpaintingTool",
@@ -4,6 +4,7 @@ from .sdxl_image import SDXLImagePipeline
4
4
  from .sd_image import SDImagePipeline
5
5
  from .wan_video import WanVideoPipeline
6
6
  from .wan_s2v import WanSpeech2VideoPipeline
7
+ from .wan_dmd import WanDMDPipeline
7
8
  from .qwen_image import QwenImagePipeline
8
9
  from .hunyuan3d_shape import Hunyuan3DShapePipeline
9
10
  from .z_image import ZImagePipeline
@@ -16,6 +17,7 @@ __all__ = [
16
17
  "SDImagePipeline",
17
18
  "WanVideoPipeline",
18
19
  "WanSpeech2VideoPipeline",
20
+ "WanDMDPipeline",
19
21
  "QwenImagePipeline",
20
22
  "Hunyuan3DShapePipeline",
21
23
  "ZImagePipeline",
@@ -145,7 +145,7 @@ class BasePipeline:
145
145
  self.load_loras([(path, scale)], fused, save_original_weight)
146
146
 
147
147
  def apply_scheduler_config(self, scheduler_config: Dict):
148
- pass
148
+ self.noise_scheduler.update_config(scheduler_config)
149
149
 
150
150
  def unload_loras(self):
151
151
  raise NotImplementedError()
@@ -393,9 +393,6 @@ class QwenImagePipeline(BasePipeline):
393
393
  self.dit.unload_loras()
394
394
  self.noise_scheduler.restore_config()
395
395
 
396
- def apply_scheduler_config(self, scheduler_config: Dict):
397
- self.noise_scheduler.update_config(scheduler_config)
398
-
399
396
  def prepare_latents(
400
397
  self,
401
398
  latents: torch.Tensor,
@@ -0,0 +1,111 @@
1
+ import torch
2
+ import torch.distributed as dist
3
+ from typing import Callable, List, Optional
4
+ from tqdm import tqdm
5
+ from PIL import Image
6
+
7
+ from diffsynth_engine.pipelines.wan_video import WanVideoPipeline
8
+
9
+
10
+ class WanDMDPipeline(WanVideoPipeline):
11
+ def prepare_latents(
12
+ self,
13
+ latents,
14
+ denoising_step_list,
15
+ ):
16
+ height, width = latents.shape[-2:]
17
+ height, width = height * self.upsampling_factor, width * self.upsampling_factor
18
+ sigmas, timesteps = self.noise_scheduler.schedule(num_inference_steps=1000)
19
+ sigmas = sigmas[[1000 - t for t in denoising_step_list] + [-1]]
20
+ timesteps = timesteps[[1000 - t for t in denoising_step_list]]
21
+ init_latents = latents.clone()
22
+
23
+ return init_latents, latents, sigmas, timesteps
24
+
25
+ @torch.no_grad()
26
+ def __call__(
27
+ self,
28
+ prompt,
29
+ input_image: Image.Image | None = None,
30
+ seed=None,
31
+ height=480,
32
+ width=832,
33
+ num_frames=81,
34
+ denoising_step_list: List[int] = None,
35
+ progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
36
+ ):
37
+ denoising_step_list = [1000, 750, 500, 250] if denoising_step_list is None else denoising_step_list
38
+ divisor = 32 if self.vae.z_dim == 48 else 16 # 32 for wan2.2 vae, 16 for wan2.1 vae
39
+ assert height % divisor == 0 and width % divisor == 0, f"height and width must be divisible by {divisor}"
40
+ assert (num_frames - 1) % 4 == 0, "num_frames must be 4X+1"
41
+
42
+ # Initialize noise
43
+ if dist.is_initialized() and seed is None:
44
+ raise ValueError("must provide a seed when parallelism is enabled")
45
+ noise = self.generate_noise(
46
+ (
47
+ 1,
48
+ self.vae.z_dim,
49
+ (num_frames - 1) // 4 + 1,
50
+ height // self.upsampling_factor,
51
+ width // self.upsampling_factor,
52
+ ),
53
+ seed=seed,
54
+ device="cpu",
55
+ dtype=torch.float32,
56
+ ).to(self.device)
57
+ init_latents, latents, sigmas, timesteps = self.prepare_latents(noise, denoising_step_list)
58
+ mask = torch.ones((1, 1, *latents.shape[2:]), dtype=latents.dtype, device=latents.device)
59
+
60
+ # Encode prompts
61
+ self.load_models_to_device(["text_encoder"])
62
+ prompt_emb_posi = self.encode_prompt(prompt)
63
+ prompt_emb_nega = None
64
+
65
+ # Encode image
66
+ image_clip_feature = self.encode_clip_feature(input_image, height, width)
67
+ image_y = self.encode_vae_feature(input_image, num_frames, height, width)
68
+ image_latents = self.encode_image_latents(input_image, height, width)
69
+ if image_latents is not None:
70
+ latents[:, :, : image_latents.shape[2], :, :] = image_latents
71
+ init_latents = latents.clone()
72
+ mask[:, :, : image_latents.shape[2], :, :] = 0
73
+
74
+ # Initialize sampler
75
+ self.sampler.initialize(sigmas=sigmas)
76
+
77
+ # Denoise
78
+ hide_progress = dist.is_initialized() and dist.get_rank() != 0
79
+ for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)):
80
+ if timestep.item() / 1000 >= self.config.boundary:
81
+ self.load_models_to_device(["dit"])
82
+ model = self.dit
83
+ else:
84
+ self.load_models_to_device(["dit2"])
85
+ model = self.dit2
86
+
87
+ timestep = timestep * mask[:, :, :, ::2, ::2].flatten() # seq_len
88
+ timestep = timestep.to(dtype=self.dtype, device=self.device)
89
+ # Classifier-free guidance
90
+ noise_pred = self.predict_noise_with_cfg(
91
+ model=model,
92
+ latents=latents,
93
+ timestep=timestep,
94
+ positive_prompt_emb=prompt_emb_posi,
95
+ negative_prompt_emb=prompt_emb_nega,
96
+ image_clip_feature=image_clip_feature,
97
+ image_y=image_y,
98
+ cfg_scale=1.0,
99
+ batch_cfg=self.config.batch_cfg,
100
+ )
101
+ # Scheduler
102
+ latents = self.sampler.step(latents, noise_pred, i)
103
+ latents = latents * mask + init_latents * (1 - mask)
104
+ if progress_callback is not None:
105
+ progress_callback(i + 1, len(timesteps), "DENOISING")
106
+
107
+ # Decode
108
+ self.load_models_to_device(["vae"])
109
+ frames = self.decode_video(latents, progress_callback=progress_callback)
110
+ frames = self.vae_output_to_image(frames)
111
+ return frames
@@ -43,6 +43,24 @@ class WanLoRAConverter(LoRAStateDictConverter):
43
43
  dit_dict[key] = lora_args
44
44
  return {"dit": dit_dict}
45
45
 
46
+ def _from_diffusers(self, state_dict):
47
+ dit_dict = {}
48
+ for key, param in state_dict.items():
49
+ if ".lora_down.weight" not in key:
50
+ continue
51
+
52
+ lora_args = {}
53
+ lora_args["up"] = state_dict[key.replace(".lora_down.weight", ".lora_up.weight")]
54
+ lora_args["down"] = param
55
+ lora_args["rank"] = lora_args["up"].shape[1]
56
+ if key.replace(".lora_down.weight", ".alpha") in state_dict:
57
+ lora_args["alpha"] = state_dict[key.replace(".lora_down.weight", ".alpha")]
58
+ else:
59
+ lora_args["alpha"] = lora_args["rank"]
60
+ key = key.replace("diffusion_model.", "").replace(".lora_down.weight", "")
61
+ dit_dict[key] = lora_args
62
+ return {"dit": dit_dict}
63
+
46
64
  def _from_civitai(self, state_dict):
47
65
  dit_dict = {}
48
66
  for key, param in state_dict.items():
@@ -86,6 +104,9 @@ class WanLoRAConverter(LoRAStateDictConverter):
86
104
  if "lora_unet_blocks_0_cross_attn_k.lora_down.weight" in state_dict:
87
105
  state_dict = self._from_fun(state_dict)
88
106
  logger.info("use fun format state dict")
107
+ elif "diffusion_model.blocks.0.cross_attn.k.lora_down.weight" in state_dict:
108
+ state_dict = self._from_diffusers(state_dict)
109
+ logger.info("use diffusers format state dict")
89
110
  elif "diffusion_model.blocks.0.cross_attn.k.lora_A.weight" in state_dict:
90
111
  state_dict = self._from_civitai(state_dict)
91
112
  logger.info("use civitai format state dict")
@@ -480,8 +501,8 @@ class WanVideoPipeline(BasePipeline):
480
501
 
481
502
  dit_state_dict, dit2_state_dict = None, None
482
503
  if isinstance(config.model_path, list):
483
- high_noise_model_ckpt = [path for path in config.model_path if "high_noise_model" in path]
484
- low_noise_model_ckpt = [path for path in config.model_path if "low_noise_model" in path]
504
+ high_noise_model_ckpt = [path for path in config.model_path if "high_noise" in path]
505
+ low_noise_model_ckpt = [path for path in config.model_path if "low_noise" in path]
485
506
  if high_noise_model_ckpt and low_noise_model_ckpt:
486
507
  logger.info(f"loading high noise model state dict from {high_noise_model_ckpt} ...")
487
508
  dit_state_dict = cls.load_model_checkpoint(
@@ -681,8 +702,9 @@ class WanVideoPipeline(BasePipeline):
681
702
  config.attn_params = VideoSparseAttentionParams(sparsity=0.9)
682
703
 
683
704
  def update_weights(self, state_dicts: WanStateDicts) -> None:
684
- is_dual_model_state_dict = (isinstance(state_dicts.model, dict) and
685
- ("high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model))
705
+ is_dual_model_state_dict = isinstance(state_dicts.model, dict) and (
706
+ "high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model
707
+ )
686
708
  is_dual_model_pipeline = self.dit2 is not None
687
709
 
688
710
  if is_dual_model_state_dict != is_dual_model_pipeline:
@@ -694,15 +716,21 @@ class WanVideoPipeline(BasePipeline):
694
716
 
695
717
  if is_dual_model_state_dict:
696
718
  if "high_noise_model" in state_dicts.model:
697
- self.update_component(self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype)
719
+ self.update_component(
720
+ self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype
721
+ )
698
722
  if "low_noise_model" in state_dicts.model:
699
- self.update_component(self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype)
723
+ self.update_component(
724
+ self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype
725
+ )
700
726
  else:
701
727
  self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
702
728
 
703
729
  self.update_component(self.text_encoder, state_dicts.t5, self.config.device, self.config.t5_dtype)
704
730
  self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype)
705
- self.update_component(self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype)
731
+ self.update_component(
732
+ self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype
733
+ )
706
734
 
707
735
  def compile(self):
708
736
  self.dit.compile_repeated_blocks()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev38
3
+ Version: 0.6.1.dev39
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -1,4 +1,4 @@
1
- diffsynth_engine/__init__.py,sha256=hN0jYaikjhpqHB4Mg-e53h-7ck1DsiY4FBti8K9lN2k,2390
1
+ diffsynth_engine/__init__.py,sha256=um2Vh4BgmAAG66LafdcTXPiJ6dFtBU85xwPSKZOswFE,2432
2
2
  diffsynth_engine/algorithm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  diffsynth_engine/algorithm/noise_scheduler/__init__.py,sha256=YvcwE2tCNua-OAX9GEPm0EXsINNWH4XvJMNZb-uaZMM,745
4
4
  diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py,sha256=3ve4bYxGyfuERynvoNYdFYSk0agdBgXKCeIOS6O6wgI,819
@@ -150,16 +150,17 @@ diffsynth_engine/models/wan/wan_vae.py,sha256=dC7MoUFeXRL7SIY0LG1OOUiZW-pp9IbXCg
150
150
  diffsynth_engine/models/z_image/__init__.py,sha256=d1ztBNgM8GR2_uGwlxOE1Jf5URTq1g-WnmJH7nrMoaY,160
151
151
  diffsynth_engine/models/z_image/qwen3.py,sha256=PmT6m46Fc7KZXNzG7ig23Mzj6QfHnMmrpX_MM0UuuYg,4580
152
152
  diffsynth_engine/models/z_image/z_image_dit.py,sha256=kGtYzmfzk_FDe7KWfXpJagN7k7ROXl5J01IhRRs-Bsk,23806
153
- diffsynth_engine/pipelines/__init__.py,sha256=xQUtz2cVmcEInazvT1dqv2HdPiJKmywWTIPfbK5dZXI,662
154
- diffsynth_engine/pipelines/base.py,sha256=ShRiX5MY6bUkRKfuGrA1aalAqeHyeZxhzT87Mwc30b4,17231
153
+ diffsynth_engine/pipelines/__init__.py,sha256=44odpJm_Jnkzdbl1GDq9XVu4LN0_SICsK5ubjYKWeg4,720
154
+ diffsynth_engine/pipelines/base.py,sha256=h6xOqT1LMFGrJYoTD68_VoHcfRX04je8KUE_y3BUZfM,17279
155
155
  diffsynth_engine/pipelines/flux_image.py,sha256=L0ggxpthLD8a5-zdPHu9z668uWBei9YzPb4PFVypDNU,50707
156
156
  diffsynth_engine/pipelines/hunyuan3d_shape.py,sha256=TNV0Wr09Dj2bzzlpua9WioCClOj3YiLfE6utI9aWL8A,8164
157
- diffsynth_engine/pipelines/qwen_image.py,sha256=9n0fZCYw5E1iloXqd7vOU0XfHVPxQp_pm-v4D3Oloos,35751
157
+ diffsynth_engine/pipelines/qwen_image.py,sha256=Xc3H5LiQj2MUdi2KgFD2G2VqDwUa2ehqj4H35sr8iro,35627
158
158
  diffsynth_engine/pipelines/sd_image.py,sha256=nr-Nhsnomq8CsUqhTM3i2l2zG01YjwXdfRXgr_bC3F0,17891
159
159
  diffsynth_engine/pipelines/sdxl_image.py,sha256=v7ZACGPb6EcBunL6e5E9jynSQjE7GQx8etEV-ZLP91g,21704
160
160
  diffsynth_engine/pipelines/utils.py,sha256=HZbJHErNJS1DhlwJKvZ9dY7Kh8Zdlsw3zE2e88TYGRY,2277
161
+ diffsynth_engine/pipelines/wan_dmd.py,sha256=T_i4xp_tASFSaKZxg50FEAk5TOn89JSNv-4y5Os6Q6E,4508
161
162
  diffsynth_engine/pipelines/wan_s2v.py,sha256=QHlCLMqlmnp55iYm2mzg4qCq4jceRAP3Zt5Mubz3mAM,29384
162
- diffsynth_engine/pipelines/wan_video.py,sha256=9xjSvQ4mlVEDdaL6QuUURj4iyxhJ2xABBphQjkfzK8s,31323
163
+ diffsynth_engine/pipelines/wan_video.py,sha256=9nUV6h2zBbGu3gvVSM_oqdoruCdBWoa7t6vrJYJt8QY,32391
163
164
  diffsynth_engine/pipelines/z_image.py,sha256=VvqjxsKRsmP2tfWg9nDlcQu5oEzIRFa2wtuArzjQAlk,16151
164
165
  diffsynth_engine/processor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
165
166
  diffsynth_engine/processor/canny_processor.py,sha256=hV30NlblTkEFUAmF_O-LJrNlGVM2SFrqq6okfF8VpOo,602
@@ -199,8 +200,8 @@ diffsynth_engine/utils/video.py,sha256=8FCaeqIdUsWMgWI_6SO9SPynsToGcLCQAVYFTc4CD
199
200
  diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
200
201
  diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
201
202
  diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
202
- diffsynth_engine-0.6.1.dev38.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
203
- diffsynth_engine-0.6.1.dev38.dist-info/METADATA,sha256=0fI0prUJox3z_sDzvhl-wh6wlCCYCA7N-naxpobysL0,1164
204
- diffsynth_engine-0.6.1.dev38.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
205
- diffsynth_engine-0.6.1.dev38.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
206
- diffsynth_engine-0.6.1.dev38.dist-info/RECORD,,
203
+ diffsynth_engine-0.6.1.dev39.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
204
+ diffsynth_engine-0.6.1.dev39.dist-info/METADATA,sha256=f_qU_vp4RcHSOgW3Agm428engf8v7TKRCt8DuxAOEi8,1164
205
+ diffsynth_engine-0.6.1.dev39.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
206
+ diffsynth_engine-0.6.1.dev39.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
207
+ diffsynth_engine-0.6.1.dev39.dist-info/RECORD,,