diffsynth-engine 0.5.1.dev4__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +12 -0
- diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +19 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +22 -6
- diffsynth_engine/conf/models/flux/flux_dit.json +20 -1
- diffsynth_engine/conf/models/flux/flux_vae.json +253 -5
- diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
- diffsynth_engine/configs/__init__.py +16 -1
- diffsynth_engine/configs/controlnet.py +13 -0
- diffsynth_engine/configs/pipeline.py +37 -11
- diffsynth_engine/models/base.py +1 -1
- diffsynth_engine/models/basic/attention.py +105 -43
- diffsynth_engine/models/basic/transformer_helper.py +36 -2
- diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
- diffsynth_engine/models/flux/flux_controlnet.py +16 -30
- diffsynth_engine/models/flux/flux_dit.py +49 -62
- diffsynth_engine/models/flux/flux_dit_fbcache.py +26 -28
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/flux/flux_text_encoder.py +1 -1
- diffsynth_engine/models/flux/flux_vae.py +20 -2
- diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +4 -2
- diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
- diffsynth_engine/models/qwen_image/qwen_image_dit.py +151 -58
- diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
- diffsynth_engine/models/qwen_image/qwen_image_vae.py +1 -1
- diffsynth_engine/models/sd/sd_text_encoder.py +1 -1
- diffsynth_engine/models/sd/sd_unet.py +1 -1
- diffsynth_engine/models/sd3/sd3_dit.py +1 -1
- diffsynth_engine/models/sd3/sd3_text_encoder.py +1 -1
- diffsynth_engine/models/sdxl/sdxl_text_encoder.py +1 -1
- diffsynth_engine/models/sdxl/sdxl_unet.py +1 -1
- diffsynth_engine/models/vae/vae.py +1 -1
- diffsynth_engine/models/wan/wan_audio_encoder.py +6 -3
- diffsynth_engine/models/wan/wan_dit.py +65 -28
- diffsynth_engine/models/wan/wan_s2v_dit.py +1 -1
- diffsynth_engine/models/wan/wan_text_encoder.py +13 -13
- diffsynth_engine/models/wan/wan_vae.py +2 -2
- diffsynth_engine/pipelines/base.py +73 -7
- diffsynth_engine/pipelines/flux_image.py +139 -120
- diffsynth_engine/pipelines/hunyuan3d_shape.py +4 -0
- diffsynth_engine/pipelines/qwen_image.py +272 -87
- diffsynth_engine/pipelines/sdxl_image.py +1 -1
- diffsynth_engine/pipelines/utils.py +52 -0
- diffsynth_engine/pipelines/wan_s2v.py +25 -14
- diffsynth_engine/pipelines/wan_video.py +43 -19
- diffsynth_engine/tokenizers/base.py +6 -0
- diffsynth_engine/tokenizers/qwen2.py +12 -4
- diffsynth_engine/utils/constants.py +13 -12
- diffsynth_engine/utils/download.py +4 -2
- diffsynth_engine/utils/env.py +2 -0
- diffsynth_engine/utils/flag.py +6 -0
- diffsynth_engine/utils/loader.py +25 -6
- diffsynth_engine/utils/parallel.py +62 -29
- diffsynth_engine/utils/video.py +3 -1
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +69 -67
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
|
@@ -239,7 +239,15 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
|
|
|
239
239
|
|
|
240
240
|
return ref_latents, motion_latents, motion_frames
|
|
241
241
|
|
|
242
|
-
def encode_pose(
|
|
242
|
+
def encode_pose(
|
|
243
|
+
self,
|
|
244
|
+
pose_video: List[Image.Image],
|
|
245
|
+
pose_video_fps: int,
|
|
246
|
+
num_clips: int,
|
|
247
|
+
num_frames_per_clip: int,
|
|
248
|
+
height: int,
|
|
249
|
+
width: int,
|
|
250
|
+
):
|
|
243
251
|
self.load_models_to_device(["vae"])
|
|
244
252
|
max_num_pose_frames = num_frames_per_clip * num_clips
|
|
245
253
|
pose_video = read_n_frames(pose_video, pose_video_fps, max_num_pose_frames, target_fps=self.config.fps)
|
|
@@ -386,6 +394,7 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
|
|
|
386
394
|
void_audio_input: torch.Tensor | None = None,
|
|
387
395
|
):
|
|
388
396
|
latents = latents.to(dtype=self.config.model_dtype, device=self.device)
|
|
397
|
+
attn_kwargs = self.get_attn_kwargs(latents)
|
|
389
398
|
|
|
390
399
|
noise_pred = model(
|
|
391
400
|
x=latents,
|
|
@@ -400,6 +409,7 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
|
|
|
400
409
|
drop_motion_frames=drop_motion_frames,
|
|
401
410
|
audio_mask=audio_mask,
|
|
402
411
|
void_audio_input=void_audio_input,
|
|
412
|
+
attn_kwargs=attn_kwargs,
|
|
403
413
|
)
|
|
404
414
|
return noise_pred
|
|
405
415
|
|
|
@@ -466,7 +476,9 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
|
|
|
466
476
|
dtype=self.dtype,
|
|
467
477
|
).to(self.device)
|
|
468
478
|
if pose_video is not None:
|
|
469
|
-
pose_latents_all_clips = self.encode_pose(
|
|
479
|
+
pose_latents_all_clips = self.encode_pose(
|
|
480
|
+
pose_video, pose_video_fps, num_clips, num_frames_per_clip, height, width
|
|
481
|
+
)
|
|
470
482
|
|
|
471
483
|
output_frames_all_clips = []
|
|
472
484
|
for clip_idx in range(num_clips):
|
|
@@ -602,7 +614,9 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
|
|
|
602
614
|
return cls.from_state_dict(state_dicts, config)
|
|
603
615
|
|
|
604
616
|
@classmethod
|
|
605
|
-
def from_state_dict(
|
|
617
|
+
def from_state_dict(
|
|
618
|
+
cls, state_dicts: WanS2VStateDicts, config: WanSpeech2VideoPipelineConfig
|
|
619
|
+
) -> "WanSpeech2VideoPipeline":
|
|
606
620
|
if config.parallelism > 1:
|
|
607
621
|
pipe = ParallelWrapper(
|
|
608
622
|
cfg_degree=config.cfg_degree,
|
|
@@ -617,7 +631,9 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
|
|
|
617
631
|
return pipe
|
|
618
632
|
|
|
619
633
|
@classmethod
|
|
620
|
-
def _from_state_dict(
|
|
634
|
+
def _from_state_dict(
|
|
635
|
+
cls, state_dicts: WanS2VStateDicts, config: WanSpeech2VideoPipelineConfig
|
|
636
|
+
) -> "WanSpeech2VideoPipeline":
|
|
621
637
|
# default params from model config
|
|
622
638
|
vae_type = "wan2.1-vae"
|
|
623
639
|
dit_type = "wan2.2-s2v-14b"
|
|
@@ -632,25 +648,20 @@ class WanSpeech2VideoPipeline(WanVideoPipeline):
|
|
|
632
648
|
init_device = "cpu" if config.offload_mode is not None else config.device
|
|
633
649
|
tokenizer = WanT5Tokenizer(WAN_TOKENIZER_CONF_PATH, seq_len=512, clean="whitespace")
|
|
634
650
|
text_encoder = WanTextEncoder.from_state_dict(state_dicts.t5, device=init_device, dtype=config.t5_dtype)
|
|
635
|
-
vae = WanVideoVAE.from_state_dict(
|
|
651
|
+
vae = WanVideoVAE.from_state_dict(
|
|
652
|
+
state_dicts.vae, config=vae_config, device=init_device, dtype=config.vae_dtype
|
|
653
|
+
)
|
|
636
654
|
audio_encoder = Wav2Vec2Model.from_state_dict(
|
|
637
655
|
state_dicts.audio_encoder, config=Wav2Vec2Config(), device=init_device, dtype=config.audio_encoder_dtype
|
|
638
656
|
)
|
|
639
657
|
|
|
640
658
|
with LoRAContext():
|
|
641
|
-
attn_kwargs = {
|
|
642
|
-
"attn_impl": config.dit_attn_impl,
|
|
643
|
-
"sparge_smooth_k": config.sparge_smooth_k,
|
|
644
|
-
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
645
|
-
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
646
|
-
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
647
|
-
}
|
|
648
659
|
dit = WanS2VDiT.from_state_dict(
|
|
649
660
|
state_dicts.model,
|
|
650
661
|
config=model_config,
|
|
651
|
-
device=init_device,
|
|
662
|
+
device=("cpu" if config.use_fsdp else init_device),
|
|
652
663
|
dtype=config.model_dtype,
|
|
653
|
-
|
|
664
|
+
use_vsa=(config.dit_attn_impl.value == "vsa"),
|
|
654
665
|
)
|
|
655
666
|
if config.use_fp8_linear:
|
|
656
667
|
enable_fp8_linear(dit)
|
|
@@ -95,8 +95,14 @@ class WanLoRAConverter(LoRAStateDictConverter):
|
|
|
95
95
|
return state_dict
|
|
96
96
|
|
|
97
97
|
|
|
98
|
+
class WanLowNoiseLoRAConverter(WanLoRAConverter):
|
|
99
|
+
def convert(self, state_dict):
|
|
100
|
+
return {"dit2": super().convert(state_dict)["dit"]}
|
|
101
|
+
|
|
102
|
+
|
|
98
103
|
class WanVideoPipeline(BasePipeline):
|
|
99
104
|
lora_converter = WanLoRAConverter()
|
|
105
|
+
low_noise_lora_converter = WanLowNoiseLoRAConverter()
|
|
100
106
|
|
|
101
107
|
def __init__(
|
|
102
108
|
self,
|
|
@@ -133,7 +139,13 @@ class WanVideoPipeline(BasePipeline):
|
|
|
133
139
|
self.image_encoder = image_encoder
|
|
134
140
|
self.model_names = ["text_encoder", "dit", "dit2", "vae", "image_encoder"]
|
|
135
141
|
|
|
136
|
-
def load_loras(
|
|
142
|
+
def load_loras(
|
|
143
|
+
self,
|
|
144
|
+
lora_list: List[Tuple[str, float]],
|
|
145
|
+
fused: bool = True,
|
|
146
|
+
save_original_weight: bool = False,
|
|
147
|
+
lora_converter: Optional[WanLoRAConverter] = None,
|
|
148
|
+
):
|
|
137
149
|
assert self.config.tp_degree is None or self.config.tp_degree == 1, (
|
|
138
150
|
"load LoRA is not allowed when tensor parallel is enabled; "
|
|
139
151
|
"set tp_degree=None or tp_degree=1 during pipeline initialization"
|
|
@@ -142,10 +154,24 @@ class WanVideoPipeline(BasePipeline):
|
|
|
142
154
|
"load fused LoRA is not allowed when fully sharded data parallel is enabled; "
|
|
143
155
|
"either load LoRA with fused=False or set use_fsdp=False during pipeline initialization"
|
|
144
156
|
)
|
|
145
|
-
super().load_loras(lora_list, fused, save_original_weight)
|
|
157
|
+
super().load_loras(lora_list, fused, save_original_weight, lora_converter)
|
|
158
|
+
|
|
159
|
+
def load_loras_low_noise(
|
|
160
|
+
self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False
|
|
161
|
+
):
|
|
162
|
+
assert self.dit2 is not None, "low noise LoRA can only be applied to Wan2.2"
|
|
163
|
+
self.load_loras(lora_list, fused, save_original_weight, self.low_noise_lora_converter)
|
|
164
|
+
|
|
165
|
+
def load_loras_high_noise(
|
|
166
|
+
self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False
|
|
167
|
+
):
|
|
168
|
+
assert self.dit2 is not None, "high noise LoRA can only be applied to Wan2.2"
|
|
169
|
+
self.load_loras(lora_list, fused, save_original_weight)
|
|
146
170
|
|
|
147
171
|
def unload_loras(self):
|
|
148
172
|
self.dit.unload_loras()
|
|
173
|
+
if self.dit2 is not None:
|
|
174
|
+
self.dit2.unload_loras()
|
|
149
175
|
self.text_encoder.unload_loras()
|
|
150
176
|
|
|
151
177
|
def get_default_fps(self) -> int:
|
|
@@ -301,6 +327,7 @@ class WanVideoPipeline(BasePipeline):
|
|
|
301
327
|
|
|
302
328
|
def predict_noise(self, model, latents, image_clip_feature, image_y, timestep, context):
|
|
303
329
|
latents = latents.to(dtype=self.config.model_dtype, device=self.device)
|
|
330
|
+
attn_kwargs = self.get_attn_kwargs(latents)
|
|
304
331
|
|
|
305
332
|
noise_pred = model(
|
|
306
333
|
x=latents,
|
|
@@ -308,6 +335,7 @@ class WanVideoPipeline(BasePipeline):
|
|
|
308
335
|
context=context,
|
|
309
336
|
clip_feature=image_clip_feature,
|
|
310
337
|
y=image_y,
|
|
338
|
+
attn_kwargs=attn_kwargs,
|
|
311
339
|
)
|
|
312
340
|
return noise_pred
|
|
313
341
|
|
|
@@ -556,19 +584,12 @@ class WanVideoPipeline(BasePipeline):
|
|
|
556
584
|
dit_state_dict = state_dicts.model
|
|
557
585
|
|
|
558
586
|
with LoRAContext():
|
|
559
|
-
attn_kwargs = {
|
|
560
|
-
"attn_impl": config.dit_attn_impl,
|
|
561
|
-
"sparge_smooth_k": config.sparge_smooth_k,
|
|
562
|
-
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
563
|
-
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
564
|
-
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
565
|
-
}
|
|
566
587
|
dit = WanDiT.from_state_dict(
|
|
567
588
|
dit_state_dict,
|
|
568
589
|
config=dit_config,
|
|
569
|
-
device=init_device,
|
|
590
|
+
device=("cpu" if config.use_fsdp else init_device),
|
|
570
591
|
dtype=config.model_dtype,
|
|
571
|
-
|
|
592
|
+
use_vsa=(config.dit_attn_impl.value == "vsa"),
|
|
572
593
|
)
|
|
573
594
|
if config.use_fp8_linear:
|
|
574
595
|
enable_fp8_linear(dit)
|
|
@@ -578,9 +599,9 @@ class WanVideoPipeline(BasePipeline):
|
|
|
578
599
|
dit2 = WanDiT.from_state_dict(
|
|
579
600
|
dit2_state_dict,
|
|
580
601
|
config=dit_config,
|
|
581
|
-
device=init_device,
|
|
602
|
+
device=("cpu" if config.use_fsdp else init_device),
|
|
582
603
|
dtype=config.model_dtype,
|
|
583
|
-
|
|
604
|
+
use_vsa=(config.dit_attn_impl.value == "vsa"),
|
|
584
605
|
)
|
|
585
606
|
if config.use_fp8_linear:
|
|
586
607
|
enable_fp8_linear(dit2)
|
|
@@ -618,19 +639,22 @@ class WanVideoPipeline(BasePipeline):
|
|
|
618
639
|
@staticmethod
|
|
619
640
|
def _get_dit_type(model_state_dict: Dict[str, torch.Tensor] | Dict[str, Dict[str, torch.Tensor]]) -> str:
|
|
620
641
|
# determine wan dit type by model params
|
|
642
|
+
def has_any_key(*xs):
|
|
643
|
+
return any(x in model_state_dict for x in xs)
|
|
644
|
+
|
|
621
645
|
dit_type = None
|
|
622
|
-
if "high_noise_model"
|
|
646
|
+
if has_any_key("high_noise_model"):
|
|
623
647
|
if model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 36:
|
|
624
648
|
dit_type = "wan2.2-i2v-a14b"
|
|
625
649
|
elif model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 16:
|
|
626
650
|
dit_type = "wan2.2-t2v-a14b"
|
|
627
651
|
elif model_state_dict["patch_embedding.weight"].shape[1] == 48:
|
|
628
652
|
dit_type = "wan2.2-ti2v-5b"
|
|
629
|
-
elif "img_emb.emb_pos"
|
|
653
|
+
elif has_any_key("img_emb.emb_pos", "condition_embedder.image_embedder.pos_embed"):
|
|
630
654
|
dit_type = "wan2.1-flf2v-14b"
|
|
631
|
-
elif "img_emb.proj.0.weight"
|
|
655
|
+
elif has_any_key("img_emb.proj.0.weight", "condition_embedder.image_embedder.norm1"):
|
|
632
656
|
dit_type = "wan2.1-i2v-14b"
|
|
633
|
-
elif "blocks.39.self_attn.norm_q.weight"
|
|
657
|
+
elif has_any_key("blocks.39.self_attn.norm_q.weight", "blocks.39.attn1.norm_q.weight"):
|
|
634
658
|
dit_type = "wan2.1-t2v-14b"
|
|
635
659
|
else:
|
|
636
660
|
dit_type = "wan2.1-t2v-1.3b"
|
|
@@ -645,6 +669,6 @@ class WanVideoPipeline(BasePipeline):
|
|
|
645
669
|
return vae_type
|
|
646
670
|
|
|
647
671
|
def compile(self):
|
|
648
|
-
self.dit.compile_repeated_blocks(
|
|
672
|
+
self.dit.compile_repeated_blocks()
|
|
649
673
|
if self.dit2 is not None:
|
|
650
|
-
self.dit2.compile_repeated_blocks(
|
|
674
|
+
self.dit2.compile_repeated_blocks()
|
|
@@ -1,10 +1,16 @@
|
|
|
1
1
|
# Modified from transformers.tokenization_utils_base
|
|
2
2
|
from typing import Dict, List, Union, overload
|
|
3
|
+
from enum import Enum
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
|
|
6
7
|
|
|
7
8
|
|
|
9
|
+
class PaddingStrategy(str, Enum):
|
|
10
|
+
LONGEST = "longest"
|
|
11
|
+
MAX_LENGTH = "max_length"
|
|
12
|
+
|
|
13
|
+
|
|
8
14
|
class BaseTokenizer:
|
|
9
15
|
SPECIAL_TOKENS_ATTRIBUTES = [
|
|
10
16
|
"bos_token",
|
|
@@ -4,7 +4,7 @@ import torch
|
|
|
4
4
|
from typing import Dict, List, Union, Optional
|
|
5
5
|
from tokenizers import Tokenizer as TokenizerFast, AddedToken
|
|
6
6
|
|
|
7
|
-
from diffsynth_engine.tokenizers.base import BaseTokenizer, TOKENIZER_CONFIG_FILE
|
|
7
|
+
from diffsynth_engine.tokenizers.base import BaseTokenizer, PaddingStrategy, TOKENIZER_CONFIG_FILE
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
|
|
@@ -165,22 +165,28 @@ class Qwen2TokenizerFast(BaseTokenizer):
|
|
|
165
165
|
texts: Union[str, List[str]],
|
|
166
166
|
max_length: Optional[int] = None,
|
|
167
167
|
padding_side: Optional[str] = None,
|
|
168
|
+
padding_strategy: Union[PaddingStrategy, str] = "longest",
|
|
168
169
|
**kwargs,
|
|
169
170
|
) -> Dict[str, "torch.Tensor"]:
|
|
170
171
|
"""
|
|
171
172
|
Tokenize text and prepare for model inputs.
|
|
172
173
|
|
|
173
174
|
Args:
|
|
174
|
-
|
|
175
|
+
texts (`str`, `List[str]`):
|
|
175
176
|
The sequence or batch of sequences to be encoded.
|
|
176
177
|
|
|
177
178
|
max_length (`int`, *optional*):
|
|
178
|
-
|
|
179
|
+
Maximum length of the encoded sequences.
|
|
179
180
|
|
|
180
181
|
padding_side (`str`, *optional*):
|
|
181
182
|
The side on which the padding should be applied. Should be selected between `"right"` and `"left"`.
|
|
182
183
|
Defaults to `"right"`.
|
|
183
184
|
|
|
185
|
+
padding_strategy (`PaddingStrategy`, `str`, *optional*):
|
|
186
|
+
If `"longest"`, will pad the sequences to the longest sequence in the batch.
|
|
187
|
+
If `"max_length"`, will pad the sequences to the `max_length` argument.
|
|
188
|
+
Defaults to `"longest"`.
|
|
189
|
+
|
|
184
190
|
Returns:
|
|
185
191
|
`Dict[str, "torch.Tensor"]`: tensor dict compatible with model_input_names.
|
|
186
192
|
"""
|
|
@@ -190,7 +196,9 @@ class Qwen2TokenizerFast(BaseTokenizer):
|
|
|
190
196
|
|
|
191
197
|
batch_ids = self.batch_encode(texts)
|
|
192
198
|
ids_lens = [len(ids_) for ids_ in batch_ids]
|
|
193
|
-
max_length = max_length if max_length is not None else
|
|
199
|
+
max_length = max_length if max_length is not None else self.model_max_length
|
|
200
|
+
if padding_strategy == PaddingStrategy.LONGEST:
|
|
201
|
+
max_length = min(max(ids_lens), max_length)
|
|
194
202
|
padding_side = padding_side if padding_side is not None else self.padding_side
|
|
195
203
|
|
|
196
204
|
encoded = torch.zeros(len(texts), max_length, dtype=torch.long)
|
|
@@ -27,18 +27,19 @@ SD3_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd3", "sd3_tex
|
|
|
27
27
|
SDXL_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_text_encoder.json")
|
|
28
28
|
SDXL_UNET_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_unet.json")
|
|
29
29
|
|
|
30
|
-
WAN2_1_DIT_T2V_1_3B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
31
|
-
WAN2_1_DIT_T2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
32
|
-
WAN2_1_DIT_I2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
33
|
-
WAN2_1_DIT_FLF2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
34
|
-
WAN2_2_DIT_TI2V_5B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
35
|
-
WAN2_2_DIT_T2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
36
|
-
WAN2_2_DIT_I2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
37
|
-
WAN2_2_DIT_S2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
30
|
+
WAN2_1_DIT_T2V_1_3B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_t2v_1.3b.json")
|
|
31
|
+
WAN2_1_DIT_T2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_t2v_14b.json")
|
|
32
|
+
WAN2_1_DIT_I2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_i2v_14b.json")
|
|
33
|
+
WAN2_1_DIT_FLF2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_flf2v_14b.json")
|
|
34
|
+
WAN2_2_DIT_TI2V_5B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_ti2v_5b.json")
|
|
35
|
+
WAN2_2_DIT_T2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_t2v_a14b.json")
|
|
36
|
+
WAN2_2_DIT_I2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_i2v_a14b.json")
|
|
37
|
+
WAN2_2_DIT_S2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_s2v_14b.json")
|
|
38
|
+
WAN_DIT_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan_dit_keymap.json")
|
|
39
|
+
|
|
40
|
+
WAN2_1_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.1_vae.json")
|
|
41
|
+
WAN2_2_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.2_vae.json")
|
|
42
|
+
WAN_VAE_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan_vae_keymap.json")
|
|
42
43
|
|
|
43
44
|
QWEN_IMAGE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "qwen_image", "qwen2_5_vl_config.json")
|
|
44
45
|
QWEN_IMAGE_VISION_CONFIG_FILE = os.path.join(CONF_PATH, "models", "qwen_image", "qwen2_5_vl_vision_config.json")
|
|
@@ -12,7 +12,7 @@ from modelscope import snapshot_download
|
|
|
12
12
|
from modelscope.hub.api import HubApi
|
|
13
13
|
from diffsynth_engine.utils import logging
|
|
14
14
|
from diffsynth_engine.utils.lock import HeartbeatFileLock
|
|
15
|
-
from diffsynth_engine.utils.env import DIFFSYNTH_FILELOCK_DIR, DIFFSYNTH_CACHE
|
|
15
|
+
from diffsynth_engine.utils.env import DIFFSYNTH_FILELOCK_DIR, DIFFSYNTH_CACHE, MS_HUB_OFFLINE
|
|
16
16
|
from diffsynth_engine.utils.constants import MB
|
|
17
17
|
|
|
18
18
|
logger = logging.get_logger(__name__)
|
|
@@ -81,7 +81,9 @@ def fetch_modelscope_model(
|
|
|
81
81
|
api.login(access_token)
|
|
82
82
|
with HeartbeatFileLock(lock_file_path):
|
|
83
83
|
directory = os.path.join(DIFFSYNTH_CACHE, "modelscope", model_id, revision if revision else "__version")
|
|
84
|
-
dirpath = snapshot_download(
|
|
84
|
+
dirpath = snapshot_download(
|
|
85
|
+
model_id, revision=revision, local_dir=directory, allow_patterns=path, local_files_only=MS_HUB_OFFLINE
|
|
86
|
+
)
|
|
85
87
|
|
|
86
88
|
if isinstance(path, str):
|
|
87
89
|
path = glob.glob(os.path.join(dirpath, path))
|
diffsynth_engine/utils/env.py
CHANGED
|
@@ -8,3 +8,5 @@ DIFFSYNTH_CACHE = os.environ.get("DIFFSYNTH_CACHE", os.path.join(HOME, ".cache",
|
|
|
8
8
|
DIFFSYNTH_FILELOCK_DIR = os.environ.get(
|
|
9
9
|
"DIFFSYNTH_FILELOCK_DIR", os.path.join(HOME, ".cache", "diffsynth", "filelocks")
|
|
10
10
|
)
|
|
11
|
+
|
|
12
|
+
MS_HUB_OFFLINE = os.getenv("MS_HUB_OFFLINE", "0").lower() in ("1", "true", "yes")
|
diffsynth_engine/utils/flag.py
CHANGED
|
@@ -44,3 +44,9 @@ if SPARGE_ATTN_AVAILABLE:
|
|
|
44
44
|
logger.info("Sparge attention is available")
|
|
45
45
|
else:
|
|
46
46
|
logger.info("Sparge attention is not available")
|
|
47
|
+
|
|
48
|
+
VIDEO_SPARSE_ATTN_AVAILABLE = importlib.util.find_spec("vsa") is not None
|
|
49
|
+
if VIDEO_SPARSE_ATTN_AVAILABLE:
|
|
50
|
+
logger.info("Video sparse attention is available")
|
|
51
|
+
else:
|
|
52
|
+
logger.info("Video sparse attention is not available")
|
diffsynth_engine/utils/loader.py
CHANGED
|
@@ -9,12 +9,10 @@ try:
|
|
|
9
9
|
|
|
10
10
|
use_fast_safetensors = True
|
|
11
11
|
except ImportError:
|
|
12
|
-
from safetensors.torch import load_file as _load_file
|
|
13
|
-
|
|
14
12
|
use_fast_safetensors = False
|
|
15
13
|
|
|
16
14
|
|
|
17
|
-
def load_file(path: str | os.PathLike, device: str = "cpu"):
|
|
15
|
+
def load_file(path: str | os.PathLike, device: str = "cpu", need_metadata: bool = False):
|
|
18
16
|
if use_fast_safetensors:
|
|
19
17
|
logger.info(f"FastSafetensors load model from {path}")
|
|
20
18
|
start_time = time.time()
|
|
@@ -24,13 +22,34 @@ def load_file(path: str | os.PathLike, device: str = "cpu"):
|
|
|
24
22
|
direct_io=(os.environ.get("FAST_SAFETENSORS_DIRECT_IO", "False").upper() == "TRUE"),
|
|
25
23
|
)
|
|
26
24
|
logger.info(f"FastSafetensors Load Model End. Time: {time.time() - start_time:.2f}s")
|
|
27
|
-
|
|
25
|
+
state_dict = {k: v.to(device) for k, v in result.items()}
|
|
26
|
+
|
|
27
|
+
if need_metadata:
|
|
28
|
+
# FastSafetensors不直接支持metadata,需要用标准safetensors获取
|
|
29
|
+
from safetensors import safe_open
|
|
30
|
+
|
|
31
|
+
with safe_open(str(path), framework="pt", device="cpu") as f:
|
|
32
|
+
metadata = f.metadata()
|
|
33
|
+
return state_dict, metadata
|
|
34
|
+
else:
|
|
35
|
+
return state_dict
|
|
28
36
|
else:
|
|
29
37
|
logger.info(f"Safetensors load model from {path}")
|
|
30
38
|
start_time = time.time()
|
|
31
|
-
|
|
39
|
+
|
|
40
|
+
from safetensors import safe_open
|
|
41
|
+
|
|
42
|
+
with safe_open(path, framework="pt", device="cpu") as f:
|
|
43
|
+
state_dict = {k: f.get_tensor(k).to(device) for k in f.keys()}
|
|
44
|
+
if need_metadata:
|
|
45
|
+
metadata = f.metadata()
|
|
46
|
+
|
|
32
47
|
logger.info(f"Safetensors Load Model End. Time: {time.time() - start_time:.2f}s")
|
|
33
|
-
|
|
48
|
+
|
|
49
|
+
if need_metadata:
|
|
50
|
+
return state_dict, metadata
|
|
51
|
+
else:
|
|
52
|
+
return state_dict
|
|
34
53
|
|
|
35
54
|
|
|
36
55
|
save_file = _save_file
|
|
@@ -8,19 +8,17 @@ import torch.multiprocessing as mp
|
|
|
8
8
|
import torch.distributed as dist
|
|
9
9
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
10
10
|
from torch.distributed.fsdp import ShardingStrategy
|
|
11
|
-
from torch.distributed.fsdp.wrap import
|
|
11
|
+
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
|
12
12
|
from torch.distributed.device_mesh import DeviceMesh
|
|
13
13
|
from torch.distributed.tensor.parallel.style import ParallelStyle
|
|
14
14
|
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
|
|
15
15
|
from contextlib import contextmanager
|
|
16
16
|
from datetime import timedelta
|
|
17
17
|
from functools import partial
|
|
18
|
-
from typing import Dict, List, Union, Optional
|
|
18
|
+
from typing import Dict, List, Set, Type, Union, Optional
|
|
19
19
|
from queue import Empty
|
|
20
20
|
|
|
21
21
|
import diffsynth_engine.models.basic.attention as attention_ops
|
|
22
|
-
from diffsynth_engine.models import PreTrainedModel
|
|
23
|
-
from diffsynth_engine.pipelines import BasePipeline
|
|
24
22
|
from diffsynth_engine.utils.platform import empty_cache
|
|
25
23
|
from diffsynth_engine.utils import logging
|
|
26
24
|
|
|
@@ -40,10 +38,14 @@ class ProcessGroupSingleton(Singleton):
|
|
|
40
38
|
def __init__(self):
|
|
41
39
|
self.CFG_GROUP: Optional[dist.ProcessGroup] = None
|
|
42
40
|
self.SP_GROUP: Optional[dist.ProcessGroup] = None
|
|
41
|
+
self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None
|
|
42
|
+
self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None
|
|
43
43
|
self.TP_GROUP: Optional[dist.ProcessGroup] = None
|
|
44
44
|
|
|
45
45
|
self.CFG_RANKS: List[int] = []
|
|
46
46
|
self.SP_RANKS: List[int] = []
|
|
47
|
+
self.SP_ULYSSUES_RANKS: List[int] = []
|
|
48
|
+
self.SP_RING_RANKS: List[int] = []
|
|
47
49
|
self.TP_RANKS: List[int] = []
|
|
48
50
|
|
|
49
51
|
|
|
@@ -82,6 +84,38 @@ def get_sp_ranks():
|
|
|
82
84
|
return PROCESS_GROUP.SP_RANKS
|
|
83
85
|
|
|
84
86
|
|
|
87
|
+
def get_sp_ulysses_group():
|
|
88
|
+
return PROCESS_GROUP.SP_ULYSSUES_GROUP
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def get_sp_ulysses_world_size():
|
|
92
|
+
return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def get_sp_ulysses_rank():
|
|
96
|
+
return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def get_sp_ulysses_ranks():
|
|
100
|
+
return PROCESS_GROUP.SP_ULYSSUES_RANKS
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def get_sp_ring_group():
|
|
104
|
+
return PROCESS_GROUP.SP_RING_GROUP
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def get_sp_ring_world_size():
|
|
108
|
+
return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_sp_ring_rank():
|
|
112
|
+
return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_sp_ring_ranks():
|
|
116
|
+
return PROCESS_GROUP.SP_RING_RANKS
|
|
117
|
+
|
|
118
|
+
|
|
85
119
|
def get_tp_group():
|
|
86
120
|
return PROCESS_GROUP.TP_GROUP
|
|
87
121
|
|
|
@@ -127,23 +161,32 @@ def init_parallel_pgs(
|
|
|
127
161
|
blocks = [list(range(world_size))]
|
|
128
162
|
cfg_groups, cfg_blocks = make_parallel_groups(blocks, cfg_degree)
|
|
129
163
|
for cfg_ranks in cfg_groups:
|
|
130
|
-
cfg_group = dist.new_group(cfg_ranks)
|
|
131
164
|
if rank in cfg_ranks:
|
|
132
|
-
PROCESS_GROUP.CFG_GROUP =
|
|
165
|
+
PROCESS_GROUP.CFG_GROUP = dist.new_group(cfg_ranks)
|
|
133
166
|
PROCESS_GROUP.CFG_RANKS = cfg_ranks
|
|
134
167
|
|
|
135
168
|
sp_groups, sp_blocks = make_parallel_groups(cfg_blocks, sp_degree)
|
|
136
169
|
for sp_ranks in sp_groups:
|
|
137
|
-
group = dist.new_group(sp_ranks)
|
|
138
170
|
if rank in sp_ranks:
|
|
139
|
-
PROCESS_GROUP.SP_GROUP =
|
|
171
|
+
PROCESS_GROUP.SP_GROUP = dist.new_group(sp_ranks)
|
|
140
172
|
PROCESS_GROUP.SP_RANKS = sp_ranks
|
|
141
173
|
|
|
174
|
+
sp_ulysses_groups, sp_ulysses_blocks = make_parallel_groups(cfg_blocks, sp_ulysses_degree)
|
|
175
|
+
for sp_ulysses_ranks in sp_ulysses_groups:
|
|
176
|
+
if rank in sp_ulysses_ranks:
|
|
177
|
+
PROCESS_GROUP.SP_ULYSSUES_GROUP = dist.new_group(sp_ulysses_ranks)
|
|
178
|
+
PROCESS_GROUP.SP_ULYSSUES_RANKS = sp_ulysses_ranks
|
|
179
|
+
|
|
180
|
+
sp_ring_groups, _ = make_parallel_groups(sp_ulysses_blocks, sp_ring_degree)
|
|
181
|
+
for sp_ring_ranks in sp_ring_groups:
|
|
182
|
+
if rank in sp_ring_ranks:
|
|
183
|
+
PROCESS_GROUP.SP_RING_GROUP = dist.new_group(sp_ring_ranks)
|
|
184
|
+
PROCESS_GROUP.SP_RING_RANKS = sp_ring_ranks
|
|
185
|
+
|
|
142
186
|
tp_groups, _ = make_parallel_groups(sp_blocks, tp_degree)
|
|
143
187
|
for tp_ranks in tp_groups:
|
|
144
|
-
group = dist.new_group(tp_ranks)
|
|
145
188
|
if rank in tp_ranks:
|
|
146
|
-
PROCESS_GROUP.TP_GROUP =
|
|
189
|
+
PROCESS_GROUP.TP_GROUP = dist.new_group(tp_ranks)
|
|
147
190
|
PROCESS_GROUP.TP_RANKS = tp_ranks
|
|
148
191
|
|
|
149
192
|
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
|
|
@@ -174,25 +217,14 @@ def to_device(data, device):
|
|
|
174
217
|
def shard_model(
|
|
175
218
|
module: nn.Module,
|
|
176
219
|
device_id: int | torch.device,
|
|
220
|
+
wrap_module_cls: Set[Type[nn.Module]],
|
|
177
221
|
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD,
|
|
178
|
-
wrap_module_names: Optional[List[str]] = None,
|
|
179
222
|
):
|
|
180
|
-
wrap_module_names = wrap_module_names or []
|
|
181
|
-
|
|
182
|
-
def wrap_fn(m):
|
|
183
|
-
for name in wrap_module_names:
|
|
184
|
-
submodule = getattr(module, name)
|
|
185
|
-
if isinstance(submodule, nn.ModuleList) and m in submodule:
|
|
186
|
-
return True
|
|
187
|
-
elif not isinstance(submodule, nn.ModuleList) and m is submodule:
|
|
188
|
-
return True
|
|
189
|
-
return False
|
|
190
|
-
|
|
191
223
|
return FSDP(
|
|
192
224
|
module,
|
|
193
225
|
device_id=device_id,
|
|
194
226
|
sharding_strategy=sharding_strategy,
|
|
195
|
-
auto_wrap_policy=partial(
|
|
227
|
+
auto_wrap_policy=partial(transformer_auto_wrap_policy, transformer_layer_cls=wrap_module_cls),
|
|
196
228
|
)
|
|
197
229
|
|
|
198
230
|
|
|
@@ -266,14 +298,15 @@ def _worker_loop(
|
|
|
266
298
|
world_size=world_size,
|
|
267
299
|
)
|
|
268
300
|
|
|
269
|
-
def wrap_for_parallel(module
|
|
270
|
-
if
|
|
271
|
-
for model_name in module
|
|
272
|
-
|
|
301
|
+
def wrap_for_parallel(module):
|
|
302
|
+
if hasattr(module, "model_names"):
|
|
303
|
+
for model_name in getattr(module, "model_names"):
|
|
304
|
+
submodule = getattr(module, model_name)
|
|
305
|
+
if getattr(submodule, "_supports_parallelization", False):
|
|
273
306
|
setattr(module, model_name, wrap_for_parallel(submodule))
|
|
274
307
|
return module
|
|
275
308
|
|
|
276
|
-
if not module
|
|
309
|
+
if not getattr(module, "_supports_parallelization", False):
|
|
277
310
|
return module
|
|
278
311
|
|
|
279
312
|
if tp_degree > 1:
|
|
@@ -283,7 +316,7 @@ def _worker_loop(
|
|
|
283
316
|
parallelize_plan=module.get_tp_plan(),
|
|
284
317
|
)
|
|
285
318
|
elif use_fsdp:
|
|
286
|
-
module = shard_model(module, device_id=device,
|
|
319
|
+
module = shard_model(module, device_id=device, wrap_module_cls=module.get_fsdp_module_cls())
|
|
287
320
|
return module
|
|
288
321
|
|
|
289
322
|
module = None
|
diffsynth_engine/utils/video.py
CHANGED
|
@@ -41,7 +41,9 @@ def save_video(frames, save_path, fps=15):
|
|
|
41
41
|
writer.write(frames, fps=fps, codec=codec)
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
def read_n_frames(
|
|
44
|
+
def read_n_frames(
|
|
45
|
+
frames: List[Image.Image], original_fps: int, n_frames: int, target_fps: int = 16
|
|
46
|
+
) -> List[Image.Image]:
|
|
45
47
|
num_frames = len(frames)
|
|
46
48
|
interval = max(1, round(original_fps / target_fps))
|
|
47
49
|
sampled_frames: List[Image.Image] = []
|