diffsynth-engine 0.6.1.dev28__py3-none-any.whl → 0.6.1.dev30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/models/basic/video_sparse_attention.py +15 -3
- diffsynth_engine/pipelines/base.py +30 -9
- diffsynth_engine/pipelines/wan_video.py +25 -1
- diffsynth_engine/utils/parallel.py +23 -107
- diffsynth_engine/utils/process_group.py +149 -0
- {diffsynth_engine-0.6.1.dev28.dist-info → diffsynth_engine-0.6.1.dev30.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.6.1.dev28.dist-info → diffsynth_engine-0.6.1.dev30.dist-info}/RECORD +10 -9
- {diffsynth_engine-0.6.1.dev28.dist-info → diffsynth_engine-0.6.1.dev30.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.6.1.dev28.dist-info → diffsynth_engine-0.6.1.dev30.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.6.1.dev28.dist-info → diffsynth_engine-0.6.1.dev30.dist-info}/top_level.txt +0 -0
|
@@ -3,10 +3,15 @@ import math
|
|
|
3
3
|
import functools
|
|
4
4
|
|
|
5
5
|
from diffsynth_engine.utils.flag import VIDEO_SPARSE_ATTN_AVAILABLE
|
|
6
|
-
from diffsynth_engine.utils.
|
|
6
|
+
from diffsynth_engine.utils.process_group import get_sp_ulysses_group, get_sp_ring_world_size
|
|
7
7
|
|
|
8
|
+
|
|
9
|
+
vsa_core = None
|
|
8
10
|
if VIDEO_SPARSE_ATTN_AVAILABLE:
|
|
9
|
-
|
|
11
|
+
try:
|
|
12
|
+
from vsa import video_sparse_attn as vsa_core
|
|
13
|
+
except Exception:
|
|
14
|
+
vsa_core = None
|
|
10
15
|
|
|
11
16
|
VSA_TILE_SIZE = (4, 4, 4)
|
|
12
17
|
|
|
@@ -171,6 +176,12 @@ def video_sparse_attn(
|
|
|
171
176
|
variable_block_sizes: torch.LongTensor,
|
|
172
177
|
non_pad_index: torch.LongTensor,
|
|
173
178
|
):
|
|
179
|
+
if vsa_core is None:
|
|
180
|
+
raise RuntimeError(
|
|
181
|
+
"Video sparse attention (VSA) is not available. "
|
|
182
|
+
"Please install the 'vsa' package and ensure all its dependencies (including pytest) are installed."
|
|
183
|
+
)
|
|
184
|
+
|
|
174
185
|
q = tile(q, num_tiles, tile_partition_indices, non_pad_index)
|
|
175
186
|
k = tile(k, num_tiles, tile_partition_indices, non_pad_index)
|
|
176
187
|
v = tile(v, num_tiles, tile_partition_indices, non_pad_index)
|
|
@@ -212,7 +223,8 @@ def distributed_video_sparse_attn(
|
|
|
212
223
|
):
|
|
213
224
|
from yunchang.comm.all_to_all import SeqAllToAll4D
|
|
214
225
|
|
|
215
|
-
|
|
226
|
+
ring_world_size = get_sp_ring_world_size()
|
|
227
|
+
assert ring_world_size == 1, "distributed video sparse attention requires ring degree to be 1"
|
|
216
228
|
sp_ulysses_group = get_sp_ulysses_group()
|
|
217
229
|
|
|
218
230
|
q = SeqAllToAll4D.apply(sp_ulysses_group, q, scatter_idx, gather_idx)
|
|
@@ -74,9 +74,9 @@ class BasePipeline:
|
|
|
74
74
|
component.load_state_dict(state_dict, assign=True)
|
|
75
75
|
component.to(device=device, dtype=dtype, non_blocking=True)
|
|
76
76
|
|
|
77
|
-
def
|
|
77
|
+
def _load_lora_state_dicts(
|
|
78
78
|
self,
|
|
79
|
-
|
|
79
|
+
lora_state_dict_list: List[Tuple[Dict[str, torch.Tensor], Union[float, LoraConfig], str]],
|
|
80
80
|
fused: bool = True,
|
|
81
81
|
save_original_weight: bool = False,
|
|
82
82
|
lora_converter: Optional[LoRAStateDictConverter] = None,
|
|
@@ -84,29 +84,30 @@ class BasePipeline:
|
|
|
84
84
|
if not lora_converter:
|
|
85
85
|
lora_converter = self.lora_converter
|
|
86
86
|
|
|
87
|
-
for
|
|
87
|
+
for state_dict, lora_item, lora_name in lora_state_dict_list:
|
|
88
88
|
if isinstance(lora_item, float):
|
|
89
89
|
lora_scale = lora_item
|
|
90
90
|
scheduler_config = None
|
|
91
|
-
|
|
91
|
+
elif isinstance(lora_item, LoraConfig):
|
|
92
92
|
lora_scale = lora_item.scale
|
|
93
93
|
scheduler_config = lora_item.scheduler_config
|
|
94
|
+
else:
|
|
95
|
+
raise ValueError(f"lora_item must be float or LoraConfig, got {type(lora_item)}")
|
|
94
96
|
|
|
95
|
-
logger.info(f"loading lora from {
|
|
96
|
-
state_dict = load_file(lora_path, device=self.device)
|
|
97
|
+
logger.info(f"loading lora from state_dict '{lora_name}' with scale={lora_scale}")
|
|
97
98
|
|
|
98
99
|
if scheduler_config is not None:
|
|
99
100
|
self.apply_scheduler_config(scheduler_config)
|
|
100
101
|
logger.info(f"Applied scheduler args from LoraConfig: {scheduler_config}")
|
|
101
102
|
|
|
102
103
|
lora_state_dict = lora_converter.convert(state_dict)
|
|
103
|
-
for model_name,
|
|
104
|
+
for model_name, model_state_dict in lora_state_dict.items():
|
|
104
105
|
model = getattr(self, model_name)
|
|
105
106
|
lora_args = []
|
|
106
|
-
for key, param in
|
|
107
|
+
for key, param in model_state_dict.items():
|
|
107
108
|
lora_args.append(
|
|
108
109
|
{
|
|
109
|
-
"name":
|
|
110
|
+
"name": lora_name,
|
|
110
111
|
"key": key,
|
|
111
112
|
"scale": lora_scale,
|
|
112
113
|
"rank": param["rank"],
|
|
@@ -120,6 +121,26 @@ class BasePipeline:
|
|
|
120
121
|
)
|
|
121
122
|
model.load_loras(lora_args, fused=fused)
|
|
122
123
|
|
|
124
|
+
def load_loras(
|
|
125
|
+
self,
|
|
126
|
+
lora_list: List[Tuple[str, Union[float, LoraConfig]]],
|
|
127
|
+
fused: bool = True,
|
|
128
|
+
save_original_weight: bool = False,
|
|
129
|
+
lora_converter: Optional[LoRAStateDictConverter] = None,
|
|
130
|
+
):
|
|
131
|
+
lora_state_dict_list = []
|
|
132
|
+
for lora_path, lora_item in lora_list:
|
|
133
|
+
logger.info(f"loading lora from {lora_path}")
|
|
134
|
+
state_dict = load_file(lora_path, device=self.device)
|
|
135
|
+
lora_state_dict_list.append((state_dict, lora_item, lora_path))
|
|
136
|
+
|
|
137
|
+
self._load_lora_state_dicts(
|
|
138
|
+
lora_state_dict_list=lora_state_dict_list,
|
|
139
|
+
fused=fused,
|
|
140
|
+
save_original_weight=save_original_weight,
|
|
141
|
+
lora_converter=lora_converter,
|
|
142
|
+
)
|
|
143
|
+
|
|
123
144
|
def load_lora(self, path: str, scale: float, fused: bool = True, save_original_weight: bool = False):
|
|
124
145
|
self.load_loras([(path, scale)], fused, save_original_weight)
|
|
125
146
|
|
|
@@ -650,7 +650,7 @@ class WanVideoPipeline(BasePipeline):
|
|
|
650
650
|
dit_type = "wan2.2-i2v-a14b"
|
|
651
651
|
elif model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 16:
|
|
652
652
|
dit_type = "wan2.2-t2v-a14b"
|
|
653
|
-
elif model_state_dict["patch_embedding.weight"].shape[1] == 48:
|
|
653
|
+
elif has_any_key("patch_embedding.weight") and model_state_dict["patch_embedding.weight"].shape[1] == 48:
|
|
654
654
|
dit_type = "wan2.2-ti2v-5b"
|
|
655
655
|
elif has_any_key("img_emb.emb_pos", "condition_embedder.image_embedder.pos_embed"):
|
|
656
656
|
dit_type = "wan2.1-flf2v-14b"
|
|
@@ -680,6 +680,30 @@ class WanVideoPipeline(BasePipeline):
|
|
|
680
680
|
if config.attn_params is None:
|
|
681
681
|
config.attn_params = VideoSparseAttentionParams(sparsity=0.9)
|
|
682
682
|
|
|
683
|
+
def update_weights(self, state_dicts: WanStateDicts) -> None:
|
|
684
|
+
is_dual_model_state_dict = (isinstance(state_dicts.model, dict) and
|
|
685
|
+
("high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model))
|
|
686
|
+
is_dual_model_pipeline = self.dit2 is not None
|
|
687
|
+
|
|
688
|
+
if is_dual_model_state_dict != is_dual_model_pipeline:
|
|
689
|
+
raise ValueError(
|
|
690
|
+
f"Model structure mismatch: pipeline has {'dual' if is_dual_model_pipeline else 'single'} model "
|
|
691
|
+
f"but state_dict is for {'dual' if is_dual_model_state_dict else 'single'} model. "
|
|
692
|
+
f"Cannot update weights between WAN 2.1 (single model) and WAN 2.2 (dual model)."
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
if is_dual_model_state_dict:
|
|
696
|
+
if "high_noise_model" in state_dicts.model:
|
|
697
|
+
self.update_component(self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype)
|
|
698
|
+
if "low_noise_model" in state_dicts.model:
|
|
699
|
+
self.update_component(self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype)
|
|
700
|
+
else:
|
|
701
|
+
self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
|
|
702
|
+
|
|
703
|
+
self.update_component(self.text_encoder, state_dicts.t5, self.config.device, self.config.t5_dtype)
|
|
704
|
+
self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype)
|
|
705
|
+
self.update_component(self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype)
|
|
706
|
+
|
|
683
707
|
def compile(self):
|
|
684
708
|
self.dit.compile_repeated_blocks()
|
|
685
709
|
if self.dit2 is not None:
|
|
@@ -21,117 +21,33 @@ from queue import Empty
|
|
|
21
21
|
import diffsynth_engine.models.basic.attention as attention_ops
|
|
22
22
|
from diffsynth_engine.utils.platform import empty_cache
|
|
23
23
|
from diffsynth_engine.utils import logging
|
|
24
|
+
from diffsynth_engine.utils.process_group import (
|
|
25
|
+
PROCESS_GROUP,
|
|
26
|
+
get_cfg_group,
|
|
27
|
+
get_cfg_world_size,
|
|
28
|
+
get_cfg_rank,
|
|
29
|
+
get_cfg_ranks,
|
|
30
|
+
get_sp_group,
|
|
31
|
+
get_sp_world_size,
|
|
32
|
+
get_sp_rank,
|
|
33
|
+
get_sp_ranks,
|
|
34
|
+
get_sp_ulysses_group,
|
|
35
|
+
get_sp_ulysses_world_size,
|
|
36
|
+
get_sp_ulysses_rank,
|
|
37
|
+
get_sp_ulysses_ranks,
|
|
38
|
+
get_sp_ring_group,
|
|
39
|
+
get_sp_ring_world_size,
|
|
40
|
+
get_sp_ring_rank,
|
|
41
|
+
get_sp_ring_ranks,
|
|
42
|
+
get_tp_group,
|
|
43
|
+
get_tp_world_size,
|
|
44
|
+
get_tp_rank,
|
|
45
|
+
get_tp_ranks,
|
|
46
|
+
)
|
|
24
47
|
|
|
25
48
|
logger = logging.get_logger(__name__)
|
|
26
49
|
|
|
27
50
|
|
|
28
|
-
class Singleton:
|
|
29
|
-
_instance = None
|
|
30
|
-
|
|
31
|
-
def __new__(cls, *args, **kwargs):
|
|
32
|
-
if not cls._instance:
|
|
33
|
-
cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs)
|
|
34
|
-
return cls._instance
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
class ProcessGroupSingleton(Singleton):
|
|
38
|
-
def __init__(self):
|
|
39
|
-
self.CFG_GROUP: Optional[dist.ProcessGroup] = None
|
|
40
|
-
self.SP_GROUP: Optional[dist.ProcessGroup] = None
|
|
41
|
-
self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None
|
|
42
|
-
self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None
|
|
43
|
-
self.TP_GROUP: Optional[dist.ProcessGroup] = None
|
|
44
|
-
|
|
45
|
-
self.CFG_RANKS: List[int] = []
|
|
46
|
-
self.SP_RANKS: List[int] = []
|
|
47
|
-
self.SP_ULYSSUES_RANKS: List[int] = []
|
|
48
|
-
self.SP_RING_RANKS: List[int] = []
|
|
49
|
-
self.TP_RANKS: List[int] = []
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
PROCESS_GROUP = ProcessGroupSingleton()
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def get_cfg_group():
|
|
56
|
-
return PROCESS_GROUP.CFG_GROUP
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def get_cfg_world_size():
|
|
60
|
-
return PROCESS_GROUP.CFG_GROUP.size() if PROCESS_GROUP.CFG_GROUP is not None else 1
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def get_cfg_rank():
|
|
64
|
-
return PROCESS_GROUP.CFG_GROUP.rank() if PROCESS_GROUP.CFG_GROUP is not None else 0
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def get_cfg_ranks():
|
|
68
|
-
return PROCESS_GROUP.CFG_RANKS
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def get_sp_group():
|
|
72
|
-
return PROCESS_GROUP.SP_GROUP
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def get_sp_world_size():
|
|
76
|
-
return PROCESS_GROUP.SP_GROUP.size() if PROCESS_GROUP.SP_GROUP is not None else 1
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def get_sp_rank():
|
|
80
|
-
return PROCESS_GROUP.SP_GROUP.rank() if PROCESS_GROUP.SP_GROUP is not None else 0
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def get_sp_ranks():
|
|
84
|
-
return PROCESS_GROUP.SP_RANKS
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def get_sp_ulysses_group():
|
|
88
|
-
return PROCESS_GROUP.SP_ULYSSUES_GROUP
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
def get_sp_ulysses_world_size():
|
|
92
|
-
return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
def get_sp_ulysses_rank():
|
|
96
|
-
return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def get_sp_ulysses_ranks():
|
|
100
|
-
return PROCESS_GROUP.SP_ULYSSUES_RANKS
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
def get_sp_ring_group():
|
|
104
|
-
return PROCESS_GROUP.SP_RING_GROUP
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
def get_sp_ring_world_size():
|
|
108
|
-
return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
def get_sp_ring_rank():
|
|
112
|
-
return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def get_sp_ring_ranks():
|
|
116
|
-
return PROCESS_GROUP.SP_RING_RANKS
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def get_tp_group():
|
|
120
|
-
return PROCESS_GROUP.TP_GROUP
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def get_tp_world_size():
|
|
124
|
-
return PROCESS_GROUP.TP_GROUP.size() if PROCESS_GROUP.TP_GROUP is not None else 1
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
def get_tp_rank():
|
|
128
|
-
return PROCESS_GROUP.TP_GROUP.rank() if PROCESS_GROUP.TP_GROUP is not None else 0
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
def get_tp_ranks():
|
|
132
|
-
return PROCESS_GROUP.TP_RANKS
|
|
133
|
-
|
|
134
|
-
|
|
135
51
|
def init_parallel_pgs(
|
|
136
52
|
cfg_degree: int = 1,
|
|
137
53
|
sp_ulysses_degree: int = 1,
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Process group management for distributed training.
|
|
3
|
+
|
|
4
|
+
This module provides singleton-based process group management for distributed training,
|
|
5
|
+
including support for CFG parallelism, sequence parallelism (Ulysses + Ring), and tensor parallelism.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch.distributed as dist
|
|
9
|
+
from typing import Optional, List
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Singleton:
|
|
13
|
+
_instance = None
|
|
14
|
+
|
|
15
|
+
def __new__(cls, *args, **kwargs):
|
|
16
|
+
if not cls._instance:
|
|
17
|
+
cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs)
|
|
18
|
+
return cls._instance
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ProcessGroupSingleton(Singleton):
|
|
22
|
+
def __init__(self):
|
|
23
|
+
if not hasattr(self, 'initialized'):
|
|
24
|
+
self.CFG_GROUP: Optional[dist.ProcessGroup] = None
|
|
25
|
+
self.SP_GROUP: Optional[dist.ProcessGroup] = None
|
|
26
|
+
self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None
|
|
27
|
+
self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None
|
|
28
|
+
self.TP_GROUP: Optional[dist.ProcessGroup] = None
|
|
29
|
+
|
|
30
|
+
self.CFG_RANKS: List[int] = []
|
|
31
|
+
self.SP_RANKS: List[int] = []
|
|
32
|
+
self.SP_ULYSSUES_RANKS: List[int] = []
|
|
33
|
+
self.SP_RING_RANKS: List[int] = []
|
|
34
|
+
self.TP_RANKS: List[int] = []
|
|
35
|
+
|
|
36
|
+
self.initialized = True
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
PROCESS_GROUP = ProcessGroupSingleton()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# CFG parallel group functions
|
|
43
|
+
def get_cfg_group():
|
|
44
|
+
return PROCESS_GROUP.CFG_GROUP
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_cfg_world_size():
|
|
48
|
+
return PROCESS_GROUP.CFG_GROUP.size() if PROCESS_GROUP.CFG_GROUP is not None else 1
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_cfg_rank():
|
|
52
|
+
return PROCESS_GROUP.CFG_GROUP.rank() if PROCESS_GROUP.CFG_GROUP is not None else 0
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_cfg_ranks():
|
|
56
|
+
return PROCESS_GROUP.CFG_RANKS
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# Sequence parallel group functions
|
|
60
|
+
def get_sp_group():
|
|
61
|
+
return PROCESS_GROUP.SP_GROUP
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_sp_world_size():
|
|
65
|
+
return PROCESS_GROUP.SP_GROUP.size() if PROCESS_GROUP.SP_GROUP is not None else 1
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def get_sp_rank():
|
|
69
|
+
return PROCESS_GROUP.SP_GROUP.rank() if PROCESS_GROUP.SP_GROUP is not None else 0
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_sp_ranks():
|
|
73
|
+
return PROCESS_GROUP.SP_RANKS
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# Sequence parallel Ulysses group functions
|
|
77
|
+
def get_sp_ulysses_group():
|
|
78
|
+
return PROCESS_GROUP.SP_ULYSSUES_GROUP
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_sp_ulysses_world_size():
|
|
82
|
+
return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def get_sp_ulysses_rank():
|
|
86
|
+
return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def get_sp_ulysses_ranks():
|
|
90
|
+
return PROCESS_GROUP.SP_ULYSSUES_RANKS
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# Sequence parallel Ring group functions
|
|
94
|
+
def get_sp_ring_group():
|
|
95
|
+
return PROCESS_GROUP.SP_RING_GROUP
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_sp_ring_world_size():
|
|
99
|
+
return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def get_sp_ring_rank():
|
|
103
|
+
return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def get_sp_ring_ranks():
|
|
107
|
+
return PROCESS_GROUP.SP_RING_RANKS
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
# Tensor parallel group functions
|
|
111
|
+
def get_tp_group():
|
|
112
|
+
return PROCESS_GROUP.TP_GROUP
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_tp_world_size():
|
|
116
|
+
return PROCESS_GROUP.TP_GROUP.size() if PROCESS_GROUP.TP_GROUP is not None else 1
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def get_tp_rank():
|
|
120
|
+
return PROCESS_GROUP.TP_GROUP.rank() if PROCESS_GROUP.TP_GROUP is not None else 0
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def get_tp_ranks():
|
|
124
|
+
return PROCESS_GROUP.TP_RANKS
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
__all__ = [
|
|
128
|
+
"PROCESS_GROUP",
|
|
129
|
+
"get_cfg_group",
|
|
130
|
+
"get_cfg_world_size",
|
|
131
|
+
"get_cfg_rank",
|
|
132
|
+
"get_cfg_ranks",
|
|
133
|
+
"get_sp_group",
|
|
134
|
+
"get_sp_world_size",
|
|
135
|
+
"get_sp_rank",
|
|
136
|
+
"get_sp_ranks",
|
|
137
|
+
"get_sp_ulysses_group",
|
|
138
|
+
"get_sp_ulysses_world_size",
|
|
139
|
+
"get_sp_ulysses_rank",
|
|
140
|
+
"get_sp_ulysses_ranks",
|
|
141
|
+
"get_sp_ring_group",
|
|
142
|
+
"get_sp_ring_world_size",
|
|
143
|
+
"get_sp_ring_rank",
|
|
144
|
+
"get_sp_ring_ranks",
|
|
145
|
+
"get_tp_group",
|
|
146
|
+
"get_tp_world_size",
|
|
147
|
+
"get_tp_rank",
|
|
148
|
+
"get_tp_ranks",
|
|
149
|
+
]
|
|
@@ -93,7 +93,7 @@ diffsynth_engine/models/basic/relative_position_emb.py,sha256=rCXOweZMcayVnNUVvB
|
|
|
93
93
|
diffsynth_engine/models/basic/timestep.py,sha256=WJODYqkSXEM0wcS42YkkfrGwxWt0e60zMTkDdUBQqBw,2810
|
|
94
94
|
diffsynth_engine/models/basic/transformer_helper.py,sha256=6K7A5bVnN2bOoq6I0IQf7RJBhSZUP4jNf1n7NPGu8zA,5287
|
|
95
95
|
diffsynth_engine/models/basic/unet_helper.py,sha256=4lN6F80Ubm6ip4dkLVmB-Og5-Y25Wduhs9Q8qjyzK6E,9044
|
|
96
|
-
diffsynth_engine/models/basic/video_sparse_attention.py,sha256=
|
|
96
|
+
diffsynth_engine/models/basic/video_sparse_attention.py,sha256=GxDN6PTpA1rCoQaXUwSPgH4708bEzVI1qsD48WVDXLA,8201
|
|
97
97
|
diffsynth_engine/models/flux/__init__.py,sha256=x0JoxL0CdiiVrY0BjkIrGinud7mcXecLleGO0km91XQ,686
|
|
98
98
|
diffsynth_engine/models/flux/flux_controlnet.py,sha256=NvFKQIx0NldX5uUxdmYwuS2s-xaFRlKotiE6lr3-HRY,8018
|
|
99
99
|
diffsynth_engine/models/flux/flux_dit.py,sha256=7sdV8KFQiHcK-8aqyvXBgC7E_-D9rcgBcnMXUq_AybI,23403
|
|
@@ -143,7 +143,7 @@ diffsynth_engine/models/wan/wan_s2v_dit.py,sha256=j63ulcWLY4XGITOKUMGX292LtSEtP-
|
|
|
143
143
|
diffsynth_engine/models/wan/wan_text_encoder.py,sha256=OERlmwOqthAFPNnnT2sXJ4OjyyRmsRLx7VGp1zlBkLU,11021
|
|
144
144
|
diffsynth_engine/models/wan/wan_vae.py,sha256=dC7MoUFeXRL7SIY0LG1OOUiZW-pp9IbXCghutMxpXr4,38889
|
|
145
145
|
diffsynth_engine/pipelines/__init__.py,sha256=jh-4LSJ0vqlXiT8BgFgRIQxuAr2atEPyHrxXWj-Ud1U,604
|
|
146
|
-
diffsynth_engine/pipelines/base.py,sha256=
|
|
146
|
+
diffsynth_engine/pipelines/base.py,sha256=ShRiX5MY6bUkRKfuGrA1aalAqeHyeZxhzT87Mwc30b4,17231
|
|
147
147
|
diffsynth_engine/pipelines/flux_image.py,sha256=L0ggxpthLD8a5-zdPHu9z668uWBei9YzPb4PFVypDNU,50707
|
|
148
148
|
diffsynth_engine/pipelines/hunyuan3d_shape.py,sha256=TNV0Wr09Dj2bzzlpua9WioCClOj3YiLfE6utI9aWL8A,8164
|
|
149
149
|
diffsynth_engine/pipelines/qwen_image.py,sha256=ktOirdU2ljgb6vHhXosC0tWgXI3gwvsoAtrYKYvMwzI,35719
|
|
@@ -151,7 +151,7 @@ diffsynth_engine/pipelines/sd_image.py,sha256=nr-Nhsnomq8CsUqhTM3i2l2zG01YjwXdfR
|
|
|
151
151
|
diffsynth_engine/pipelines/sdxl_image.py,sha256=v7ZACGPb6EcBunL6e5E9jynSQjE7GQx8etEV-ZLP91g,21704
|
|
152
152
|
diffsynth_engine/pipelines/utils.py,sha256=HZbJHErNJS1DhlwJKvZ9dY7Kh8Zdlsw3zE2e88TYGRY,2277
|
|
153
153
|
diffsynth_engine/pipelines/wan_s2v.py,sha256=QHlCLMqlmnp55iYm2mzg4qCq4jceRAP3Zt5Mubz3mAM,29384
|
|
154
|
-
diffsynth_engine/pipelines/wan_video.py,sha256=
|
|
154
|
+
diffsynth_engine/pipelines/wan_video.py,sha256=9xjSvQ4mlVEDdaL6QuUURj4iyxhJ2xABBphQjkfzK8s,31323
|
|
155
155
|
diffsynth_engine/processor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
156
156
|
diffsynth_engine/processor/canny_processor.py,sha256=hV30NlblTkEFUAmF_O-LJrNlGVM2SFrqq6okfF8VpOo,602
|
|
157
157
|
diffsynth_engine/processor/depth_processor.py,sha256=dQvs3JsnyMbz4dyI9QoR8oO-mMFBFAgNvgqeCoaU5jk,1532
|
|
@@ -182,15 +182,16 @@ diffsynth_engine/utils/lock.py,sha256=1Ipgst9eEFfFdViAvD5bxdB6HnHHBcqWYOb__fGaPU
|
|
|
182
182
|
diffsynth_engine/utils/logging.py,sha256=XB0xTT8PBN6btkOjFtOvjlrOCRVgDGT8PFAp1vmse28,467
|
|
183
183
|
diffsynth_engine/utils/offload.py,sha256=94og79TIkxldwYUgZT3L4OVu1WBlE7gfVPvO2MRhm6c,3551
|
|
184
184
|
diffsynth_engine/utils/onnx.py,sha256=jeWUudJHnESjuiEAHyUZYUZz7dCj34O9aGjHCe8yjWo,1149
|
|
185
|
-
diffsynth_engine/utils/parallel.py,sha256=
|
|
185
|
+
diffsynth_engine/utils/parallel.py,sha256=OBGsAK-3ncArRyMU1lea7tbYgxSdCucQvXheL3Ssl5M,17653
|
|
186
186
|
diffsynth_engine/utils/platform.py,sha256=nbpG-XHJFRmYY6u_e7IBQ9Q6GyItrIkKf3VKuBPTUpY,627
|
|
187
|
+
diffsynth_engine/utils/process_group.py,sha256=P-X04a--Zb4M4kjc3DddmusrxCKqv8wiDGhXG4Al-rE,3783
|
|
187
188
|
diffsynth_engine/utils/prompt.py,sha256=YItMchoVzsG6y-LB4vzzDUWrkhKRVlt1HfVhxZjSxMQ,280
|
|
188
189
|
diffsynth_engine/utils/video.py,sha256=8FCaeqIdUsWMgWI_6SO9SPynsToGcLCQAVYFTc4CDhg,2200
|
|
189
190
|
diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
190
191
|
diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
|
|
191
192
|
diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
|
|
192
|
-
diffsynth_engine-0.6.1.
|
|
193
|
-
diffsynth_engine-0.6.1.
|
|
194
|
-
diffsynth_engine-0.6.1.
|
|
195
|
-
diffsynth_engine-0.6.1.
|
|
196
|
-
diffsynth_engine-0.6.1.
|
|
193
|
+
diffsynth_engine-0.6.1.dev30.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
|
|
194
|
+
diffsynth_engine-0.6.1.dev30.dist-info/METADATA,sha256=z-j4fdSyJwgilKYRl-MrSlhicE8MJP9uvoGYYTFrYKk,1164
|
|
195
|
+
diffsynth_engine-0.6.1.dev30.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
196
|
+
diffsynth_engine-0.6.1.dev30.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
|
|
197
|
+
diffsynth_engine-0.6.1.dev30.dist-info/RECORD,,
|
|
File without changes
|
{diffsynth_engine-0.6.1.dev28.dist-info → diffsynth_engine-0.6.1.dev30.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{diffsynth_engine-0.6.1.dev28.dist-info → diffsynth_engine-0.6.1.dev30.dist-info}/top_level.txt
RENAMED
|
File without changes
|