diffsynth-engine 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +28 -0
- diffsynth_engine/algorithm/__init__.py +0 -0
- diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
- diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +50 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +25 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +17 -0
- diffsynth_engine/algorithm/sampler/__init__.py +19 -0
- diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
- diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
- diffsynth_engine/conf/models/components/vae.json +254 -0
- diffsynth_engine/conf/models/flux/flux_dit.json +105 -0
- diffsynth_engine/conf/models/flux/flux_text_encoder.json +20 -0
- diffsynth_engine/conf/models/flux/flux_vae.json +250 -0
- diffsynth_engine/conf/models/sd/sd_text_encoder.json +220 -0
- diffsynth_engine/conf/models/sd/sd_unet.json +397 -0
- diffsynth_engine/conf/models/sd3/sd3_dit.json +908 -0
- diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +756 -0
- diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +455 -0
- diffsynth_engine/conf/models/sdxl/sdxl_unet.json +1056 -0
- diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +13 -0
- diffsynth_engine/conf/models/wan/dit/14b-i2v.json +13 -0
- diffsynth_engine/conf/models/wan/dit/14b-t2v.json +13 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
- diffsynth_engine/kernels/__init__.py +0 -0
- diffsynth_engine/models/__init__.py +7 -0
- diffsynth_engine/models/base.py +64 -0
- diffsynth_engine/models/basic/__init__.py +0 -0
- diffsynth_engine/models/basic/attention.py +217 -0
- diffsynth_engine/models/basic/lora.py +293 -0
- diffsynth_engine/models/basic/relative_position_emb.py +56 -0
- diffsynth_engine/models/basic/timestep.py +81 -0
- diffsynth_engine/models/basic/transformer_helper.py +88 -0
- diffsynth_engine/models/basic/unet_helper.py +244 -0
- diffsynth_engine/models/components/__init__.py +0 -0
- diffsynth_engine/models/components/clip.py +56 -0
- diffsynth_engine/models/components/t5.py +222 -0
- diffsynth_engine/models/components/vae.py +392 -0
- diffsynth_engine/models/flux/__init__.py +14 -0
- diffsynth_engine/models/flux/flux_dit.py +476 -0
- diffsynth_engine/models/flux/flux_text_encoder.py +88 -0
- diffsynth_engine/models/flux/flux_vae.py +78 -0
- diffsynth_engine/models/sd/__init__.py +12 -0
- diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
- diffsynth_engine/models/sd/sd_unet.py +293 -0
- diffsynth_engine/models/sd/sd_vae.py +38 -0
- diffsynth_engine/models/sd3/__init__.py +14 -0
- diffsynth_engine/models/sd3/sd3_dit.py +302 -0
- diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
- diffsynth_engine/models/sd3/sd3_vae.py +43 -0
- diffsynth_engine/models/sdxl/__init__.py +13 -0
- diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
- diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
- diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
- diffsynth_engine/models/utils.py +54 -0
- diffsynth_engine/models/wan/__init__.py +0 -0
- diffsynth_engine/models/wan/wan_dit.py +497 -0
- diffsynth_engine/models/wan/wan_image_encoder.py +494 -0
- diffsynth_engine/models/wan/wan_text_encoder.py +297 -0
- diffsynth_engine/models/wan/wan_vae.py +771 -0
- diffsynth_engine/pipelines/__init__.py +18 -0
- diffsynth_engine/pipelines/base.py +253 -0
- diffsynth_engine/pipelines/flux_image.py +512 -0
- diffsynth_engine/pipelines/sd_image.py +352 -0
- diffsynth_engine/pipelines/sdxl_image.py +395 -0
- diffsynth_engine/pipelines/wan_video.py +524 -0
- diffsynth_engine/tokenizers/__init__.py +6 -0
- diffsynth_engine/tokenizers/base.py +157 -0
- diffsynth_engine/tokenizers/clip.py +288 -0
- diffsynth_engine/tokenizers/t5.py +194 -0
- diffsynth_engine/tokenizers/wan.py +74 -0
- diffsynth_engine/utils/__init__.py +0 -0
- diffsynth_engine/utils/constants.py +34 -0
- diffsynth_engine/utils/download.py +135 -0
- diffsynth_engine/utils/env.py +7 -0
- diffsynth_engine/utils/flag.py +46 -0
- diffsynth_engine/utils/fp8_linear.py +64 -0
- diffsynth_engine/utils/gguf.py +415 -0
- diffsynth_engine/utils/loader.py +17 -0
- diffsynth_engine/utils/lock.py +56 -0
- diffsynth_engine/utils/logging.py +12 -0
- diffsynth_engine/utils/offload.py +44 -0
- diffsynth_engine/utils/parallel.py +390 -0
- diffsynth_engine/utils/prompt.py +9 -0
- diffsynth_engine/utils/video.py +40 -0
- diffsynth_engine-0.0.0.dist-info/LICENSE +201 -0
- diffsynth_engine-0.0.0.dist-info/METADATA +236 -0
- diffsynth_engine-0.0.0.dist-info/RECORD +127 -0
- diffsynth_engine-0.0.0.dist-info/WHEEL +5 -0
- diffsynth_engine-0.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import threading
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from types import TracebackType
|
|
5
|
+
from flufl.lock import Lock
|
|
6
|
+
|
|
7
|
+
from diffsynth_engine.utils import logging
|
|
8
|
+
|
|
9
|
+
logger = logging.get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class HeartbeatFileLock:
|
|
13
|
+
def __init__(self, lock_file_path: str, heartbeat_interval: float = 10):
|
|
14
|
+
self.lock_file_path = lock_file_path
|
|
15
|
+
self.heartbeat_interval = heartbeat_interval
|
|
16
|
+
self.lifetime = math.ceil(heartbeat_interval + 1)
|
|
17
|
+
self.heartbeat_thread = None
|
|
18
|
+
self.stop_event = threading.Event()
|
|
19
|
+
self.lock = None
|
|
20
|
+
|
|
21
|
+
def _heartbeat(self):
|
|
22
|
+
while not self.stop_event.is_set():
|
|
23
|
+
self.lock.refresh(lifetime=self.lifetime)
|
|
24
|
+
self.stop_event.wait(self.heartbeat_interval - 1)
|
|
25
|
+
|
|
26
|
+
def acquire(self):
|
|
27
|
+
self.lock = Lock(self.lock_file_path, lifetime=self.lifetime)
|
|
28
|
+
self.lock.lock()
|
|
29
|
+
|
|
30
|
+
self.heartbeat_thread = threading.Thread(target=self._heartbeat)
|
|
31
|
+
self.heartbeat_thread.start()
|
|
32
|
+
|
|
33
|
+
def release(self):
|
|
34
|
+
if self.lock is not None:
|
|
35
|
+
self.lock.unlock(unconditionally=True)
|
|
36
|
+
self._release()
|
|
37
|
+
|
|
38
|
+
def _release(self):
|
|
39
|
+
if self.heartbeat_thread is not None:
|
|
40
|
+
self.stop_event.set()
|
|
41
|
+
self.heartbeat_thread.join()
|
|
42
|
+
|
|
43
|
+
def __enter__(self):
|
|
44
|
+
self.acquire()
|
|
45
|
+
return self
|
|
46
|
+
|
|
47
|
+
def __exit__(
|
|
48
|
+
self,
|
|
49
|
+
exc_type: Optional[type[BaseException]] = None,
|
|
50
|
+
exc_value: Optional[BaseException] = None,
|
|
51
|
+
traceback: Optional[TracebackType] = None,
|
|
52
|
+
):
|
|
53
|
+
self._release()
|
|
54
|
+
|
|
55
|
+
def __del__(self):
|
|
56
|
+
self.release()
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
logging.basicConfig(
|
|
5
|
+
level=logging.INFO,
|
|
6
|
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
7
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
|
12
|
+
return logging.getLogger(name)
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
|
|
3
|
+
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
|
|
4
|
+
from diffsynth_engine.models.basic.relative_position_emb import RelativePositionEmbedding
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
SUPPORTED_OFFLOAD_MODULES = (
|
|
8
|
+
nn.Embedding,
|
|
9
|
+
nn.Linear,
|
|
10
|
+
nn.LayerNorm,
|
|
11
|
+
nn.Conv2d,
|
|
12
|
+
nn.GroupNorm,
|
|
13
|
+
RMSNorm,
|
|
14
|
+
RelativePositionEmbedding,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def enable_sequential_cpu_offload(module: nn.Module, device: str = "cuda:0"):
|
|
19
|
+
if isinstance(module, SUPPORTED_OFFLOAD_MODULES):
|
|
20
|
+
add_cpu_offload_hook(module, device)
|
|
21
|
+
return
|
|
22
|
+
for submodule in module.children():
|
|
23
|
+
enable_sequential_cpu_offload(submodule, device)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def add_cpu_offload_hook(module: nn.Module, device: str = "cuda:0"):
|
|
27
|
+
def _forward_pre_hook(module: nn.Module, input):
|
|
28
|
+
offload_params = {}
|
|
29
|
+
for name, param in module.named_parameters():
|
|
30
|
+
offload_params[name] = param.data
|
|
31
|
+
param.data = param.data.to(device=device)
|
|
32
|
+
setattr(module, "_offload_params", offload_params)
|
|
33
|
+
|
|
34
|
+
def _forward_hook(module: nn.Module, input, output):
|
|
35
|
+
offload_params = getattr(module, "_offload_params", {})
|
|
36
|
+
for name, param in module.named_parameters():
|
|
37
|
+
if name in offload_params:
|
|
38
|
+
param.data = offload_params[name]
|
|
39
|
+
|
|
40
|
+
if getattr(module, "_sequential_cpu_offload_enabled", False):
|
|
41
|
+
return
|
|
42
|
+
module.register_forward_pre_hook(_forward_pre_hook)
|
|
43
|
+
module.register_forward_hook(_forward_hook)
|
|
44
|
+
setattr(module, "_sequential_cpu_offload_enabled", True)
|
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import copy
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.multiprocessing as mp
|
|
6
|
+
import torch.distributed as dist
|
|
7
|
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
8
|
+
from torch.distributed.fsdp import ShardingStrategy
|
|
9
|
+
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
|
10
|
+
from torch.distributed.device_mesh import DeviceMesh
|
|
11
|
+
from torch.distributed.tensor.parallel.style import ParallelStyle
|
|
12
|
+
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
|
|
13
|
+
from datetime import timedelta
|
|
14
|
+
from functools import partial
|
|
15
|
+
from typing import Callable, Dict, List, Union, Optional
|
|
16
|
+
from queue import Empty
|
|
17
|
+
from yunchang.globals import Singleton, set_seq_parallel_pg
|
|
18
|
+
|
|
19
|
+
from diffsynth_engine.utils import logging
|
|
20
|
+
|
|
21
|
+
logger = logging.get_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ProcessGroupSingleton(Singleton):
|
|
25
|
+
def __init__(self):
|
|
26
|
+
self.CFG_GROUP: dist.ProcessGroup = None
|
|
27
|
+
self.SP_GROUP: dist.ProcessGroup = None
|
|
28
|
+
self.TP_GROUP: dist.ProcessGroup = None
|
|
29
|
+
|
|
30
|
+
self.CFG_RANKS: List[int] = []
|
|
31
|
+
self.SP_RANKS: List[int] = []
|
|
32
|
+
self.TP_RANKS: List[int] = []
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
PROCESS_GROUP = ProcessGroupSingleton()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_cfg_group():
|
|
39
|
+
return PROCESS_GROUP.CFG_GROUP
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_cfg_world_size():
|
|
43
|
+
return PROCESS_GROUP.CFG_GROUP.size()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_cfg_rank():
|
|
47
|
+
return PROCESS_GROUP.CFG_GROUP.rank()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_cfg_ranks():
|
|
51
|
+
return PROCESS_GROUP.CFG_RANKS
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_sp_group():
|
|
55
|
+
return PROCESS_GROUP.SP_GROUP
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_sp_world_size():
|
|
59
|
+
return PROCESS_GROUP.SP_GROUP.size()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_sp_rank():
|
|
63
|
+
return PROCESS_GROUP.SP_GROUP.rank()
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_sp_ranks():
|
|
67
|
+
return PROCESS_GROUP.SP_RANKS
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def get_tp_group():
|
|
71
|
+
return PROCESS_GROUP.TP_GROUP
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def get_tp_world_size():
|
|
75
|
+
return PROCESS_GROUP.TP_GROUP.size()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def get_tp_rank():
|
|
79
|
+
return PROCESS_GROUP.TP_GROUP.rank()
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def get_tp_ranks():
|
|
83
|
+
return PROCESS_GROUP.TP_RANKS
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def init_parallel_pgs(
|
|
87
|
+
cfg_degree: int = 1,
|
|
88
|
+
sp_ulysses_degree: int = 1,
|
|
89
|
+
sp_ring_degree: int = 1,
|
|
90
|
+
tp_degree: int = 1,
|
|
91
|
+
rank: int = 0,
|
|
92
|
+
world_size: int = 1,
|
|
93
|
+
):
|
|
94
|
+
sp_degree = sp_ulysses_degree * sp_ring_degree
|
|
95
|
+
|
|
96
|
+
assert sp_degree == 1 or tp_degree == 1, "not allowed to enable sequence parallel and tensor parallel together"
|
|
97
|
+
assert world_size == cfg_degree * sp_degree * tp_degree, (
|
|
98
|
+
f"world_size ({world_size}) must be equal to cfg_degree ({cfg_degree}) * sp_degree ({sp_degree}) * tp_degree ({tp_degree})"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def make_parallel_groups(blocks: List[List[int]], degree: int):
|
|
102
|
+
groups, chunks = [], []
|
|
103
|
+
for block in blocks:
|
|
104
|
+
size = len(block) // degree
|
|
105
|
+
chunk = [block[i * size : (i + 1) * size] for i in range(degree)]
|
|
106
|
+
chunks.extend(chunk)
|
|
107
|
+
groups.extend(list(zip(*chunk)))
|
|
108
|
+
return groups, chunks
|
|
109
|
+
|
|
110
|
+
blocks = [list(range(world_size))]
|
|
111
|
+
cfg_groups, cfg_blocks = make_parallel_groups(blocks, cfg_degree)
|
|
112
|
+
for cfg_ranks in cfg_groups:
|
|
113
|
+
cfg_group = dist.new_group(cfg_ranks)
|
|
114
|
+
if rank in cfg_ranks:
|
|
115
|
+
PROCESS_GROUP.CFG_GROUP = cfg_group
|
|
116
|
+
PROCESS_GROUP.CFG_RANKS = cfg_ranks
|
|
117
|
+
|
|
118
|
+
sp_groups, sp_blocks = make_parallel_groups(cfg_blocks, sp_degree)
|
|
119
|
+
for sp_ranks in sp_groups:
|
|
120
|
+
group = dist.new_group(sp_ranks)
|
|
121
|
+
if rank in sp_ranks:
|
|
122
|
+
PROCESS_GROUP.SP_GROUP = group
|
|
123
|
+
PROCESS_GROUP.SP_RANKS = sp_ranks
|
|
124
|
+
|
|
125
|
+
tp_groups, _ = make_parallel_groups(sp_blocks, tp_degree)
|
|
126
|
+
for tp_ranks in tp_groups:
|
|
127
|
+
group = dist.new_group(tp_ranks)
|
|
128
|
+
if rank in tp_ranks:
|
|
129
|
+
PROCESS_GROUP.TP_GROUP = group
|
|
130
|
+
PROCESS_GROUP.TP_RANKS = tp_ranks
|
|
131
|
+
|
|
132
|
+
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def clone(data):
|
|
136
|
+
if isinstance(data, dict):
|
|
137
|
+
return {k: clone(v) for k, v in data.items()}
|
|
138
|
+
if isinstance(data, tuple) or isinstance(data, list):
|
|
139
|
+
return [clone(t) for t in data]
|
|
140
|
+
elif isinstance(data, torch.Tensor):
|
|
141
|
+
return data.clone()
|
|
142
|
+
else:
|
|
143
|
+
return copy.deepcopy(data)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def to_device(data, device):
|
|
147
|
+
if isinstance(data, dict):
|
|
148
|
+
return {k: to_device(v, device) for k, v in data.items()}
|
|
149
|
+
if isinstance(data, tuple) or isinstance(data, list):
|
|
150
|
+
return [to_device(t, device) for t in data]
|
|
151
|
+
elif isinstance(data, torch.Tensor):
|
|
152
|
+
return data.to(device)
|
|
153
|
+
else:
|
|
154
|
+
return data
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def split_and_get(data, num, dim, index):
|
|
158
|
+
if isinstance(data, dict):
|
|
159
|
+
return {k: split_and_get(v, num, dim, index) for k, v in data.items()}
|
|
160
|
+
if isinstance(data, tuple) or isinstance(data, list):
|
|
161
|
+
return [split_and_get(t, num, dim, index) for t in data]
|
|
162
|
+
if isinstance(data, torch.Tensor):
|
|
163
|
+
if data.shape[dim] < num:
|
|
164
|
+
raise ValueError(f"data.shape[{dim}] ({data.shape[dim]}) < num ({num}), split failed")
|
|
165
|
+
return torch.split(data, data.shape[dim] // num, dim)[index]
|
|
166
|
+
return data
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def shard_model(
|
|
170
|
+
module: nn.Module,
|
|
171
|
+
device_id: int,
|
|
172
|
+
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD,
|
|
173
|
+
wrap_module_names: Optional[List[str]] = None,
|
|
174
|
+
):
|
|
175
|
+
wrap_module_names = wrap_module_names or []
|
|
176
|
+
|
|
177
|
+
def wrap_fn(m):
|
|
178
|
+
for name in wrap_module_names:
|
|
179
|
+
submodule = getattr(module, name)
|
|
180
|
+
if isinstance(submodule, nn.ModuleList) and m in submodule:
|
|
181
|
+
return True
|
|
182
|
+
elif not isinstance(submodule, nn.ModuleList) and m is submodule:
|
|
183
|
+
return True
|
|
184
|
+
return False
|
|
185
|
+
|
|
186
|
+
return FSDP(
|
|
187
|
+
module,
|
|
188
|
+
device_id=device_id,
|
|
189
|
+
sharding_strategy=sharding_strategy,
|
|
190
|
+
auto_wrap_policy=partial(lambda_auto_wrap_policy, lambda_fn=wrap_fn),
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def parallelize_module(
|
|
195
|
+
module: nn.Module,
|
|
196
|
+
device_mesh: DeviceMesh,
|
|
197
|
+
parallelize_plan: Optional[Union[ParallelStyle, Dict[str, ParallelStyle]]] = None,
|
|
198
|
+
):
|
|
199
|
+
_validate_tp_mesh_dim(device_mesh)
|
|
200
|
+
if parallelize_plan is None:
|
|
201
|
+
return module
|
|
202
|
+
if isinstance(parallelize_plan, ParallelStyle):
|
|
203
|
+
return parallelize_plan._apply(module, device_mesh)
|
|
204
|
+
for module_path, parallelize_style in parallelize_plan.items():
|
|
205
|
+
if module_path.strip() == "":
|
|
206
|
+
raise ValueError("Expect module path to be non-empty, but got empty string!")
|
|
207
|
+
try:
|
|
208
|
+
submodule = module.get_submodule(module_path)
|
|
209
|
+
parallelize_style._apply(submodule, device_mesh)
|
|
210
|
+
except AttributeError:
|
|
211
|
+
continue
|
|
212
|
+
return module
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
NCCL_TIMEOUT_SEC = int(os.environ.get("NCCL_TIMEOUT_SEC", 600))
|
|
216
|
+
PARALLEL_FWD_TIMEOUT_SEC = int(os.environ.get("PARALLEL_FWD_TIMEOUT_SEC", 300))
|
|
217
|
+
PARALLEL_LORA_TIMEOUT_SEC = int(os.environ.get("PARALLEL_LORA_TIMEOUT_SEC ", 60))
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _worker_loop(
|
|
221
|
+
rank: int,
|
|
222
|
+
world_size: int,
|
|
223
|
+
queue_in: mp.Queue,
|
|
224
|
+
queue_out: mp.Queue,
|
|
225
|
+
module: nn.Module,
|
|
226
|
+
cfg_degree: int,
|
|
227
|
+
sp_ulysses_degree: int,
|
|
228
|
+
sp_ring_degree: int,
|
|
229
|
+
tp_degree: int,
|
|
230
|
+
shard_fn: Optional[Callable] = None,
|
|
231
|
+
master_port: int = 29500,
|
|
232
|
+
device: str = "cuda",
|
|
233
|
+
):
|
|
234
|
+
"""
|
|
235
|
+
https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors
|
|
236
|
+
"""
|
|
237
|
+
try:
|
|
238
|
+
os.environ["RANK"] = str(rank)
|
|
239
|
+
os.environ["WORLD_SIZE"] = str(world_size)
|
|
240
|
+
os.environ["MASTER_ADDR"] = "localhost"
|
|
241
|
+
os.environ["MASTER_PORT"] = str(master_port)
|
|
242
|
+
torch.cuda.set_device(rank)
|
|
243
|
+
|
|
244
|
+
timeout = timedelta(seconds=NCCL_TIMEOUT_SEC)
|
|
245
|
+
dist.init_process_group(
|
|
246
|
+
backend="nccl",
|
|
247
|
+
init_method="env://",
|
|
248
|
+
timeout=timeout,
|
|
249
|
+
world_size=world_size,
|
|
250
|
+
rank=rank,
|
|
251
|
+
)
|
|
252
|
+
init_parallel_pgs(
|
|
253
|
+
cfg_degree=cfg_degree,
|
|
254
|
+
sp_ulysses_degree=sp_ulysses_degree,
|
|
255
|
+
sp_ring_degree=sp_ring_degree,
|
|
256
|
+
tp_degree=tp_degree,
|
|
257
|
+
rank=rank,
|
|
258
|
+
world_size=world_size,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
if tp_degree > 1:
|
|
262
|
+
module = parallelize_module(
|
|
263
|
+
module=module,
|
|
264
|
+
device_mesh=DeviceMesh(device, torch.tensor(get_tp_ranks())),
|
|
265
|
+
parallelize_plan=module.get_tp_plan(),
|
|
266
|
+
).to(device)
|
|
267
|
+
elif shard_fn:
|
|
268
|
+
module = shard_fn(module=module, device_id=rank)
|
|
269
|
+
else:
|
|
270
|
+
module = module.to(device)
|
|
271
|
+
|
|
272
|
+
while True:
|
|
273
|
+
if rank == 0:
|
|
274
|
+
kwargs = queue_in.get()
|
|
275
|
+
data = [kwargs]
|
|
276
|
+
else:
|
|
277
|
+
data = [None]
|
|
278
|
+
dist.broadcast_object_list(data, src=0)
|
|
279
|
+
kwargs = clone(data[0])
|
|
280
|
+
del data
|
|
281
|
+
|
|
282
|
+
y = None
|
|
283
|
+
if kwargs.get("method", None) == "load_loras":
|
|
284
|
+
module.load_loras(lora_args=kwargs["lora_args"], fused=kwargs["fused"])
|
|
285
|
+
elif kwargs.get("method", None) == "unload_loras":
|
|
286
|
+
module.unload_loras()
|
|
287
|
+
else:
|
|
288
|
+
kwargs = to_device(kwargs, device)
|
|
289
|
+
kwargs = split_and_get(kwargs, get_cfg_world_size(), 0, get_cfg_rank())
|
|
290
|
+
with torch.no_grad():
|
|
291
|
+
y = module(**kwargs)
|
|
292
|
+
if get_sp_rank() == 0 and get_tp_rank() == 0:
|
|
293
|
+
gathered = torch.zeros((get_cfg_world_size(), *y.shape[1:]), dtype=y.dtype, device=y.device)
|
|
294
|
+
dist.all_gather_into_tensor(gathered, y, group=get_cfg_group())
|
|
295
|
+
y = gathered
|
|
296
|
+
|
|
297
|
+
if rank == 0:
|
|
298
|
+
queue_out.put(y)
|
|
299
|
+
dist.barrier()
|
|
300
|
+
except Exception as e:
|
|
301
|
+
import traceback
|
|
302
|
+
|
|
303
|
+
traceback.print_exc()
|
|
304
|
+
logger.error(f"Error in worker loop (rank {rank}): {e}")
|
|
305
|
+
finally:
|
|
306
|
+
del module
|
|
307
|
+
torch.cuda.synchronize()
|
|
308
|
+
torch.cuda.empty_cache()
|
|
309
|
+
dist.destroy_process_group()
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
class ParallelModel(nn.Module):
|
|
313
|
+
def __init__(
|
|
314
|
+
self,
|
|
315
|
+
module: nn.Module,
|
|
316
|
+
cfg_degree: int,
|
|
317
|
+
sp_ulysses_degree: int,
|
|
318
|
+
sp_ring_degree: int,
|
|
319
|
+
tp_degree: int,
|
|
320
|
+
shard_fn: Optional[Callable] = None,
|
|
321
|
+
master_port: int = 29500,
|
|
322
|
+
device: str = "cuda",
|
|
323
|
+
):
|
|
324
|
+
super().__init__()
|
|
325
|
+
self.world_size = cfg_degree * sp_ulysses_degree * sp_ring_degree * tp_degree
|
|
326
|
+
self.device = device
|
|
327
|
+
self.queue_in = mp.Queue()
|
|
328
|
+
self.queue_out = mp.Queue()
|
|
329
|
+
self.ctx = mp.spawn(
|
|
330
|
+
_worker_loop,
|
|
331
|
+
args=(
|
|
332
|
+
self.world_size,
|
|
333
|
+
self.queue_in,
|
|
334
|
+
self.queue_out,
|
|
335
|
+
module,
|
|
336
|
+
cfg_degree,
|
|
337
|
+
sp_ulysses_degree,
|
|
338
|
+
sp_ring_degree,
|
|
339
|
+
tp_degree,
|
|
340
|
+
shard_fn,
|
|
341
|
+
master_port,
|
|
342
|
+
device,
|
|
343
|
+
),
|
|
344
|
+
nprocs=self.world_size,
|
|
345
|
+
join=False,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
def load_loras(self, lora_args: List[Dict[str, any]], fused: bool = True):
|
|
349
|
+
self.queue_in.put(
|
|
350
|
+
{
|
|
351
|
+
"method": "load_loras",
|
|
352
|
+
"lora_args": lora_args,
|
|
353
|
+
"fused": fused,
|
|
354
|
+
}
|
|
355
|
+
)
|
|
356
|
+
try:
|
|
357
|
+
_ = self.queue_out.get(timeout=PARALLEL_LORA_TIMEOUT_SEC)
|
|
358
|
+
except Empty:
|
|
359
|
+
logger.error("Parallel model load LoRA timeout")
|
|
360
|
+
raise RuntimeError("Parallel model load LoRA timeout")
|
|
361
|
+
logger.info("Parallel model load LoRA done")
|
|
362
|
+
|
|
363
|
+
def unload_loras(self):
|
|
364
|
+
self.queue_in.put({"method": "unload_loras"})
|
|
365
|
+
try:
|
|
366
|
+
_ = self.queue_out.get(timeout=PARALLEL_LORA_TIMEOUT_SEC)
|
|
367
|
+
except Empty:
|
|
368
|
+
logger.error("Parallel model unload LoRA timeout")
|
|
369
|
+
raise RuntimeError("Parallel model unload LoRA timeout")
|
|
370
|
+
logger.info("Parallel model unload LoRA done")
|
|
371
|
+
|
|
372
|
+
def forward(self, **kwargs):
|
|
373
|
+
self.queue_in.put(kwargs)
|
|
374
|
+
try:
|
|
375
|
+
y = self.queue_out.get(timeout=PARALLEL_FWD_TIMEOUT_SEC)
|
|
376
|
+
except Empty:
|
|
377
|
+
logger.error("Parallel model forward timeout")
|
|
378
|
+
raise RuntimeError("Parallel model forward timeout")
|
|
379
|
+
return y
|
|
380
|
+
|
|
381
|
+
def __del__(self):
|
|
382
|
+
# Send terminate signal to all workers
|
|
383
|
+
for p in self.ctx.processes:
|
|
384
|
+
p.terminate()
|
|
385
|
+
p.join()
|
|
386
|
+
self.queue_in.close()
|
|
387
|
+
self.queue_out.close()
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
__all__ = ["ParallelModel"]
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from typing import Union, Optional
|
|
2
|
+
|
|
3
|
+
from diffsynth_engine.tokenizers import CLIPTokenizer, T5TokenizerFast
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def tokenize_long_prompt(
|
|
7
|
+
tokenizer: Union[CLIPTokenizer, T5TokenizerFast], prompt: str, max_length: Optional[int] = None
|
|
8
|
+
):
|
|
9
|
+
return tokenizer(prompt)["input_ids"]
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import imageio
|
|
2
|
+
import imageio.v3 as iio
|
|
3
|
+
import numpy as np
|
|
4
|
+
from PIL import Image
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class VideoReader:
|
|
9
|
+
def __init__(self, path: str):
|
|
10
|
+
self.reader = imageio.get_reader(path)
|
|
11
|
+
|
|
12
|
+
def __len__(self):
|
|
13
|
+
return self.reader.count_frames()
|
|
14
|
+
|
|
15
|
+
def __getitem__(self, item):
|
|
16
|
+
return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
|
|
17
|
+
|
|
18
|
+
def __del__(self):
|
|
19
|
+
self.reader.close()
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def frames(self) -> List[Image.Image]:
|
|
23
|
+
return [self[i] for i in range(len(self))]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def load_video(path: str) -> VideoReader:
|
|
27
|
+
return VideoReader(path)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def save_video(frames, save_path, fps=15):
|
|
31
|
+
if save_path.endswith(".webm"):
|
|
32
|
+
codec = "libvpx-vp9"
|
|
33
|
+
elif save_path.endswith(".mp4"):
|
|
34
|
+
codec = "libx264"
|
|
35
|
+
|
|
36
|
+
frames = [np.array(img) for img in frames]
|
|
37
|
+
|
|
38
|
+
# 使用 imageio 写入 .webm 文件
|
|
39
|
+
with iio.imopen(save_path, "w", plugin="FFMPEG") as writer:
|
|
40
|
+
writer.write(frames, fps=fps, codec=codec)
|