diffsynth-engine 0.5.1.dev4__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. diffsynth_engine/__init__.py +12 -0
  2. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +19 -0
  3. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +22 -6
  4. diffsynth_engine/conf/models/flux/flux_dit.json +20 -1
  5. diffsynth_engine/conf/models/flux/flux_vae.json +253 -5
  6. diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
  7. diffsynth_engine/configs/__init__.py +16 -1
  8. diffsynth_engine/configs/controlnet.py +13 -0
  9. diffsynth_engine/configs/pipeline.py +37 -11
  10. diffsynth_engine/models/base.py +1 -1
  11. diffsynth_engine/models/basic/attention.py +105 -43
  12. diffsynth_engine/models/basic/transformer_helper.py +36 -2
  13. diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
  14. diffsynth_engine/models/flux/flux_controlnet.py +16 -30
  15. diffsynth_engine/models/flux/flux_dit.py +49 -62
  16. diffsynth_engine/models/flux/flux_dit_fbcache.py +26 -28
  17. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  18. diffsynth_engine/models/flux/flux_text_encoder.py +1 -1
  19. diffsynth_engine/models/flux/flux_vae.py +20 -2
  20. diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +4 -2
  21. diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
  22. diffsynth_engine/models/qwen_image/qwen_image_dit.py +151 -58
  23. diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
  24. diffsynth_engine/models/qwen_image/qwen_image_vae.py +1 -1
  25. diffsynth_engine/models/sd/sd_text_encoder.py +1 -1
  26. diffsynth_engine/models/sd/sd_unet.py +1 -1
  27. diffsynth_engine/models/sd3/sd3_dit.py +1 -1
  28. diffsynth_engine/models/sd3/sd3_text_encoder.py +1 -1
  29. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +1 -1
  30. diffsynth_engine/models/sdxl/sdxl_unet.py +1 -1
  31. diffsynth_engine/models/vae/vae.py +1 -1
  32. diffsynth_engine/models/wan/wan_audio_encoder.py +6 -3
  33. diffsynth_engine/models/wan/wan_dit.py +65 -28
  34. diffsynth_engine/models/wan/wan_s2v_dit.py +1 -1
  35. diffsynth_engine/models/wan/wan_text_encoder.py +13 -13
  36. diffsynth_engine/models/wan/wan_vae.py +2 -2
  37. diffsynth_engine/pipelines/base.py +73 -7
  38. diffsynth_engine/pipelines/flux_image.py +139 -120
  39. diffsynth_engine/pipelines/hunyuan3d_shape.py +4 -0
  40. diffsynth_engine/pipelines/qwen_image.py +272 -87
  41. diffsynth_engine/pipelines/sdxl_image.py +1 -1
  42. diffsynth_engine/pipelines/utils.py +52 -0
  43. diffsynth_engine/pipelines/wan_s2v.py +25 -14
  44. diffsynth_engine/pipelines/wan_video.py +43 -19
  45. diffsynth_engine/tokenizers/base.py +6 -0
  46. diffsynth_engine/tokenizers/qwen2.py +12 -4
  47. diffsynth_engine/utils/constants.py +13 -12
  48. diffsynth_engine/utils/download.py +4 -2
  49. diffsynth_engine/utils/env.py +2 -0
  50. diffsynth_engine/utils/flag.py +6 -0
  51. diffsynth_engine/utils/loader.py +25 -6
  52. diffsynth_engine/utils/parallel.py +62 -29
  53. diffsynth_engine/utils/video.py +3 -1
  54. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
  55. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +69 -67
  56. /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
  57. /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
  58. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
  59. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
  60. /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
  61. /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
  62. /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
  63. /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
  64. /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
  65. /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
  66. /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
  67. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
  68. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
  69. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
@@ -239,7 +239,15 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
239
239
 
240
240
  return ref_latents, motion_latents, motion_frames
241
241
 
242
- def encode_pose(self, pose_video: List[Image.Image], pose_video_fps: int, num_clips: int, num_frames_per_clip: int, height: int, width: int):
242
+ def encode_pose(
243
+ self,
244
+ pose_video: List[Image.Image],
245
+ pose_video_fps: int,
246
+ num_clips: int,
247
+ num_frames_per_clip: int,
248
+ height: int,
249
+ width: int,
250
+ ):
243
251
  self.load_models_to_device(["vae"])
244
252
  max_num_pose_frames = num_frames_per_clip * num_clips
