diffsynth-engine 0.6.1.dev28__py3-none-any.whl → 0.6.1.dev29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,10 +3,15 @@ import math
3
3
  import functools
4
4
 
5
5
  from diffsynth_engine.utils.flag import VIDEO_SPARSE_ATTN_AVAILABLE
6
- from diffsynth_engine.utils.parallel import get_sp_ulysses_group, get_sp_ring_world_size
6
+ from diffsynth_engine.utils.process_group import get_sp_ulysses_group, get_sp_ring_world_size
7
7
 
8
+
9
+ vsa_core = None
8
10
  if VIDEO_SPARSE_ATTN_AVAILABLE:
9
- from vsa import video_sparse_attn as vsa_core
11
+ try:
12
+ from vsa import video_sparse_attn as vsa_core
13
+ except Exception:
14
+ vsa_core = None
10
15
 
11
16
  VSA_TILE_SIZE = (4, 4, 4)
12
17
 
@@ -171,6 +176,12 @@ def video_sparse_attn(
171
176
  variable_block_sizes: torch.LongTensor,
172
177
  non_pad_index: torch.LongTensor,
173
178
  ):
179
+ if vsa_core is None:
180
+ raise RuntimeError(
181
+ "Video sparse attention (VSA) is not available. "
182
+ "Please install the 'vsa' package and ensure all its dependencies (including pytest) are installed."
183
+ )
184
+
174
185
  q = tile(q, num_tiles, tile_partition_indices, non_pad_index)
175
186
  k = tile(k, num_tiles, tile_partition_indices, non_pad_index)
176
187
  v = tile(v, num_tiles, tile_partition_indices, non_pad_index)
@@ -212,7 +223,8 @@ def distributed_video_sparse_attn(
212
223
  ):
213
224
  from yunchang.comm.all_to_all import SeqAllToAll4D
214
225
 
215
- assert get_sp_ring_world_size() == 1, "distributed video sparse attention requires ring degree to be 1"
226
+ ring_world_size = get_sp_ring_world_size()
227
+ assert ring_world_size == 1, "distributed video sparse attention requires ring degree to be 1"
216
228
  sp_ulysses_group = get_sp_ulysses_group()
217
229
 
218
230
  q = SeqAllToAll4D.apply(sp_ulysses_group, q, scatter_idx, gather_idx)
@@ -650,7 +650,7 @@ class WanVideoPipeline(BasePipeline):
650
650
  dit_type = "wan2.2-i2v-a14b"
651
651
  elif model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 16:
652
652
  dit_type = "wan2.2-t2v-a14b"
653
- elif model_state_dict["patch_embedding.weight"].shape[1] == 48:
653
+ elif has_any_key("patch_embedding.weight") and model_state_dict["patch_embedding.weight"].shape[1] == 48:
654
654
  dit_type = "wan2.2-ti2v-5b"
655
655
  elif has_any_key("img_emb.emb_pos", "condition_embedder.image_embedder.pos_embed"):
656
656
  dit_type = "wan2.1-flf2v-14b"
@@ -680,6 +680,30 @@ class WanVideoPipeline(BasePipeline):
680
680
  if config.attn_params is None:
681
681
  config.attn_params = VideoSparseAttentionParams(sparsity=0.9)
682
682
 
683
+ def update_weights(self, state_dicts: WanStateDicts) -> None:
684
+ is_dual_model_state_dict = (isinstance(state_dicts.model, dict) and
685
+ ("high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model))
686
+ is_dual_model_pipeline = self.dit2 is not None
687
+
688
+ if is_dual_model_state_dict != is_dual_model_pipeline:
689
+ raise ValueError(
690
+ f"Model structure mismatch: pipeline has {'dual' if is_dual_model_pipeline else 'single'} model "
691
+ f"but state_dict is for {'dual' if is_dual_model_state_dict else 'single'} model. "
692
+ f"Cannot update weights between WAN 2.1 (single model) and WAN 2.2 (dual model)."
693
+ )
694
+
695
+ if is_dual_model_state_dict:
696
+ if "high_noise_model" in state_dicts.model:
697
+ self.update_component(self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype)
698
+ if "low_noise_model" in state_dicts.model:
699
+ self.update_component(self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype)
700
+ else:
701
+ self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
702
+
703
+ self.update_component(self.text_encoder, state_dicts.t5, self.config.device, self.config.t5_dtype)
704
+ self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype)
705
+ self.update_component(self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype)
706
+
683
707
  def compile(self):
684
708
  self.dit.compile_repeated_blocks()
685
709
  if self.dit2 is not None:
@@ -21,117 +21,33 @@ from queue import Empty
21
21
  import diffsynth_engine.models.basic.attention as attention_ops
22
22
  from diffsynth_engine.utils.platform import empty_cache
23
23
  from diffsynth_engine.utils import logging
24
+ from diffsynth_engine.utils.process_group import (
25
+ PROCESS_GROUP,
26
+ get_cfg_group,
27
+ get_cfg_world_size,
28
+ get_cfg_rank,
29
+ get_cfg_ranks,
30
+ get_sp_group,
31
+ get_sp_world_size,
32
+ get_sp_rank,
33
+ get_sp_ranks,
34
+ get_sp_ulysses_group,
35
+ get_sp_ulysses_world_size,
36
+ get_sp_ulysses_rank,
37
+ get_sp_ulysses_ranks,
38
+ get_sp_ring_group,
39
+ get_sp_ring_world_size,
40
+ get_sp_ring_rank,
41
+ get_sp_ring_ranks,
42
+ get_tp_group,
43
+ get_tp_world_size,
44
+ get_tp_rank,
45
+ get_tp_ranks,
46
+ )
24
47
 
25
48
  logger = logging.get_logger(__name__)
26
49
 
27
50
 
28
- class Singleton:
29
- _instance = None
30
-
31
- def __new__(cls, *args, **kwargs):
32
- if not cls._instance:
33
- cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs)
34
- return cls._instance
35
-
36
-
37
- class ProcessGroupSingleton(Singleton):
38
- def __init__(self):
39
- self.CFG_GROUP: Optional[dist.ProcessGroup] = None
40
- self.SP_GROUP: Optional[dist.ProcessGroup] = None
41
- self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None
42
- self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None
43
- self.TP_GROUP: Optional[dist.ProcessGroup] = None
44
-
45
- self.CFG_RANKS: List[int] = []
46
- self.SP_RANKS: List[int] = []
47
- self.SP_ULYSSUES_RANKS: List[int] = []
48
- self.SP_RING_RANKS: List[int] = []
49
- self.TP_RANKS: List[int] = []
50
-
51
-
52
- PROCESS_GROUP = ProcessGroupSingleton()
53
-
54
-
55
- def get_cfg_group():
56
- return PROCESS_GROUP.CFG_GROUP
57
-
58
-
59
- def get_cfg_world_size():
60
- return PROCESS_GROUP.CFG_GROUP.size() if PROCESS_GROUP.CFG_GROUP is not None else 1
61
-
62
-
63
- def get_cfg_rank():
64
- return PROCESS_GROUP.CFG_GROUP.rank() if PROCESS_GROUP.CFG_GROUP is not None else 0
65
-
66
-
67
- def get_cfg_ranks():
68
- return PROCESS_GROUP.CFG_RANKS
69
-
70
-
71
- def get_sp_group():
72
- return PROCESS_GROUP.SP_GROUP
73
-
74
-
75
- def get_sp_world_size():
76
- return PROCESS_GROUP.SP_GROUP.size() if PROCESS_GROUP.SP_GROUP is not None else 1
77
-
78
-
79
- def get_sp_rank():
80
- return PROCESS_GROUP.SP_GROUP.rank() if PROCESS_GROUP.SP_GROUP is not None else 0
81
-
82
-
83
- def get_sp_ranks():
84
- return PROCESS_GROUP.SP_RANKS
85
-
86
-
87
- def get_sp_ulysses_group():
88
- return PROCESS_GROUP.SP_ULYSSUES_GROUP
89
-
90
-
91
- def get_sp_ulysses_world_size():
92
- return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1
93
-
94
-
95
- def get_sp_ulysses_rank():
96
- return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0
97
-
98
-
99
- def get_sp_ulysses_ranks():
100
- return PROCESS_GROUP.SP_ULYSSUES_RANKS
101
-
102
-
103
- def get_sp_ring_group():
104
- return PROCESS_GROUP.SP_RING_GROUP
105
-
106
-
107
- def get_sp_ring_world_size():
108
- return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1
109
-
110
-
111
- def get_sp_ring_rank():
112
- return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0
113
-
114
-
115
- def get_sp_ring_ranks():
116
- return PROCESS_GROUP.SP_RING_RANKS
117
-
118
-
119
- def get_tp_group():
120
- return PROCESS_GROUP.TP_GROUP
121
-
122
-
123
- def get_tp_world_size():
124
- return PROCESS_GROUP.TP_GROUP.size() if PROCESS_GROUP.TP_GROUP is not None else 1
125
-
126
-
127
- def get_tp_rank():
128
- return PROCESS_GROUP.TP_GROUP.rank() if PROCESS_GROUP.TP_GROUP is not None else 0
129
-
130
-
131
- def get_tp_ranks():
132
- return PROCESS_GROUP.TP_RANKS
133
-
134
-
135
51
  def init_parallel_pgs(
136
52
  cfg_degree: int = 1,
137
53
  sp_ulysses_degree: int = 1,
@@ -0,0 +1,149 @@
1
+ """
2
+ Process group management for distributed training.
3
+
4
+ This module provides singleton-based process group management for distributed training,
5
+ including support for CFG parallelism, sequence parallelism (Ulysses + Ring), and tensor parallelism.
6
+ """
7
+
8
+ import torch.distributed as dist
9
+ from typing import Optional, List
10
+
11
+
12
+ class Singleton:
13
+ _instance = None
14
+
15
+ def __new__(cls, *args, **kwargs):
16
+ if not cls._instance:
17
+ cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs)
18
+ return cls._instance
19
+
20
+
21
+ class ProcessGroupSingleton(Singleton):
22
+ def __init__(self):
23
+ if not hasattr(self, 'initialized'):
24
+ self.CFG_GROUP: Optional[dist.ProcessGroup] = None
25
+ self.SP_GROUP: Optional[dist.ProcessGroup] = None
26
+ self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None
27
+ self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None
28
+ self.TP_GROUP: Optional[dist.ProcessGroup] = None
29
+
30
+ self.CFG_RANKS: List[int] = []
31
+ self.SP_RANKS: List[int] = []
32
+ self.SP_ULYSSUES_RANKS: List[int] = []
33
+ self.SP_RING_RANKS: List[int] = []
34
+ self.TP_RANKS: List[int] = []
35
+
36
+ self.initialized = True
37
+
38
+
39
+ PROCESS_GROUP = ProcessGroupSingleton()
40
+
41
+
42
+ # CFG parallel group functions
43
+ def get_cfg_group():
44
+ return PROCESS_GROUP.CFG_GROUP
45
+
46
+
47
+ def get_cfg_world_size():
48
+ return PROCESS_GROUP.CFG_GROUP.size() if PROCESS_GROUP.CFG_GROUP is not None else 1
49
+
50
+
51
+ def get_cfg_rank():
52
+ return PROCESS_GROUP.CFG_GROUP.rank() if PROCESS_GROUP.CFG_GROUP is not None else 0
53
+
54
+
55
+ def get_cfg_ranks():
56
+ return PROCESS_GROUP.CFG_RANKS
57
+
58
+
59
+ # Sequence parallel group functions
60
+ def get_sp_group():
61
+ return PROCESS_GROUP.SP_GROUP
62
+
63
+
64
+ def get_sp_world_size():
65
+ return PROCESS_GROUP.SP_GROUP.size() if PROCESS_GROUP.SP_GROUP is not None else 1
66
+
67
+
68
+ def get_sp_rank():
69
+ return PROCESS_GROUP.SP_GROUP.rank() if PROCESS_GROUP.SP_GROUP is not None else 0
70
+
71
+
72
+ def get_sp_ranks():
73
+ return PROCESS_GROUP.SP_RANKS
74
+
75
+
76
+ # Sequence parallel Ulysses group functions
77
+ def get_sp_ulysses_group():
78
+ return PROCESS_GROUP.SP_ULYSSUES_GROUP
79
+
80
+
81
+ def get_sp_ulysses_world_size():
82
+ return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1
83
+
84
+
85
+ def get_sp_ulysses_rank():
86
+ return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0
87
+
88
+
89
+ def get_sp_ulysses_ranks():
90
+ return PROCESS_GROUP.SP_ULYSSUES_RANKS
91
+
92
+
93
+ # Sequence parallel Ring group functions
94
+ def get_sp_ring_group():
95
+ return PROCESS_GROUP.SP_RING_GROUP
96
+
97
+
98
+ def get_sp_ring_world_size():
99
+ return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1
100
+
101
+
102
+ def get_sp_ring_rank():
103
+ return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0
104
+
105
+
106
+ def get_sp_ring_ranks():
107
+ return PROCESS_GROUP.SP_RING_RANKS
108
+
109
+
110
+ # Tensor parallel group functions
111
+ def get_tp_group():
112
+ return PROCESS_GROUP.TP_GROUP
113
+
114
+
115
+ def get_tp_world_size():
116
+ return PROCESS_GROUP.TP_GROUP.size() if PROCESS_GROUP.TP_GROUP is not None else 1
117
+
118
+
119
+ def get_tp_rank():
120
+ return PROCESS_GROUP.TP_GROUP.rank() if PROCESS_GROUP.TP_GROUP is not None else 0
121
+
122
+
123
+ def get_tp_ranks():
124
+ return PROCESS_GROUP.TP_RANKS
125
+
126
+
127
+ __all__ = [
128
+ "PROCESS_GROUP",
129
+ "get_cfg_group",
130
+ "get_cfg_world_size",
131
+ "get_cfg_rank",
132
+ "get_cfg_ranks",
133
+ "get_sp_group",
134
+ "get_sp_world_size",
135
+ "get_sp_rank",
136
+ "get_sp_ranks",
137
+ "get_sp_ulysses_group",
138
+ "get_sp_ulysses_world_size",
139
+ "get_sp_ulysses_rank",
140
+ "get_sp_ulysses_ranks",
141
+ "get_sp_ring_group",
142
+ "get_sp_ring_world_size",
143
+ "get_sp_ring_rank",
144
+ "get_sp_ring_ranks",
145
+ "get_tp_group",
146
+ "get_tp_world_size",
147
+ "get_tp_rank",
148
+ "get_tp_ranks",
149
+ ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev28
3
+ Version: 0.6.1.dev29
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -93,7 +93,7 @@ diffsynth_engine/models/basic/relative_position_emb.py,sha256=rCXOweZMcayVnNUVvB
93
93
  diffsynth_engine/models/basic/timestep.py,sha256=WJODYqkSXEM0wcS42YkkfrGwxWt0e60zMTkDdUBQqBw,2810
94
94
  diffsynth_engine/models/basic/transformer_helper.py,sha256=6K7A5bVnN2bOoq6I0IQf7RJBhSZUP4jNf1n7NPGu8zA,5287
95
95
  diffsynth_engine/models/basic/unet_helper.py,sha256=4lN6F80Ubm6ip4dkLVmB-Og5-Y25Wduhs9Q8qjyzK6E,9044
96
- diffsynth_engine/models/basic/video_sparse_attention.py,sha256=iXA3sHDLWk1ns1lVCNbZdiaDu94kBIsw-9vrCGAll7g,7843
96
+ diffsynth_engine/models/basic/video_sparse_attention.py,sha256=GxDN6PTpA1rCoQaXUwSPgH4708bEzVI1qsD48WVDXLA,8201
97
97
  diffsynth_engine/models/flux/__init__.py,sha256=x0JoxL0CdiiVrY0BjkIrGinud7mcXecLleGO0km91XQ,686
98
98
  diffsynth_engine/models/flux/flux_controlnet.py,sha256=NvFKQIx0NldX5uUxdmYwuS2s-xaFRlKotiE6lr3-HRY,8018
99
99
  diffsynth_engine/models/flux/flux_dit.py,sha256=7sdV8KFQiHcK-8aqyvXBgC7E_-D9rcgBcnMXUq_AybI,23403
@@ -151,7 +151,7 @@ diffsynth_engine/pipelines/sd_image.py,sha256=nr-Nhsnomq8CsUqhTM3i2l2zG01YjwXdfR
151
151
  diffsynth_engine/pipelines/sdxl_image.py,sha256=v7ZACGPb6EcBunL6e5E9jynSQjE7GQx8etEV-ZLP91g,21704
152
152
  diffsynth_engine/pipelines/utils.py,sha256=HZbJHErNJS1DhlwJKvZ9dY7Kh8Zdlsw3zE2e88TYGRY,2277
153
153
  diffsynth_engine/pipelines/wan_s2v.py,sha256=QHlCLMqlmnp55iYm2mzg4qCq4jceRAP3Zt5Mubz3mAM,29384
154
- diffsynth_engine/pipelines/wan_video.py,sha256=rJq60LiaCoLq1PkqUzzrdvFkp6h73fc-ZUu0MiMQC-c,29668
154
+ diffsynth_engine/pipelines/wan_video.py,sha256=9xjSvQ4mlVEDdaL6QuUURj4iyxhJ2xABBphQjkfzK8s,31323
155
155
  diffsynth_engine/processor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
156
156
  diffsynth_engine/processor/canny_processor.py,sha256=hV30NlblTkEFUAmF_O-LJrNlGVM2SFrqq6okfF8VpOo,602
157
157
  diffsynth_engine/processor/depth_processor.py,sha256=dQvs3JsnyMbz4dyI9QoR8oO-mMFBFAgNvgqeCoaU5jk,1532
@@ -182,15 +182,16 @@ diffsynth_engine/utils/lock.py,sha256=1Ipgst9eEFfFdViAvD5bxdB6HnHHBcqWYOb__fGaPU
182
182
  diffsynth_engine/utils/logging.py,sha256=XB0xTT8PBN6btkOjFtOvjlrOCRVgDGT8PFAp1vmse28,467
183
183
  diffsynth_engine/utils/offload.py,sha256=94og79TIkxldwYUgZT3L4OVu1WBlE7gfVPvO2MRhm6c,3551
184
184
  diffsynth_engine/utils/onnx.py,sha256=jeWUudJHnESjuiEAHyUZYUZz7dCj34O9aGjHCe8yjWo,1149
185
- diffsynth_engine/utils/parallel.py,sha256=6T8oCTp-7Gb3qsgNRB2Bp3DF4eyx1FzvS6pFnEJbsek,19789
185
+ diffsynth_engine/utils/parallel.py,sha256=OBGsAK-3ncArRyMU1lea7tbYgxSdCucQvXheL3Ssl5M,17653
186
186
  diffsynth_engine/utils/platform.py,sha256=nbpG-XHJFRmYY6u_e7IBQ9Q6GyItrIkKf3VKuBPTUpY,627
187
+ diffsynth_engine/utils/process_group.py,sha256=P-X04a--Zb4M4kjc3DddmusrxCKqv8wiDGhXG4Al-rE,3783
187
188
  diffsynth_engine/utils/prompt.py,sha256=YItMchoVzsG6y-LB4vzzDUWrkhKRVlt1HfVhxZjSxMQ,280
188
189
  diffsynth_engine/utils/video.py,sha256=8FCaeqIdUsWMgWI_6SO9SPynsToGcLCQAVYFTc4CDhg,2200
189
190
  diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
190
191
  diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
191
192
  diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
192
- diffsynth_engine-0.6.1.dev28.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
193
- diffsynth_engine-0.6.1.dev28.dist-info/METADATA,sha256=2LB9DNq9Pf8-sOj9A7_7EzP88Fh--XGF-0hDSP94DPE,1164
194
- diffsynth_engine-0.6.1.dev28.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
195
- diffsynth_engine-0.6.1.dev28.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
196
- diffsynth_engine-0.6.1.dev28.dist-info/RECORD,,
193
+ diffsynth_engine-0.6.1.dev29.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
194
+ diffsynth_engine-0.6.1.dev29.dist-info/METADATA,sha256=8A5q0qhRMxeJi7IOvP3dcqk58BsgIBxy16ndlnDM_6I,1164
195
+ diffsynth_engine-0.6.1.dev29.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
196
+ diffsynth_engine-0.6.1.dev29.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
197
+ diffsynth_engine-0.6.1.dev29.dist-info/RECORD,,