diffsynth-engine 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. diffsynth_engine/__init__.py +28 -0
  2. diffsynth_engine/algorithm/__init__.py +0 -0
  3. diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
  4. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
  5. diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
  6. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
  7. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
  8. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +50 -0
  9. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  10. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
  11. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +25 -0
  12. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
  13. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
  14. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
  15. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +17 -0
  16. diffsynth_engine/algorithm/sampler/__init__.py +19 -0
  17. diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  18. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
  19. diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  20. diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
  21. diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
  22. diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
  23. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
  24. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
  25. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
  26. diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
  27. diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
  28. diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
  29. diffsynth_engine/conf/models/components/vae.json +254 -0
  30. diffsynth_engine/conf/models/flux/flux_dit.json +105 -0
  31. diffsynth_engine/conf/models/flux/flux_text_encoder.json +20 -0
  32. diffsynth_engine/conf/models/flux/flux_vae.json +250 -0
  33. diffsynth_engine/conf/models/sd/sd_text_encoder.json +220 -0
  34. diffsynth_engine/conf/models/sd/sd_unet.json +397 -0
  35. diffsynth_engine/conf/models/sd3/sd3_dit.json +908 -0
  36. diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +756 -0
  37. diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +455 -0
  38. diffsynth_engine/conf/models/sdxl/sdxl_unet.json +1056 -0
  39. diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +13 -0
  40. diffsynth_engine/conf/models/wan/dit/14b-i2v.json +13 -0
  41. diffsynth_engine/conf/models/wan/dit/14b-t2v.json +13 -0
  42. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
  43. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
  44. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
  45. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
  46. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
  47. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  48. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
  49. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
  50. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
  51. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
  52. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
  53. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
  54. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
  55. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
  56. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
  57. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
  58. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
  59. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  60. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
  61. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
  62. diffsynth_engine/kernels/__init__.py +0 -0
  63. diffsynth_engine/models/__init__.py +7 -0
  64. diffsynth_engine/models/base.py +64 -0
  65. diffsynth_engine/models/basic/__init__.py +0 -0
  66. diffsynth_engine/models/basic/attention.py +217 -0
  67. diffsynth_engine/models/basic/lora.py +293 -0
  68. diffsynth_engine/models/basic/relative_position_emb.py +56 -0
  69. diffsynth_engine/models/basic/timestep.py +81 -0
  70. diffsynth_engine/models/basic/transformer_helper.py +88 -0
  71. diffsynth_engine/models/basic/unet_helper.py +244 -0
  72. diffsynth_engine/models/components/__init__.py +0 -0
  73. diffsynth_engine/models/components/clip.py +56 -0
  74. diffsynth_engine/models/components/t5.py +222 -0
  75. diffsynth_engine/models/components/vae.py +392 -0
  76. diffsynth_engine/models/flux/__init__.py +14 -0
  77. diffsynth_engine/models/flux/flux_dit.py +476 -0
  78. diffsynth_engine/models/flux/flux_text_encoder.py +88 -0
  79. diffsynth_engine/models/flux/flux_vae.py +78 -0
  80. diffsynth_engine/models/sd/__init__.py +12 -0
  81. diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
  82. diffsynth_engine/models/sd/sd_unet.py +293 -0
  83. diffsynth_engine/models/sd/sd_vae.py +38 -0
  84. diffsynth_engine/models/sd3/__init__.py +14 -0
  85. diffsynth_engine/models/sd3/sd3_dit.py +302 -0
  86. diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
  87. diffsynth_engine/models/sd3/sd3_vae.py +43 -0
  88. diffsynth_engine/models/sdxl/__init__.py +13 -0
  89. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
  90. diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
  91. diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
  92. diffsynth_engine/models/utils.py +54 -0
  93. diffsynth_engine/models/wan/__init__.py +0 -0
  94. diffsynth_engine/models/wan/wan_dit.py +497 -0
  95. diffsynth_engine/models/wan/wan_image_encoder.py +494 -0
  96. diffsynth_engine/models/wan/wan_text_encoder.py +297 -0
  97. diffsynth_engine/models/wan/wan_vae.py +771 -0
  98. diffsynth_engine/pipelines/__init__.py +18 -0
  99. diffsynth_engine/pipelines/base.py +253 -0
  100. diffsynth_engine/pipelines/flux_image.py +512 -0
  101. diffsynth_engine/pipelines/sd_image.py +352 -0
  102. diffsynth_engine/pipelines/sdxl_image.py +395 -0
  103. diffsynth_engine/pipelines/wan_video.py +524 -0
  104. diffsynth_engine/tokenizers/__init__.py +6 -0
  105. diffsynth_engine/tokenizers/base.py +157 -0
  106. diffsynth_engine/tokenizers/clip.py +288 -0
  107. diffsynth_engine/tokenizers/t5.py +194 -0
  108. diffsynth_engine/tokenizers/wan.py +74 -0
  109. diffsynth_engine/utils/__init__.py +0 -0
  110. diffsynth_engine/utils/constants.py +34 -0
  111. diffsynth_engine/utils/download.py +135 -0
  112. diffsynth_engine/utils/env.py +7 -0
  113. diffsynth_engine/utils/flag.py +46 -0
  114. diffsynth_engine/utils/fp8_linear.py +64 -0
  115. diffsynth_engine/utils/gguf.py +415 -0
  116. diffsynth_engine/utils/loader.py +17 -0
  117. diffsynth_engine/utils/lock.py +56 -0
  118. diffsynth_engine/utils/logging.py +12 -0
  119. diffsynth_engine/utils/offload.py +44 -0
  120. diffsynth_engine/utils/parallel.py +390 -0
  121. diffsynth_engine/utils/prompt.py +9 -0
  122. diffsynth_engine/utils/video.py +40 -0
  123. diffsynth_engine-0.0.0.dist-info/LICENSE +201 -0
  124. diffsynth_engine-0.0.0.dist-info/METADATA +236 -0
  125. diffsynth_engine-0.0.0.dist-info/RECORD +127 -0
  126. diffsynth_engine-0.0.0.dist-info/WHEEL +5 -0
  127. diffsynth_engine-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,56 @@
1
+ import math
2
+ import threading
3
+ from typing import Optional
4
+ from types import TracebackType
5
+ from flufl.lock import Lock
6
+
7
+ from diffsynth_engine.utils import logging
8
+
9
+ logger = logging.get_logger(__name__)
10
+
11
+
12
+ class HeartbeatFileLock:
13
+ def __init__(self, lock_file_path: str, heartbeat_interval: float = 10):
14
+ self.lock_file_path = lock_file_path
15
+ self.heartbeat_interval = heartbeat_interval
16
+ self.lifetime = math.ceil(heartbeat_interval + 1)
17
+ self.heartbeat_thread = None
18
+ self.stop_event = threading.Event()
19
+ self.lock = None
20
+
21
+ def _heartbeat(self):
22
+ while not self.stop_event.is_set():
23
+ self.lock.refresh(lifetime=self.lifetime)
24
+ self.stop_event.wait(self.heartbeat_interval - 1)
25
+
26
+ def acquire(self):
27
+ self.lock = Lock(self.lock_file_path, lifetime=self.lifetime)
28
+ self.lock.lock()
29
+
30
+ self.heartbeat_thread = threading.Thread(target=self._heartbeat)
31
+ self.heartbeat_thread.start()
32
+
33
+ def release(self):
34
+ if self.lock is not None:
35
+ self.lock.unlock(unconditionally=True)
36
+ self._release()
37
+
38
+ def _release(self):
39
+ if self.heartbeat_thread is not None:
40
+ self.stop_event.set()
41
+ self.heartbeat_thread.join()
42
+
43
+ def __enter__(self):
44
+ self.acquire()
45
+ return self
46
+
47
+ def __exit__(
48
+ self,
49
+ exc_type: Optional[type[BaseException]] = None,
50
+ exc_value: Optional[BaseException] = None,
51
+ traceback: Optional[TracebackType] = None,
52
+ ):
53
+ self._release()
54
+
55
+ def __del__(self):
56
+ self.release()
@@ -0,0 +1,12 @@
1
+ import logging
2
+ from typing import Optional
3
+
4
+ logging.basicConfig(
5
+ level=logging.INFO,
6
+ format="%(asctime)s - %(levelname)s - %(message)s",
7
+ datefmt="%Y-%m-%d %H:%M:%S",
8
+ )
9
+
10
+
11
+ def get_logger(name: Optional[str] = None) -> logging.Logger:
12
+ return logging.getLogger(name)
@@ -0,0 +1,44 @@
1
+ import torch.nn as nn
2
+
3
+ from diffsynth_engine.models.basic.transformer_helper import RMSNorm
4
+ from diffsynth_engine.models.basic.relative_position_emb import RelativePositionEmbedding
5
+
6
+
7
+ SUPPORTED_OFFLOAD_MODULES = (
8
+ nn.Embedding,
9
+ nn.Linear,
10
+ nn.LayerNorm,
11
+ nn.Conv2d,
12
+ nn.GroupNorm,
13
+ RMSNorm,
14
+ RelativePositionEmbedding,
15
+ )
16
+
17
+
18
+ def enable_sequential_cpu_offload(module: nn.Module, device: str = "cuda:0"):
19
+ if isinstance(module, SUPPORTED_OFFLOAD_MODULES):
20
+ add_cpu_offload_hook(module, device)
21
+ return
22
+ for submodule in module.children():
23
+ enable_sequential_cpu_offload(submodule, device)
24
+
25
+
26
+ def add_cpu_offload_hook(module: nn.Module, device: str = "cuda:0"):
27
+ def _forward_pre_hook(module: nn.Module, input):
28
+ offload_params = {}
29
+ for name, param in module.named_parameters():
30
+ offload_params[name] = param.data
31
+ param.data = param.data.to(device=device)
32
+ setattr(module, "_offload_params", offload_params)
33
+
34
+ def _forward_hook(module: nn.Module, input, output):
35
+ offload_params = getattr(module, "_offload_params", {})
36
+ for name, param in module.named_parameters():
37
+ if name in offload_params:
38
+ param.data = offload_params[name]
39
+
40
+ if getattr(module, "_sequential_cpu_offload_enabled", False):
41
+ return
42
+ module.register_forward_pre_hook(_forward_pre_hook)
43
+ module.register_forward_hook(_forward_hook)
44
+ setattr(module, "_sequential_cpu_offload_enabled", True)
@@ -0,0 +1,390 @@
1
+ import os
2
+ import copy
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.multiprocessing as mp
6
+ import torch.distributed as dist
7
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8
+ from torch.distributed.fsdp import ShardingStrategy
9
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
10
+ from torch.distributed.device_mesh import DeviceMesh
11
+ from torch.distributed.tensor.parallel.style import ParallelStyle
12
+ from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
13
+ from datetime import timedelta
14
+ from functools import partial
15
+ from typing import Callable, Dict, List, Union, Optional
16
+ from queue import Empty
17
+ from yunchang.globals import Singleton, set_seq_parallel_pg
18
+
19
+ from diffsynth_engine.utils import logging
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class ProcessGroupSingleton(Singleton):
25
+ def __init__(self):
26
+ self.CFG_GROUP: dist.ProcessGroup = None
27
+ self.SP_GROUP: dist.ProcessGroup = None
28
+ self.TP_GROUP: dist.ProcessGroup = None
29
+
30
+ self.CFG_RANKS: List[int] = []
31
+ self.SP_RANKS: List[int] = []
32
+ self.TP_RANKS: List[int] = []
33
+
34
+
35
+ PROCESS_GROUP = ProcessGroupSingleton()
36
+
37
+
38
+ def get_cfg_group():
39
+ return PROCESS_GROUP.CFG_GROUP
40
+
41
+
42
+ def get_cfg_world_size():
43
+ return PROCESS_GROUP.CFG_GROUP.size()
44
+
45
+
46
+ def get_cfg_rank():
47
+ return PROCESS_GROUP.CFG_GROUP.rank()
48
+
49
+
50
+ def get_cfg_ranks():
51
+ return PROCESS_GROUP.CFG_RANKS
52
+
53
+
54
+ def get_sp_group():
55
+ return PROCESS_GROUP.SP_GROUP
56
+
57
+
58
+ def get_sp_world_size():
59
+ return PROCESS_GROUP.SP_GROUP.size()
60
+
61
+
62
+ def get_sp_rank():
63
+ return PROCESS_GROUP.SP_GROUP.rank()
64
+
65
+
66
+ def get_sp_ranks():
67
+ return PROCESS_GROUP.SP_RANKS
68
+
69
+
70
+ def get_tp_group():
71
+ return PROCESS_GROUP.TP_GROUP
72
+
73
+
74
+ def get_tp_world_size():
75
+ return PROCESS_GROUP.TP_GROUP.size()
76
+
77
+
78
+ def get_tp_rank():
79
+ return PROCESS_GROUP.TP_GROUP.rank()
80
+
81
+
82
+ def get_tp_ranks():
83
+ return PROCESS_GROUP.TP_RANKS
84
+
85
+
86
+ def init_parallel_pgs(
87
+ cfg_degree: int = 1,
88
+ sp_ulysses_degree: int = 1,
89
+ sp_ring_degree: int = 1,
90
+ tp_degree: int = 1,
91
+ rank: int = 0,
92
+ world_size: int = 1,
93
+ ):
94
+ sp_degree = sp_ulysses_degree * sp_ring_degree
95
+
96
+ assert sp_degree == 1 or tp_degree == 1, "not allowed to enable sequence parallel and tensor parallel together"
97
+ assert world_size == cfg_degree * sp_degree * tp_degree, (
98
+ f"world_size ({world_size}) must be equal to cfg_degree ({cfg_degree}) * sp_degree ({sp_degree}) * tp_degree ({tp_degree})"
99
+ )
100
+
101
+ def make_parallel_groups(blocks: List[List[int]], degree: int):
102
+ groups, chunks = [], []
103
+ for block in blocks:
104
+ size = len(block) // degree
105
+ chunk = [block[i * size : (i + 1) * size] for i in range(degree)]
106
+ chunks.extend(chunk)
107
+ groups.extend(list(zip(*chunk)))
108
+ return groups, chunks
109
+
110
+ blocks = [list(range(world_size))]
111
+ cfg_groups, cfg_blocks = make_parallel_groups(blocks, cfg_degree)
112
+ for cfg_ranks in cfg_groups:
113
+ cfg_group = dist.new_group(cfg_ranks)
114
+ if rank in cfg_ranks:
115
+ PROCESS_GROUP.CFG_GROUP = cfg_group
116
+ PROCESS_GROUP.CFG_RANKS = cfg_ranks
117
+
118
+ sp_groups, sp_blocks = make_parallel_groups(cfg_blocks, sp_degree)
119
+ for sp_ranks in sp_groups:
120
+ group = dist.new_group(sp_ranks)
121
+ if rank in sp_ranks:
122
+ PROCESS_GROUP.SP_GROUP = group
123
+ PROCESS_GROUP.SP_RANKS = sp_ranks
124
+
125
+ tp_groups, _ = make_parallel_groups(sp_blocks, tp_degree)
126
+ for tp_ranks in tp_groups:
127
+ group = dist.new_group(tp_ranks)
128
+ if rank in tp_ranks:
129
+ PROCESS_GROUP.TP_GROUP = group
130
+ PROCESS_GROUP.TP_RANKS = tp_ranks
131
+
132
+ set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
133
+
134
+
135
+ def clone(data):
136
+ if isinstance(data, dict):
137
+ return {k: clone(v) for k, v in data.items()}
138
+ if isinstance(data, tuple) or isinstance(data, list):
139
+ return [clone(t) for t in data]
140
+ elif isinstance(data, torch.Tensor):
141
+ return data.clone()
142
+ else:
143
+ return copy.deepcopy(data)
144
+
145
+
146
+ def to_device(data, device):
147
+ if isinstance(data, dict):
148
+ return {k: to_device(v, device) for k, v in data.items()}
149
+ if isinstance(data, tuple) or isinstance(data, list):
150
+ return [to_device(t, device) for t in data]
151
+ elif isinstance(data, torch.Tensor):
152
+ return data.to(device)
153
+ else:
154
+ return data
155
+
156
+
157
+ def split_and_get(data, num, dim, index):
158
+ if isinstance(data, dict):
159
+ return {k: split_and_get(v, num, dim, index) for k, v in data.items()}
160
+ if isinstance(data, tuple) or isinstance(data, list):
161
+ return [split_and_get(t, num, dim, index) for t in data]
162
+ if isinstance(data, torch.Tensor):
163
+ if data.shape[dim] < num:
164
+ raise ValueError(f"data.shape[{dim}] ({data.shape[dim]}) < num ({num}), split failed")
165
+ return torch.split(data, data.shape[dim] // num, dim)[index]
166
+ return data
167
+
168
+
169
+ def shard_model(
170
+ module: nn.Module,
171
+ device_id: int,
172
+ sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD,
173
+ wrap_module_names: Optional[List[str]] = None,
174
+ ):
175
+ wrap_module_names = wrap_module_names or []
176
+
177
+ def wrap_fn(m):
178
+ for name in wrap_module_names:
179
+ submodule = getattr(module, name)
180
+ if isinstance(submodule, nn.ModuleList) and m in submodule:
181
+ return True
182
+ elif not isinstance(submodule, nn.ModuleList) and m is submodule:
183
+ return True
184
+ return False
185
+
186
+ return FSDP(
187
+ module,
188
+ device_id=device_id,
189
+ sharding_strategy=sharding_strategy,
190
+ auto_wrap_policy=partial(lambda_auto_wrap_policy, lambda_fn=wrap_fn),
191
+ )
192
+
193
+
194
+ def parallelize_module(
195
+ module: nn.Module,
196
+ device_mesh: DeviceMesh,
197
+ parallelize_plan: Optional[Union[ParallelStyle, Dict[str, ParallelStyle]]] = None,
198
+ ):
199
+ _validate_tp_mesh_dim(device_mesh)
200
+ if parallelize_plan is None:
201
+ return module
202
+ if isinstance(parallelize_plan, ParallelStyle):
203
+ return parallelize_plan._apply(module, device_mesh)
204
+ for module_path, parallelize_style in parallelize_plan.items():
205
+ if module_path.strip() == "":
206
+ raise ValueError("Expect module path to be non-empty, but got empty string!")
207
+ try:
208
+ submodule = module.get_submodule(module_path)
209
+ parallelize_style._apply(submodule, device_mesh)
210
+ except AttributeError:
211
+ continue
212
+ return module
213
+
214
+
215
+ NCCL_TIMEOUT_SEC = int(os.environ.get("NCCL_TIMEOUT_SEC", 600))
216
+ PARALLEL_FWD_TIMEOUT_SEC = int(os.environ.get("PARALLEL_FWD_TIMEOUT_SEC", 300))
217
+ PARALLEL_LORA_TIMEOUT_SEC = int(os.environ.get("PARALLEL_LORA_TIMEOUT_SEC ", 60))
218
+
219
+
220
+ def _worker_loop(
221
+ rank: int,
222
+ world_size: int,
223
+ queue_in: mp.Queue,
224
+ queue_out: mp.Queue,
225
+ module: nn.Module,
226
+ cfg_degree: int,
227
+ sp_ulysses_degree: int,
228
+ sp_ring_degree: int,
229
+ tp_degree: int,
230
+ shard_fn: Optional[Callable] = None,
231
+ master_port: int = 29500,
232
+ device: str = "cuda",
233
+ ):
234
+ """
235
+ https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors
236
+ """
237
+ try:
238
+ os.environ["RANK"] = str(rank)
239
+ os.environ["WORLD_SIZE"] = str(world_size)
240
+ os.environ["MASTER_ADDR"] = "localhost"
241
+ os.environ["MASTER_PORT"] = str(master_port)
242
+ torch.cuda.set_device(rank)
243
+
244
+ timeout = timedelta(seconds=NCCL_TIMEOUT_SEC)
245
+ dist.init_process_group(
246
+ backend="nccl",
247
+ init_method="env://",
248
+ timeout=timeout,
249
+ world_size=world_size,
250
+ rank=rank,
251
+ )
252
+ init_parallel_pgs(
253
+ cfg_degree=cfg_degree,
254
+ sp_ulysses_degree=sp_ulysses_degree,
255
+ sp_ring_degree=sp_ring_degree,
256
+ tp_degree=tp_degree,
257
+ rank=rank,
258
+ world_size=world_size,
259
+ )
260
+
261
+ if tp_degree > 1:
262
+ module = parallelize_module(
263
+ module=module,
264
+ device_mesh=DeviceMesh(device, torch.tensor(get_tp_ranks())),
265
+ parallelize_plan=module.get_tp_plan(),
266
+ ).to(device)
267
+ elif shard_fn:
268
+ module = shard_fn(module=module, device_id=rank)
269
+ else:
270
+ module = module.to(device)
271
+
272
+ while True:
273
+ if rank == 0:
274
+ kwargs = queue_in.get()
275
+ data = [kwargs]
276
+ else:
277
+ data = [None]
278
+ dist.broadcast_object_list(data, src=0)
279
+ kwargs = clone(data[0])
280
+ del data
281
+
282
+ y = None
283
+ if kwargs.get("method", None) == "load_loras":
284
+ module.load_loras(lora_args=kwargs["lora_args"], fused=kwargs["fused"])
285
+ elif kwargs.get("method", None) == "unload_loras":
286
+ module.unload_loras()
287
+ else:
288
+ kwargs = to_device(kwargs, device)
289
+ kwargs = split_and_get(kwargs, get_cfg_world_size(), 0, get_cfg_rank())
290
+ with torch.no_grad():
291
+ y = module(**kwargs)
292
+ if get_sp_rank() == 0 and get_tp_rank() == 0:
293
+ gathered = torch.zeros((get_cfg_world_size(), *y.shape[1:]), dtype=y.dtype, device=y.device)
294
+ dist.all_gather_into_tensor(gathered, y, group=get_cfg_group())
295
+ y = gathered
296
+
297
+ if rank == 0:
298
+ queue_out.put(y)
299
+ dist.barrier()
300
+ except Exception as e:
301
+ import traceback
302
+
303
+ traceback.print_exc()
304
+ logger.error(f"Error in worker loop (rank {rank}): {e}")
305
+ finally:
306
+ del module
307
+ torch.cuda.synchronize()
308
+ torch.cuda.empty_cache()
309
+ dist.destroy_process_group()
310
+
311
+
312
+ class ParallelModel(nn.Module):
313
+ def __init__(
314
+ self,
315
+ module: nn.Module,
316
+ cfg_degree: int,
317
+ sp_ulysses_degree: int,
318
+ sp_ring_degree: int,
319
+ tp_degree: int,
320
+ shard_fn: Optional[Callable] = None,
321
+ master_port: int = 29500,
322
+ device: str = "cuda",
323
+ ):
324
+ super().__init__()
325
+ self.world_size = cfg_degree * sp_ulysses_degree * sp_ring_degree * tp_degree
326
+ self.device = device
327
+ self.queue_in = mp.Queue()
328
+ self.queue_out = mp.Queue()
329
+ self.ctx = mp.spawn(
330
+ _worker_loop,
331
+ args=(
332
+ self.world_size,
333
+ self.queue_in,
334
+ self.queue_out,
335
+ module,
336
+ cfg_degree,
337
+ sp_ulysses_degree,
338
+ sp_ring_degree,
339
+ tp_degree,
340
+ shard_fn,
341
+ master_port,
342
+ device,
343
+ ),
344
+ nprocs=self.world_size,
345
+ join=False,
346
+ )
347
+
348
+ def load_loras(self, lora_args: List[Dict[str, any]], fused: bool = True):
349
+ self.queue_in.put(
350
+ {
351
+ "method": "load_loras",
352
+ "lora_args": lora_args,
353
+ "fused": fused,
354
+ }
355
+ )
356
+ try:
357
+ _ = self.queue_out.get(timeout=PARALLEL_LORA_TIMEOUT_SEC)
358
+ except Empty:
359
+ logger.error("Parallel model load LoRA timeout")
360
+ raise RuntimeError("Parallel model load LoRA timeout")
361
+ logger.info("Parallel model load LoRA done")
362
+
363
+ def unload_loras(self):
364
+ self.queue_in.put({"method": "unload_loras"})
365
+ try:
366
+ _ = self.queue_out.get(timeout=PARALLEL_LORA_TIMEOUT_SEC)
367
+ except Empty:
368
+ logger.error("Parallel model unload LoRA timeout")
369
+ raise RuntimeError("Parallel model unload LoRA timeout")
370
+ logger.info("Parallel model unload LoRA done")
371
+
372
+ def forward(self, **kwargs):
373
+ self.queue_in.put(kwargs)
374
+ try:
375
+ y = self.queue_out.get(timeout=PARALLEL_FWD_TIMEOUT_SEC)
376
+ except Empty:
377
+ logger.error("Parallel model forward timeout")
378
+ raise RuntimeError("Parallel model forward timeout")
379
+ return y
380
+
381
+ def __del__(self):
382
+ # Send terminate signal to all workers
383
+ for p in self.ctx.processes:
384
+ p.terminate()
385
+ p.join()
386
+ self.queue_in.close()
387
+ self.queue_out.close()
388
+
389
+
390
+ __all__ = ["ParallelModel"]
@@ -0,0 +1,9 @@
1
+ from typing import Union, Optional
2
+
3
+ from diffsynth_engine.tokenizers import CLIPTokenizer, T5TokenizerFast
4
+
5
+
6
+ def tokenize_long_prompt(
7
+ tokenizer: Union[CLIPTokenizer, T5TokenizerFast], prompt: str, max_length: Optional[int] = None
8
+ ):
9
+ return tokenizer(prompt)["input_ids"]
@@ -0,0 +1,40 @@
1
+ import imageio
2
+ import imageio.v3 as iio
3
+ import numpy as np
4
+ from PIL import Image
5
+ from typing import List
6
+
7
+
8
+ class VideoReader:
9
+ def __init__(self, path: str):
10
+ self.reader = imageio.get_reader(path)
11
+
12
+ def __len__(self):
13
+ return self.reader.count_frames()
14
+
15
+ def __getitem__(self, item):
16
+ return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
17
+
18
+ def __del__(self):
19
+ self.reader.close()
20
+
21
+ @property
22
+ def frames(self) -> List[Image.Image]:
23
+ return [self[i] for i in range(len(self))]
24
+
25
+
26
+ def load_video(path: str) -> VideoReader:
27
+ return VideoReader(path)
28
+
29
+
30
+ def save_video(frames, save_path, fps=15):
31
+ if save_path.endswith(".webm"):
32
+ codec = "libvpx-vp9"
33
+ elif save_path.endswith(".mp4"):
34
+ codec = "libx264"
35
+
36
+ frames = [np.array(img) for img in frames]
37
+
38
+ # 使用 imageio 写入 .webm 文件
39
+ with iio.imopen(save_path, "w", plugin="FFMPEG") as writer:
40
+ writer.write(frames, fps=fps, codec=codec)