245
253
  pose_video = read_n_frames(pose_video, pose_video_fps, max_num_pose_frames, target_fps=self.config.fps)
@@ -386,6 +394,7 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
386
394
  void_audio_input: torch.Tensor | None = None,
387
395
  ):
388
396
  latents = latents.to(dtype=self.config.model_dtype, device=self.device)
397
+ attn_kwargs = self.get_attn_kwargs(latents)
389
398
 
390
399
  noise_pred = model(
391
400
  x=latents,
@@ -400,6 +409,7 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
400
409
  drop_motion_frames=drop_motion_frames,
401
410
  audio_mask=audio_mask,
402
411
  void_audio_input=void_audio_input,
412
+ attn_kwargs=attn_kwargs,
403
413
  )
404
414
  return noise_pred
405
415
 
@@ -466,7 +476,9 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
466
476
  dtype=self.dtype,
467
477
  ).to(self.device)
468
478
  if pose_video is not None:
469
- pose_latents_all_clips = self.encode_pose(pose_video, pose_video_fps, num_clips, num_frames_per_clip, height, width)
479
+ pose_latents_all_clips = self.encode_pose(
480
+ pose_video, pose_video_fps, num_clips, num_frames_per_clip, height, width
481
+ )
470
482
 
471
483
  output_frames_all_clips = []
472
484
  for clip_idx in range(num_clips):
@@ -602,7 +614,9 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
602
614
  return cls.from_state_dict(state_dicts, config)
603
615
 
604
616
  @classmethod
605
- def from_state_dict(cls, state_dicts: WanS2VStateDicts, config: WanSpeech2VideoPipelineConfig) -> "WanSpeech2VideoPipeline":
617
+ def from_state_dict(
618
+ cls, state_dicts: WanS2VStateDicts, config: WanSpeech2VideoPipelineConfig
619
+ ) -> "WanSpeech2VideoPipeline":
606
620
  if config.parallelism > 1:
607
621
  pipe = ParallelWrapper(
608
622
  cfg_degree=config.cfg_degree,
@@ -617,7 +631,9 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
617
631
  return pipe
618
632
 
619
633
  @classmethod
620
- def _from_state_dict(cls, state_dicts: WanS2VStateDicts, config: WanSpeech2VideoPipelineConfig) -> "WanSpeech2VideoPipeline":
634
+ def _from_state_dict(
635
+ cls, state_dicts: WanS2VStateDicts, config: WanSpeech2VideoPipelineConfig
636
+ ) -> "WanSpeech2VideoPipeline":
621
637
  # default params from model config
622
638
  vae_type = "wan2.1-vae"
623
639
  dit_type = "wan2.2-s2v-14b"
@@ -632,25 +648,20 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
632
648
  init_device = "cpu" if config.offload_mode is not None else config.device
633
649
  tokenizer = WanT5Tokenizer(WAN_TOKENIZER_CONF_PATH, seq_len=512, clean="whitespace")
634
650
  text_encoder = WanTextEncoder.from_state_dict(state_dicts.t5, device=init_device, dtype=config.t5_dtype)
635
- vae = WanVideoVAE.from_state_dict(state_dicts.vae, config=vae_config, device=init_device, dtype=config.vae_dtype)
651
+ vae = WanVideoVAE.from_state_dict(
652
+ state_dicts.vae, config=vae_config, device=init_device, dtype=config.vae_dtype
653
+ )
636
654
  audio_encoder = Wav2Vec2Model.from_state_dict(
637
655
  state_dicts.audio_encoder, config=Wav2Vec2Config(), device=init_device, dtype=config.audio_encoder_dtype
638
656
  )
639
657
 
640
658
  with LoRAContext():
641
- attn_kwargs = {
642
- "attn_impl": config.dit_attn_impl,
643
- "sparge_smooth_k": config.sparge_smooth_k,
644
- "sparge_cdfthreshd": config.sparge_cdfthreshd,
645
- "sparge_simthreshd1": config.sparge_simthreshd1,
646
- "sparge_pvthreshd": config.sparge_pvthreshd,
647
- }
648
659
  dit = WanS2VDiT.from_state_dict(
649
660
  state_dicts.model,
650
661
  config=model_config,
651
- device=init_device,
662
+ device=("cpu" if config.use_fsdp else init_device),
652
663
  dtype=config.model_dtype,
653
- attn_kwargs=attn_kwargs,
664
+ use_vsa=(config.dit_attn_impl.value == "vsa"),
654
665
  )
655
666
  if config.use_fp8_linear:
656
667
  enable_fp8_linear(dit)
@@ -95,8 +95,14 @@ class WanLoRAConverter(LoRAStateDictConverter):
95
95
  return state_dict
96
96
 
97
97
 
98
+ class WanLowNoiseLoRAConverter(WanLoRAConverter):
99
+ def convert(self, state_dict):
100
+ return {"dit2": super().convert(state_dict)["dit"]}
101
+
102
+
98
103
  class WanVideoPipeline(BasePipeline):
99
104
  lora_converter = WanLoRAConverter()
105
+ low_noise_lora_converter = WanLowNoiseLoRAConverter()
100
106
 
101
107
  def __init__(
102
108
  self,
@@ -133,7 +139,13 @@ class WanVideoPipeline(BasePipeline):
133
139
  self.image_encoder = image_encoder
134
140
  self.model_names = ["text_encoder", "dit", "dit2", "vae", "image_encoder"]
135
141
 
136
- def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
142
+ def load_loras(
143
+ self,
144
+ lora_list: List[Tuple[str, float]],
145
+ fused: bool = True,
146
+ save_original_weight: bool = False,
147
+ lora_converter: Optional[WanLoRAConverter] = None,
148
+ ):
137
149
  assert self.config.tp_degree is None or self.config.tp_degree == 1, (
138
150
  "load LoRA is not allowed when tensor parallel is enabled; "
139
151
  "set tp_degree=None or tp_degree=1 during pipeline initialization"
@@ -142,10 +154,24 @@ class WanVideoPipeline(BasePipeline):
142
154
  "load fused LoRA is not allowed when fully sharded data parallel is enabled; "
143
155
  "either load LoRA with fused=False or set use_fsdp=False during pipeline initialization"
144
156
  )
145
- super().load_loras(lora_list, fused, save_original_weight)
157
+ super().load_loras(lora_list, fused, save_original_weight, lora_converter)
158
+
159
+ def load_loras_low_noise(
160
+ self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False
161
+ ):
162
+ assert self.dit2 is not None, "low noise LoRA can only be applied to Wan2.2"
163
+ self.load_loras(lora_list, fused, save_original_weight, self.low_noise_lora_converter)
164
+
165
+ def load_loras_high_noise(
166
+ self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False
167
+ ):
168
+ assert self.dit2 is not None, "high noise LoRA can only be applied to Wan2.2"
169
+ self.load_loras(lora_list, fused, save_original_weight)
146
170
 
147
171
  def unload_loras(self):
148
172
  self.dit.unload_loras()
173
+ if self.dit2 is not None:
174
+ self.dit2.unload_loras()
149
175
  self.text_encoder.unload_loras()
150
176
 
151
177
  def get_default_fps(self) -> int:
@@ -301,6 +327,7 @@ class WanVideoPipeline(BasePipeline):
301
327
 
302
328
  def predict_noise(self, model, latents, image_clip_feature, image_y, timestep, context):
303
329
  latents = latents.to(dtype=self.config.model_dtype, device=self.device)
330
+ attn_kwargs = self.get_attn_kwargs(latents)
304
331
 
305
332
  noise_pred = model(
306
333
  x=latents,
@@ -308,6 +335,7 @@ class WanVideoPipeline(BasePipeline):
308
335
  context=context,
309
336
  clip_feature=image_clip_feature,
310
337
  y=image_y,
338
+ attn_kwargs=attn_kwargs,
311
339
  )
312
340
  return noise_pred
313
341
 
@@ -556,19 +584,12 @@ class WanVideoPipeline(BasePipeline):
556
584
  dit_state_dict = state_dicts.model
557
585
 
558
586
  with LoRAContext():
559
- attn_kwargs = {
560
- "attn_impl": config.dit_attn_impl,
561
- "sparge_smooth_k": config.sparge_smooth_k,
562
- "sparge_cdfthreshd": config.sparge_cdfthreshd,
563
- "sparge_simthreshd1": config.sparge_simthreshd1,
564
- "sparge_pvthreshd": config.sparge_pvthreshd,
565
- }
566
587
  dit = WanDiT.from_state_dict(
567
588
  dit_state_dict,
568
589
  config=dit_config,
569
- device=init_device,
590
+ device=("cpu" if config.use_fsdp else init_device),
570
591
  dtype=config.model_dtype,
571
- attn_kwargs=attn_kwargs,
592
+ use_vsa=(config.dit_attn_impl.value == "vsa"),
572
593
  )
573
594
  if config.use_fp8_linear:
574
595
  enable_fp8_linear(dit)
@@ -578,9 +599,9 @@ class WanVideoPipeline(BasePipeline):
578
599
  dit2 = WanDiT.from_state_dict(
579
600
  dit2_state_dict,
580
601
  config=dit_config,
581
- device=init_device,
602
+ device=("cpu" if config.use_fsdp else init_device),
582
603
  dtype=config.model_dtype,
583
- attn_kwargs=attn_kwargs,
604
+ use_vsa=(config.dit_attn_impl.value == "vsa"),
584
605
  )
585
606
  if config.use_fp8_linear:
586
607
  enable_fp8_linear(dit2)
@@ -618,19 +639,22 @@ class WanVideoPipeline(BasePipeline):
618
639
  @staticmethod
619
640
  def _get_dit_type(model_state_dict: Dict[str, torch.Tensor] | Dict[str, Dict[str, torch.Tensor]]) -> str:
620
641
  # determine wan dit type by model params
642
+ def has_any_key(*xs):
643
+ return any(x in model_state_dict for x in xs)
644
+
621
645
  dit_type = None
622
- if "high_noise_model" in model_state_dict and "low_noise_model" in model_state_dict:
646
+ if has_any_key("high_noise_model"):
623
647
  if model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 36:
624
648
  dit_type = "wan2.2-i2v-a14b"
625
649
  elif model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 16:
626
650
  dit_type = "wan2.2-t2v-a14b"
627
651
  elif model_state_dict["patch_embedding.weight"].shape[1] == 48:
628
652
  dit_type = "wan2.2-ti2v-5b"
629
- elif "img_emb.emb_pos" in model_state_dict:
653
+ elif has_any_key("img_emb.emb_pos", "condition_embedder.image_embedder.pos_embed"):
630
654
  dit_type = "wan2.1-flf2v-14b"
631
- elif "img_emb.proj.0.weight" in model_state_dict:
655
+ elif has_any_key("img_emb.proj.0.weight", "condition_embedder.image_embedder.norm1"):
632
656
  dit_type = "wan2.1-i2v-14b"
633
- elif "blocks.39.self_attn.norm_q.weight" in model_state_dict:
657
+ elif has_any_key("blocks.39.self_attn.norm_q.weight", "blocks.39.attn1.norm_q.weight"):
634
658
  dit_type = "wan2.1-t2v-14b"
635
659
  else:
636
660
  dit_type = "wan2.1-t2v-1.3b"
@@ -645,6 +669,6 @@ class WanVideoPipeline(BasePipeline):
645
669
  return vae_type
646
670
 
647
671
  def compile(self):
648
- self.dit.compile_repeated_blocks(dynamic=True)
672
+ self.dit.compile_repeated_blocks()
649
673
  if self.dit2 is not None:
650
- self.dit2.compile_repeated_blocks(dynamic=True)
674
+ self.dit2.compile_repeated_blocks()
@@ -1,10 +1,16 @@
1
1
  # Modified from transformers.tokenization_utils_base
2
2
  from typing import Dict, List, Union, overload
3
+ from enum import Enum
3
4
 
4
5
 
5
6
  TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
6
7
 
7
8
 
9
+ class PaddingStrategy(str, Enum):
10
+ LONGEST = "longest"
11
+ MAX_LENGTH = "max_length"
12
+
13
+
8
14
  class BaseTokenizer:
9
15
  SPECIAL_TOKENS_ATTRIBUTES = [
10
16
  "bos_token",
@@ -4,7 +4,7 @@ import torch
4
4
  from typing import Dict, List, Union, Optional
5
5
  from tokenizers import Tokenizer as TokenizerFast, AddedToken
6
6
 
7
- from diffsynth_engine.tokenizers.base import BaseTokenizer, TOKENIZER_CONFIG_FILE
7
+ from diffsynth_engine.tokenizers.base import BaseTokenizer, PaddingStrategy, TOKENIZER_CONFIG_FILE
8
8
 
9
9
 
10
10
  VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
@@ -165,22 +165,28 @@ class Qwen2TokenizerFast(BaseTokenizer):
165
165
  texts: Union[str, List[str]],
166
166
  max_length: Optional[int] = None,
167
167
  padding_side: Optional[str] = None,
168
+ padding_strategy: Union[PaddingStrategy, str] = "longest",
168
169
  **kwargs,
169
170
  ) -> Dict[str, "torch.Tensor"]:
170
171
  """
171
172
  Tokenize text and prepare for model inputs.
172
173
 
173
174
  Args:
174
- text (`str`, `List[str]`, *optional*):
175
+ texts (`str`, `List[str]`):
175
176
  The sequence or batch of sequences to be encoded.
176
177
 
177
178
  max_length (`int`, *optional*):
178
- Each encoded sequence will be truncated or padded to max_length.
179
+ Maximum length of the encoded sequences.
179
180
 
180
181
  padding_side (`str`, *optional*):
181
182
  The side on which the padding should be applied. Should be selected between `"right"` and `"left"`.
182
183
  Defaults to `"right"`.
183
184
 
185
+ padding_strategy (`PaddingStrategy`, `str`, *optional*):
186
+ If `"longest"`, will pad the sequences to the longest sequence in the batch.
187
+ If `"max_length"`, will pad the sequences to the `max_length` argument.
188
+ Defaults to `"longest"`.
189
+
184
190
  Returns:
185
191
  `Dict[str, "torch.Tensor"]`: tensor dict compatible with model_input_names.
186
192
  """
@@ -190,7 +196,9 @@ class Qwen2TokenizerFast(BaseTokenizer):
190
196
 
191
197
  batch_ids = self.batch_encode(texts)
192
198
  ids_lens = [len(ids_) for ids_ in batch_ids]
193
- max_length = max_length if max_length is not None else min(max(ids_lens), self.model_max_length)
199
+ max_length = max_length if max_length is not None else self.model_max_length
200
+ if padding_strategy == PaddingStrategy.LONGEST:
201
+ max_length = min(max(ids_lens), max_length)
194
202
  padding_side = padding_side if padding_side is not None else self.padding_side
195
203
 
196
204
  encoded = torch.zeros(len(texts), max_length, dtype=torch.long)
@@ -27,18 +27,19 @@ SD3_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd3", "sd3_tex
27
27
  SDXL_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_text_encoder.json")
28
28
  SDXL_UNET_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_unet.json")
29
29
 
30
- WAN2_1_DIT_T2V_1_3B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-t2v-1.3b.json")
31
- WAN2_1_DIT_T2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-t2v-14b.json")
32
- WAN2_1_DIT_I2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-i2v-14b.json")
33
- WAN2_1_DIT_FLF2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-flf2v-14b.json")
34
- WAN2_2_DIT_TI2V_5B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-ti2v-5b.json")
35
- WAN2_2_DIT_T2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-t2v-a14b.json")
36
- WAN2_2_DIT_I2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-i2v-a14b.json")
37
- WAN2_2_DIT_S2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-s2v-14b.json")
38
-
39
- WAN2_1_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.1-vae.json")
40
- WAN2_2_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.2-vae.json")
41
- WAN_VAE_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan-vae-keymap.json")
30
+ WAN2_1_DIT_T2V_1_3B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_t2v_1.3b.json")
31
+ WAN2_1_DIT_T2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_t2v_14b.json")
32
+ WAN2_1_DIT_I2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_i2v_14b.json")
33
+ WAN2_1_DIT_FLF2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_flf2v_14b.json")
34
+ WAN2_2_DIT_TI2V_5B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_ti2v_5b.json")
35
+ WAN2_2_DIT_T2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_t2v_a14b.json")
36
+ WAN2_2_DIT_I2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_i2v_a14b.json")
37
+ WAN2_2_DIT_S2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_s2v_14b.json")
38
+ WAN_DIT_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan_dit_keymap.json")
39
+
40
+ WAN2_1_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.1_vae.json")
41
+ WAN2_2_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.2_vae.json")
42
+ WAN_VAE_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan_vae_keymap.json")
42
43
 
43
44
  QWEN_IMAGE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "qwen_image", "qwen2_5_vl_config.json")
44
45
  QWEN_IMAGE_VISION_CONFIG_FILE = os.path.join(CONF_PATH, "models", "qwen_image", "qwen2_5_vl_vision_config.json")
@@ -12,7 +12,7 @@ from modelscope import snapshot_download
12
12
  from modelscope.hub.api import HubApi
13
13
  from diffsynth_engine.utils import logging
14
14
  from diffsynth_engine.utils.lock import HeartbeatFileLock
15
- from diffsynth_engine.utils.env import DIFFSYNTH_FILELOCK_DIR, DIFFSYNTH_CACHE
15
+ from diffsynth_engine.utils.env import DIFFSYNTH_FILELOCK_DIR, DIFFSYNTH_CACHE, MS_HUB_OFFLINE
16
16
  from diffsynth_engine.utils.constants import MB
17
17
 
18
18
  logger = logging.get_logger(__name__)
@@ -81,7 +81,9 @@ def fetch_modelscope_model(
81
81
  api.login(access_token)
82
82
  with HeartbeatFileLock(lock_file_path):
83
83
  directory = os.path.join(DIFFSYNTH_CACHE, "modelscope", model_id, revision if revision else "__version")
84
- dirpath = snapshot_download(model_id, revision=revision, local_dir=directory, allow_patterns=path)
84
+ dirpath = snapshot_download(
85
+ model_id, revision=revision, local_dir=directory, allow_patterns=path, local_files_only=MS_HUB_OFFLINE
86
+ )
85
87
 
86
88
  if isinstance(path, str):
87
89
  path = glob.glob(os.path.join(dirpath, path))
@@ -8,3 +8,5 @@ DIFFSYNTH_CACHE = os.environ.get("DIFFSYNTH_CACHE", os.path.join(HOME, ".cache",
8
8
  DIFFSYNTH_FILELOCK_DIR = os.environ.get(
9
9
  "DIFFSYNTH_FILELOCK_DIR", os.path.join(HOME, ".cache", "diffsynth", "filelocks")
10
10
  )
11
+
12
+ MS_HUB_OFFLINE = os.getenv("MS_HUB_OFFLINE", "0").lower() in ("1", "true", "yes")
@@ -44,3 +44,9 @@ if SPARGE_ATTN_AVAILABLE:
44
44
  logger.info("Sparge attention is available")
45
45
  else:
46
46
  logger.info("Sparge attention is not available")
47
+
48
+ VIDEO_SPARSE_ATTN_AVAILABLE = importlib.util.find_spec("vsa") is not None
49
+ if VIDEO_SPARSE_ATTN_AVAILABLE:
50
+ logger.info("Video sparse attention is available")
51
+ else:
52
+ logger.info("Video sparse attention is not available")
@@ -9,12 +9,10 @@ try:
9
9
 
10
10
  use_fast_safetensors = True
11
11
  except ImportError:
12
- from safetensors.torch import load_file as _load_file
13
-
14
12
  use_fast_safetensors = False
15
13
 
16
14
 
17
- def load_file(path: str | os.PathLike, device: str = "cpu"):
15
+ def load_file(path: str | os.PathLike, device: str = "cpu", need_metadata: bool = False):
18
16
  if use_fast_safetensors:
19
17
  logger.info(f"FastSafetensors load model from {path}")
20
18
  start_time = time.time()
@@ -24,13 +22,34 @@ def load_file(path: str | os.PathLike, device: str = "cpu"):
24
22
  direct_io=(os.environ.get("FAST_SAFETENSORS_DIRECT_IO", "False").upper() == "TRUE"),
25
23
  )
26
24
  logger.info(f"FastSafetensors Load Model End. Time: {time.time() - start_time:.2f}s")
27
- return {k: v.to(device) for k, v in result.items()}
25
+ state_dict = {k: v.to(device) for k, v in result.items()}
26
+
27
+ if need_metadata:
28
+ # FastSafetensors不直接支持metadata,需要用标准safetensors获取
29
+ from safetensors import safe_open
30
+
31
+ with safe_open(str(path), framework="pt", device="cpu") as f:
32
+ metadata = f.metadata()
33
+ return state_dict, metadata
34
+ else:
35
+ return state_dict
28
36
  else:
29
37
  logger.info(f"Safetensors load model from {path}")
30
38
  start_time = time.time()
31
- result = _load_file(path, device=device)
39
+
40
+ from safetensors import safe_open
41
+
42
+ with safe_open(path, framework="pt", device="cpu") as f:
43
+ state_dict = {k: f.get_tensor(k).to(device) for k in f.keys()}
44
+ if need_metadata:
45
+ metadata = f.metadata()
46
+
32
47
  logger.info(f"Safetensors Load Model End. Time: {time.time() - start_time:.2f}s")
33
- return result
48
+
49
+ if need_metadata:
50
+ return state_dict, metadata
51
+ else:
52
+ return state_dict
34
53
 
35
54
 
36
55
  save_file = _save_file
@@ -8,19 +8,17 @@ import torch.multiprocessing as mp
8
8
  import torch.distributed as dist
9
9
  from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
10
10
  from torch.distributed.fsdp import ShardingStrategy
11
- from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
11
+ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
12
12
  from torch.distributed.device_mesh import DeviceMesh
13
13
  from torch.distributed.tensor.parallel.style import ParallelStyle
14
14
  from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
15
15
  from contextlib import contextmanager
16
16
  from datetime import timedelta
17
17
  from functools import partial
18
- from typing import Dict, List, Union, Optional
18
+ from typing import Dict, List, Set, Type, Union, Optional
19
19
  from queue import Empty
20
20
 
21
21
  import diffsynth_engine.models.basic.attention as attention_ops
22
- from diffsynth_engine.models import PreTrainedModel
23
- from diffsynth_engine.pipelines import BasePipeline
24
22
  from diffsynth_engine.utils.platform import empty_cache
25
23
  from diffsynth_engine.utils import logging
26
24
 
@@ -40,10 +38,14 @@ class ProcessGroupSingleton(Singleton):
40
38
  def __init__(self):
41
39
  self.CFG_GROUP: Optional[dist.ProcessGroup] = None
42
40
  self.SP_GROUP: Optional[dist.ProcessGroup] = None
41
+ self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None
42
+ self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None
43
43
  self.TP_GROUP: Optional[dist.ProcessGroup] = None
44
44
 
45
45
  self.CFG_RANKS: List[int] = []
46
46
  self.SP_RANKS: List[int] = []
47
+ self.SP_ULYSSUES_RANKS: List[int] = []
48
+ self.SP_RING_RANKS: List[int] = []
47
49
  self.TP_RANKS: List[int] = []
48
50
 
49
51
 
@@ -82,6 +84,38 @@ def get_sp_ranks():
82
84
  return PROCESS_GROUP.SP_RANKS
83
85
 
84
86
 
87
+ def get_sp_ulysses_group():
88
+ return PROCESS_GROUP.SP_ULYSSUES_GROUP
89
+
90
+
91
+ def get_sp_ulysses_world_size():
92
+ return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1
93
+
94
+
95
+ def get_sp_ulysses_rank():
96
+ return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0
97
+
98
+
99
+ def get_sp_ulysses_ranks():
100
+ return PROCESS_GROUP.SP_ULYSSUES_RANKS
101
+
102
+
103
+ def get_sp_ring_group():
104
+ return PROCESS_GROUP.SP_RING_GROUP
105
+
106
+
107
+ def get_sp_ring_world_size():
108
+ return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1
109
+
110
+
111
+ def get_sp_ring_rank():
112
+ return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0
113
+
114
+
115
+ def get_sp_ring_ranks():
116
+ return PROCESS_GROUP.SP_RING_RANKS
117
+
118
+
85
119
  def get_tp_group():
86
120
  return PROCESS_GROUP.TP_GROUP
87
121
 
@@ -127,23 +161,32 @@ def init_parallel_pgs(
127
161
  blocks = [list(range(world_size))]
128
162
  cfg_groups, cfg_blocks = make_parallel_groups(blocks, cfg_degree)
129
163
  for cfg_ranks in cfg_groups:
130
- cfg_group = dist.new_group(cfg_ranks)
131
164
  if rank in cfg_ranks:
132
- PROCESS_GROUP.CFG_GROUP = cfg_group
165
+ PROCESS_GROUP.CFG_GROUP = dist.new_group(cfg_ranks)
133
166
  PROCESS_GROUP.CFG_RANKS = cfg_ranks
134
167
 
135
168
  sp_groups, sp_blocks = make_parallel_groups(cfg_blocks, sp_degree)
136
169
  for sp_ranks in sp_groups:
137
- group = dist.new_group(sp_ranks)
138
170
  if rank in sp_ranks:
139
- PROCESS_GROUP.SP_GROUP = group
171
+ PROCESS_GROUP.SP_GROUP = dist.new_group(sp_ranks)
140
172
  PROCESS_GROUP.SP_RANKS = sp_ranks
141
173
 
174
+ sp_ulysses_groups, sp_ulysses_blocks = make_parallel_groups(cfg_blocks, sp_ulysses_degree)
175
+ for sp_ulysses_ranks in sp_ulysses_groups:
176
+ if rank in sp_ulysses_ranks:
177
+ PROCESS_GROUP.SP_ULYSSUES_GROUP = dist.new_group(sp_ulysses_ranks)
178
+ PROCESS_GROUP.SP_ULYSSUES_RANKS = sp_ulysses_ranks
179
+
180
+ sp_ring_groups, _ = make_parallel_groups(sp_ulysses_blocks, sp_ring_degree)
181
+ for sp_ring_ranks in sp_ring_groups:
182
+ if rank in sp_ring_ranks:
183
+ PROCESS_GROUP.SP_RING_GROUP = dist.new_group(sp_ring_ranks)
184
+ PROCESS_GROUP.SP_RING_RANKS = sp_ring_ranks
185
+
142
186
  tp_groups, _ = make_parallel_groups(sp_blocks, tp_degree)
143
187
  for tp_ranks in tp_groups:
144
- group = dist.new_group(tp_ranks)
145
188
  if rank in tp_ranks:
146
- PROCESS_GROUP.TP_GROUP = group
189
+ PROCESS_GROUP.TP_GROUP = dist.new_group(tp_ranks)
147
190
  PROCESS_GROUP.TP_RANKS = tp_ranks
148
191
 
149
192
  set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
@@ -174,25 +217,14 @@ def to_device(data, device):
174
217
  def shard_model(
175
218
  module: nn.Module,
176
219
  device_id: int | torch.device,
220
+ wrap_module_cls: Set[Type[nn.Module]],
177
221
  sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD,
178
- wrap_module_names: Optional[List[str]] = None,
179
222
  ):
180
- wrap_module_names = wrap_module_names or []
181
-
182
- def wrap_fn(m):
183
- for name in wrap_module_names:
184
- submodule = getattr(module, name)
185
- if isinstance(submodule, nn.ModuleList) and m in submodule:
186
- return True
187
- elif not isinstance(submodule, nn.ModuleList) and m is submodule:
188
- return True
189
- return False
190
-
191
223
  return FSDP(
192
224
  module,
193
225
  device_id=device_id,
194
226
  sharding_strategy=sharding_strategy,
195
- auto_wrap_policy=partial(lambda_auto_wrap_policy, lambda_fn=wrap_fn),
227
+ auto_wrap_policy=partial(transformer_auto_wrap_policy, transformer_layer_cls=wrap_module_cls),
196
228
  )
197
229
 
198
230
 
@@ -266,14 +298,15 @@ def _worker_loop(
266
298
  world_size=world_size,
267
299
  )
268
300
 
269
- def wrap_for_parallel(module: Union[PreTrainedModel, BasePipeline]):
270
- if isinstance(module, BasePipeline):
271
- for model_name in module.model_names:
272
- if isinstance(submodule := getattr(module, model_name), PreTrainedModel):
301
+ def wrap_for_parallel(module):
302
+ if hasattr(module, "model_names"):
303
+ for model_name in getattr(module, "model_names"):
304
+ submodule = getattr(module, model_name)
305
+ if getattr(submodule, "_supports_parallelization", False):
273
306
  setattr(module, model_name, wrap_for_parallel(submodule))
274
307
  return module
275
308
 
276
- if not module._supports_parallelization:
309
+ if not getattr(module, "_supports_parallelization", False):
277
310
  return module
278
311
 
279
312
  if tp_degree > 1:
@@ -283,7 +316,7 @@ def _worker_loop(
283
316
  parallelize_plan=module.get_tp_plan(),
284
317
  )
285
318
  elif use_fsdp:
286
- module = shard_model(module, device_id=device, wrap_module_names=module.get_fsdp_modules())
319
+ module = shard_model(module, device_id=device, wrap_module_cls=module.get_fsdp_module_cls())
287
320
  return module
288
321
 
289
322
  module = None
@@ -41,7 +41,9 @@ def save_video(frames, save_path, fps=15):
41
41
  writer.write(frames, fps=fps, codec=codec)
42
42
 
43
43
 
44
- def read_n_frames(frames: List[Image.Image], original_fps: int, n_frames: int, target_fps: int = 16) -> List[Image.Image]:
44
+ def read_n_frames(
45
+ frames: List[Image.Image], original_fps: int, n_frames: int, target_fps: int = 16
46
+ ) -> List[Image.Image]:
45
47
  num_frames = len(frames)
46
48
  interval = max(1, round(original_fps / target_fps))
47
49
  sampled_frames: List[Image.Image] = []
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.5.1.dev4
3
+ Version: 0.6.1.dev25
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent