checkpoint-engine 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpoint_engine/_version.py +2 -2
- checkpoint_engine/distributed/__init__.py +28 -0
- checkpoint_engine/distributed/base.py +288 -0
- checkpoint_engine/distributed/vllm_hccl.py +323 -0
- checkpoint_engine/distributed/vllm_nccl.py +223 -0
- checkpoint_engine/ps.py +55 -43
- checkpoint_engine/worker.py +49 -9
- {checkpoint_engine-0.3.3.dist-info → checkpoint_engine-0.4.0.dist-info}/METADATA +1 -1
- checkpoint_engine-0.4.0.dist-info/RECORD +19 -0
- {checkpoint_engine-0.3.3.dist-info → checkpoint_engine-0.4.0.dist-info}/WHEEL +1 -1
- checkpoint_engine-0.3.3.dist-info/RECORD +0 -15
- {checkpoint_engine-0.3.3.dist-info → checkpoint_engine-0.4.0.dist-info}/licenses/LICENCE +0 -0
- {checkpoint_engine-0.3.3.dist-info → checkpoint_engine-0.4.0.dist-info}/top_level.txt +0 -0
checkpoint_engine/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0,
|
|
31
|
+
__version__ = version = '0.4.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 4, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from .base import (
|
|
2
|
+
Distributed,
|
|
3
|
+
DistributedProcessGroup,
|
|
4
|
+
all_gather_object,
|
|
5
|
+
all_reduce,
|
|
6
|
+
barrier,
|
|
7
|
+
broadcast,
|
|
8
|
+
destroy_process_group,
|
|
9
|
+
init_process_group,
|
|
10
|
+
is_initialized,
|
|
11
|
+
new_group,
|
|
12
|
+
use_backend,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"Distributed",
|
|
18
|
+
"DistributedProcessGroup",
|
|
19
|
+
"all_gather_object",
|
|
20
|
+
"all_reduce",
|
|
21
|
+
"barrier",
|
|
22
|
+
"broadcast",
|
|
23
|
+
"destroy_process_group",
|
|
24
|
+
"init_process_group",
|
|
25
|
+
"is_initialized",
|
|
26
|
+
"new_group",
|
|
27
|
+
"use_backend",
|
|
28
|
+
]
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import io
|
|
3
|
+
import pickle
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from datetime import timedelta
|
|
6
|
+
from typing import Any, Protocol
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.distributed as torch_dist
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CommunicatorProtocol(Protocol):
|
|
13
|
+
def all_gather(self, *args: Any, **kwargs: Any) -> torch.Tensor: ...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CommGroup:
|
|
17
|
+
def __init__(self, comm_handle: int, ranks: list[int]):
|
|
18
|
+
self._comm = comm_handle
|
|
19
|
+
self._ranks = ranks
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def handle(self) -> int:
|
|
23
|
+
return self._comm
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def ranks(self) -> list[int]:
|
|
27
|
+
return self._ranks
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
DistributedProcessGroup = torch_dist.ProcessGroup | CommGroup
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Distributed(ABC):
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def init_process_group(
|
|
36
|
+
self,
|
|
37
|
+
rank: int,
|
|
38
|
+
world_size: int,
|
|
39
|
+
store: torch_dist.TCPStore,
|
|
40
|
+
**kwargs,
|
|
41
|
+
):
|
|
42
|
+
raise NotImplementedError
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def destroy_process_group(
|
|
46
|
+
self,
|
|
47
|
+
group: DistributedProcessGroup | None = None,
|
|
48
|
+
):
|
|
49
|
+
raise NotImplementedError
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def is_initialized(self) -> bool:
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def all_gather_object(
|
|
57
|
+
self,
|
|
58
|
+
object_list: list[Any],
|
|
59
|
+
obj: Any,
|
|
60
|
+
group: DistributedProcessGroup | None = None,
|
|
61
|
+
):
|
|
62
|
+
raise NotImplementedError
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def all_reduce(
|
|
66
|
+
self,
|
|
67
|
+
tensor: torch.Tensor,
|
|
68
|
+
op: torch_dist.ReduceOp.RedOpType,
|
|
69
|
+
group: DistributedProcessGroup | None = None,
|
|
70
|
+
**kwargs,
|
|
71
|
+
):
|
|
72
|
+
raise NotImplementedError
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
def broadcast(
|
|
76
|
+
self,
|
|
77
|
+
tensor: torch.Tensor,
|
|
78
|
+
src: int,
|
|
79
|
+
group: DistributedProcessGroup | None = None,
|
|
80
|
+
**kwargs,
|
|
81
|
+
):
|
|
82
|
+
raise NotImplementedError
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def barrier(
|
|
86
|
+
self,
|
|
87
|
+
group: DistributedProcessGroup | None = None,
|
|
88
|
+
**kwargs,
|
|
89
|
+
):
|
|
90
|
+
raise NotImplementedError
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
def new_group(
|
|
94
|
+
self,
|
|
95
|
+
ranks: list[int],
|
|
96
|
+
**kwargs,
|
|
97
|
+
):
|
|
98
|
+
raise NotImplementedError
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class TorchBackend(Distributed):
|
|
102
|
+
def init_process_group(
|
|
103
|
+
self,
|
|
104
|
+
rank: int,
|
|
105
|
+
world_size: int,
|
|
106
|
+
store: torch_dist.TCPStore,
|
|
107
|
+
**kwargs,
|
|
108
|
+
):
|
|
109
|
+
backend = kwargs.get("backend", "nccl")
|
|
110
|
+
timeout = kwargs.get("timeout", timedelta(minutes=10))
|
|
111
|
+
|
|
112
|
+
torch_dist.init_process_group(
|
|
113
|
+
backend=backend,
|
|
114
|
+
world_size=world_size,
|
|
115
|
+
rank=rank,
|
|
116
|
+
timeout=timeout,
|
|
117
|
+
store=store,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def destroy_process_group(self, group: DistributedProcessGroup | None = None):
|
|
121
|
+
torch_dist.destroy_process_group(group)
|
|
122
|
+
|
|
123
|
+
def is_initialized(self) -> bool:
|
|
124
|
+
return torch_dist.is_initialized()
|
|
125
|
+
|
|
126
|
+
def all_gather_object(
|
|
127
|
+
self, object_list: list[Any], obj: Any, group: DistributedProcessGroup | None = None
|
|
128
|
+
):
|
|
129
|
+
torch_dist.all_gather_object(object_list, obj, group)
|
|
130
|
+
|
|
131
|
+
def all_reduce(
|
|
132
|
+
self,
|
|
133
|
+
tensor: torch.Tensor,
|
|
134
|
+
op: torch_dist.ReduceOp.RedOpType = torch_dist.ReduceOp.SUM,
|
|
135
|
+
group: DistributedProcessGroup | None = None,
|
|
136
|
+
**kwargs,
|
|
137
|
+
):
|
|
138
|
+
torch_dist.all_reduce(tensor, op, group, **kwargs)
|
|
139
|
+
|
|
140
|
+
def broadcast(
|
|
141
|
+
self,
|
|
142
|
+
tensor: torch.Tensor,
|
|
143
|
+
src: int = 0,
|
|
144
|
+
group: DistributedProcessGroup | None = None,
|
|
145
|
+
**kwargs,
|
|
146
|
+
):
|
|
147
|
+
torch_dist.broadcast(tensor, src, group, **kwargs)
|
|
148
|
+
|
|
149
|
+
def barrier(self, group: DistributedProcessGroup | None = None, **kwargs):
|
|
150
|
+
torch_dist.barrier(group, **kwargs)
|
|
151
|
+
|
|
152
|
+
def new_group(self, ranks: list[int], **kwargs) -> DistributedProcessGroup | None:
|
|
153
|
+
return torch_dist.new_group(ranks, **kwargs)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# specific device instance
|
|
157
|
+
_BACKEND_INSTANCE: Distributed = TorchBackend()
|
|
158
|
+
|
|
159
|
+
_pickler = pickle.Pickler
|
|
160
|
+
_unpickler = pickle.Unpickler
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _object_to_tensor(obj: Any, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
|
164
|
+
f = io.BytesIO()
|
|
165
|
+
_pickler(f).dump(obj)
|
|
166
|
+
byte_storage = torch.ByteStorage._from_buffer(f.getvalue())
|
|
167
|
+
byte_tensor = torch.ByteTensor(byte_storage).to(device)
|
|
168
|
+
local_size = torch.LongTensor([byte_tensor.numel()]).to(device)
|
|
169
|
+
return byte_tensor, local_size
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _tensor_to_object(tensor: torch.Tensor, tensor_size: int) -> Any:
|
|
173
|
+
tensor = tensor.cpu()
|
|
174
|
+
buf = tensor.numpy().tobytes()[:tensor_size]
|
|
175
|
+
return _unpickler(io.BytesIO(buf)).load()
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _flatten_for_scatter_gather(
|
|
179
|
+
tensor_list: list[torch.Tensor], copy: bool = False
|
|
180
|
+
) -> torch.Tensor:
|
|
181
|
+
if not tensor_list:
|
|
182
|
+
raise RuntimeError("Received an empty list.")
|
|
183
|
+
t = tensor_list[0]
|
|
184
|
+
buffer_shape = [len(tensor_list)] + list(t.shape)
|
|
185
|
+
|
|
186
|
+
buffer = torch.empty(tuple(buffer_shape), dtype=t.dtype, device=t.device)
|
|
187
|
+
if copy:
|
|
188
|
+
for i, tensor in enumerate(tensor_list):
|
|
189
|
+
buffer[i].copy_(tensor)
|
|
190
|
+
return buffer
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _common_all_gather_object(
|
|
194
|
+
comm: CommunicatorProtocol,
|
|
195
|
+
device: torch.device,
|
|
196
|
+
world_size: int,
|
|
197
|
+
object_list: list[Any],
|
|
198
|
+
object: Any,
|
|
199
|
+
):
|
|
200
|
+
input_tensor, local_size = _object_to_tensor(object, device)
|
|
201
|
+
object_sizes_tensor = torch.empty(world_size, dtype=torch.long, device=device)
|
|
202
|
+
comm.all_gather(object_sizes_tensor, local_size)
|
|
203
|
+
object_size_list = [object_sizes_tensor[i].unsqueeze(dim=0) for i in range(world_size)]
|
|
204
|
+
max_object_size = int(max(object_size_list).item())
|
|
205
|
+
input_tensor.resize_(max_object_size)
|
|
206
|
+
coalesced_output_tensor = torch.empty(
|
|
207
|
+
max_object_size * world_size, dtype=torch.uint8, device=device
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
comm.all_gather(coalesced_output_tensor, input_tensor)
|
|
211
|
+
output_tensors = [
|
|
212
|
+
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
|
|
213
|
+
for i in range(world_size)
|
|
214
|
+
]
|
|
215
|
+
for i, tensor in enumerate(output_tensors):
|
|
216
|
+
tensor = tensor.type(torch.uint8)
|
|
217
|
+
tensor_size = object_size_list[i]
|
|
218
|
+
object_list[i] = _tensor_to_object(tensor, tensor_size)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def use_backend(backend: str | None):
|
|
222
|
+
global _BACKEND_INSTANCE
|
|
223
|
+
|
|
224
|
+
if not backend:
|
|
225
|
+
return
|
|
226
|
+
|
|
227
|
+
mapping = {
|
|
228
|
+
"vllm_nccl": ".vllm_nccl.DistributedNccl",
|
|
229
|
+
"vllm_hccl": ".vllm_hccl.DistributedHccl",
|
|
230
|
+
}
|
|
231
|
+
if backend not in mapping:
|
|
232
|
+
raise ValueError(f"Unsupported custom backend: {backend}")
|
|
233
|
+
|
|
234
|
+
module_path, class_name = mapping[backend].rsplit(".", 1)
|
|
235
|
+
module = importlib.import_module(module_path, "checkpoint_engine.distributed")
|
|
236
|
+
backend_class = getattr(module, class_name)
|
|
237
|
+
_BACKEND_INSTANCE = backend_class()
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def init_process_group(
|
|
241
|
+
rank: int,
|
|
242
|
+
world_size: int,
|
|
243
|
+
store: torch_dist.TCPStore,
|
|
244
|
+
**kwargs,
|
|
245
|
+
):
|
|
246
|
+
_BACKEND_INSTANCE.init_process_group(rank, world_size, store, **kwargs)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def destroy_process_group(group: DistributedProcessGroup | None = None):
|
|
250
|
+
_BACKEND_INSTANCE.destroy_process_group(group)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def is_initialized() -> bool:
|
|
254
|
+
return _BACKEND_INSTANCE.is_initialized()
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def all_gather_object(
|
|
258
|
+
object_list: list[Any],
|
|
259
|
+
obj: Any,
|
|
260
|
+
group: DistributedProcessGroup | None = None,
|
|
261
|
+
):
|
|
262
|
+
_BACKEND_INSTANCE.all_gather_object(object_list, obj, group)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def all_reduce(
|
|
266
|
+
tensor: torch.Tensor,
|
|
267
|
+
op: torch_dist.ReduceOp.RedOpType = torch_dist.ReduceOp.SUM,
|
|
268
|
+
group: DistributedProcessGroup | None = None,
|
|
269
|
+
**kwargs,
|
|
270
|
+
):
|
|
271
|
+
_BACKEND_INSTANCE.all_reduce(tensor, op, group, **kwargs)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def broadcast(
|
|
275
|
+
tensor: torch.Tensor,
|
|
276
|
+
src: int = 0,
|
|
277
|
+
group: DistributedProcessGroup | None = None,
|
|
278
|
+
**kwargs,
|
|
279
|
+
):
|
|
280
|
+
_BACKEND_INSTANCE.broadcast(tensor, src, group, **kwargs)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def barrier(group: DistributedProcessGroup | None = None, **kwargs):
|
|
284
|
+
_BACKEND_INSTANCE.barrier(group, **kwargs)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def new_group(ranks: list[int], **kwargs) -> DistributedProcessGroup | None:
|
|
288
|
+
return _BACKEND_INSTANCE.new_group(ranks, **kwargs)
|
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
import ctypes
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from typing import Any, ClassVar
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch.distributed import ReduceOp
|
|
7
|
+
from vllm.distributed.utils import StatelessProcessGroup
|
|
8
|
+
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator
|
|
9
|
+
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
|
|
10
|
+
Function,
|
|
11
|
+
HCCLLibrary,
|
|
12
|
+
aclrtStream_t,
|
|
13
|
+
buffer_type,
|
|
14
|
+
hcclComm_t,
|
|
15
|
+
hcclDataType_t,
|
|
16
|
+
hcclDataTypeEnum,
|
|
17
|
+
hcclResult_t,
|
|
18
|
+
)
|
|
19
|
+
from vllm_ascend.utils import current_stream
|
|
20
|
+
|
|
21
|
+
from checkpoint_engine.distributed.base import CommGroup, Distributed, _common_all_gather_object
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class HcclCommConfig(ctypes.Structure):
|
|
25
|
+
_fields_: ClassVar[list[tuple[str, Any]]] = [
|
|
26
|
+
("size", ctypes.c_size_t),
|
|
27
|
+
("magic_word", ctypes.c_uint32),
|
|
28
|
+
("version", ctypes.c_uint32),
|
|
29
|
+
("reserved", ctypes.c_uint64),
|
|
30
|
+
("hccl_buffer_size", ctypes.c_uint32),
|
|
31
|
+
("hccl_deterministic", ctypes.c_uint32),
|
|
32
|
+
("hccl_comm_name", ctypes.c_char * 128),
|
|
33
|
+
("hccl_udi", ctypes.c_char * 128),
|
|
34
|
+
("hccl_op_expansion_mode", ctypes.c_uint32),
|
|
35
|
+
("hccl_rdma_traffic_class", ctypes.c_uint32),
|
|
36
|
+
("hccl_rdma_service_level", ctypes.c_uint32),
|
|
37
|
+
("hcll_world_rank_id", ctypes.c_uint32),
|
|
38
|
+
("hccl_job_id", ctypes.c_uint64),
|
|
39
|
+
("comm_engine", ctypes.c_int32),
|
|
40
|
+
("thread_num", ctypes.c_uint32),
|
|
41
|
+
("notify_num_per_thread", ctypes.c_uint32),
|
|
42
|
+
("acl_graph_zero_copy_enable", ctypes.c_uint8),
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
orig_exported_functions = HCCLLibrary.exported_functions
|
|
47
|
+
extended_functions = [
|
|
48
|
+
# HcclResult HcclAllGather(
|
|
49
|
+
# void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType,
|
|
50
|
+
# HcclComm comm, alcrtStream stream
|
|
51
|
+
# )
|
|
52
|
+
Function(
|
|
53
|
+
"HcclAllGather",
|
|
54
|
+
hcclResult_t,
|
|
55
|
+
[
|
|
56
|
+
buffer_type,
|
|
57
|
+
buffer_type,
|
|
58
|
+
ctypes.c_uint64,
|
|
59
|
+
hcclDataType_t,
|
|
60
|
+
hcclComm_t,
|
|
61
|
+
aclrtStream_t,
|
|
62
|
+
],
|
|
63
|
+
),
|
|
64
|
+
# HcclResult HcclCreateSubCommConfig(
|
|
65
|
+
# HcclComm *comm, uin32_t rankNum, uint32_t *rankIds, uint64_t subCommId,
|
|
66
|
+
# uint32_t subCommRankId, HcclCommConfig *config, HcclComm *subComm
|
|
67
|
+
# )
|
|
68
|
+
Function(
|
|
69
|
+
"HcclCreateSubCommConfig",
|
|
70
|
+
hcclResult_t,
|
|
71
|
+
[
|
|
72
|
+
ctypes.POINTER(hcclComm_t),
|
|
73
|
+
ctypes.c_uint32,
|
|
74
|
+
ctypes.POINTER(ctypes.c_uint32),
|
|
75
|
+
ctypes.c_uint64,
|
|
76
|
+
ctypes.c_uint32,
|
|
77
|
+
ctypes.POINTER(HcclCommConfig),
|
|
78
|
+
ctypes.POINTER(hcclComm_t),
|
|
79
|
+
],
|
|
80
|
+
),
|
|
81
|
+
]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def hccl_all_gather(
|
|
85
|
+
self, # noqa: ANN001
|
|
86
|
+
send_buf: buffer_type,
|
|
87
|
+
recv_buf: buffer_type,
|
|
88
|
+
count: ctypes.c_uint64,
|
|
89
|
+
data_type: hcclDataType_t,
|
|
90
|
+
comm: hcclComm_t,
|
|
91
|
+
stream: aclrtStream_t,
|
|
92
|
+
):
|
|
93
|
+
self.HCCL_CHECK(
|
|
94
|
+
self._funcs["HcclAllGather"](send_buf, recv_buf, count, data_type, comm, stream)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def hccl_create_subcomm_config(
|
|
99
|
+
self, # noqa: ANN001
|
|
100
|
+
comm: hcclComm_t,
|
|
101
|
+
ranks_size: ctypes.c_uint32,
|
|
102
|
+
c_rank_ids: ctypes.POINTER(ctypes.c_uint32),
|
|
103
|
+
subcomm_id: ctypes.c_uint64,
|
|
104
|
+
subcomm_rank: ctypes.c_uint64,
|
|
105
|
+
comm_config: HcclCommConfig,
|
|
106
|
+
) -> hcclComm_t:
|
|
107
|
+
subcomm = hcclComm_t()
|
|
108
|
+
self.HCCL_CHECK(
|
|
109
|
+
self._funcs["HcclCreateSubCommConfig"](
|
|
110
|
+
ctypes.byref(comm),
|
|
111
|
+
ranks_size,
|
|
112
|
+
c_rank_ids,
|
|
113
|
+
subcomm_id,
|
|
114
|
+
subcomm_rank,
|
|
115
|
+
ctypes.byref(comm_config),
|
|
116
|
+
ctypes.byref(subcomm),
|
|
117
|
+
)
|
|
118
|
+
)
|
|
119
|
+
return subcomm
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
# extend HCCLLibrary
|
|
123
|
+
HCCLLibrary.exported_functions = orig_exported_functions + extended_functions
|
|
124
|
+
HCCLLibrary.hcclAllGather = hccl_all_gather
|
|
125
|
+
HCCLLibrary.hcclCreateSubCommConfig = hccl_create_subcomm_config
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class PyHcclCommunicatorEx(PyHcclCommunicator):
|
|
129
|
+
def __init__(self, group: StatelessProcessGroup, device: torch.device):
|
|
130
|
+
super().__init__(group, device)
|
|
131
|
+
self.subcomm_id = 1
|
|
132
|
+
|
|
133
|
+
def destroy_comm(self, comm: hcclComm_t = None):
|
|
134
|
+
if comm:
|
|
135
|
+
self.hccl.hcclCommDestroy(comm)
|
|
136
|
+
else:
|
|
137
|
+
self.hccl.hcclCommDestroy(self.comm)
|
|
138
|
+
|
|
139
|
+
def all_gather(
|
|
140
|
+
self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream: torch.npu.Stream = None
|
|
141
|
+
) -> torch.Tensor:
|
|
142
|
+
if self.disabled:
|
|
143
|
+
return
|
|
144
|
+
assert in_tensor.device == self.device, (
|
|
145
|
+
f"this hccl communicator is created to work on {self.device}, "
|
|
146
|
+
f"but the input tensor in on {in_tensor.device}"
|
|
147
|
+
)
|
|
148
|
+
if stream is None:
|
|
149
|
+
stream = current_stream()
|
|
150
|
+
self.hccl.hcclAllGather(
|
|
151
|
+
buffer_type(in_tensor.data_ptr()),
|
|
152
|
+
buffer_type(out_tensor.data_ptr()),
|
|
153
|
+
in_tensor.numel(),
|
|
154
|
+
hcclDataTypeEnum.from_torch(in_tensor.dtype),
|
|
155
|
+
self.comm, # todo
|
|
156
|
+
aclrtStream_t(stream.npu_stream),
|
|
157
|
+
)
|
|
158
|
+
return out_tensor
|
|
159
|
+
|
|
160
|
+
def create_subcomm(self, ranks: list[int]) -> hcclComm_t:
|
|
161
|
+
comm_config = HcclCommConfig(
|
|
162
|
+
size=312,
|
|
163
|
+
magic_word=0xF0F0F0F0,
|
|
164
|
+
version=6,
|
|
165
|
+
reserved=0,
|
|
166
|
+
hccl_buffer_size=0xFFFFFFFF,
|
|
167
|
+
hccl_deterministic=0xFFFFFFFF,
|
|
168
|
+
hccl_comm_name=b"\0",
|
|
169
|
+
hccl_udi=b"\0",
|
|
170
|
+
hccl_op_expansize_mode=0,
|
|
171
|
+
hccl_rdma_traffic_class=0xFFFFFFFF,
|
|
172
|
+
hccl_rdma_service_level=0xFFFFFFFF,
|
|
173
|
+
hccl_world_rank_id=0,
|
|
174
|
+
hccl_job_id=0,
|
|
175
|
+
comm_engine=-1,
|
|
176
|
+
thread_num=0xFFFFFFFF,
|
|
177
|
+
notify_num_per_thread=0xFFFFFFFF,
|
|
178
|
+
acl_graph_zero_copy_enable=0,
|
|
179
|
+
)
|
|
180
|
+
uint32_array = ctypes.c_uint32 * len(ranks)
|
|
181
|
+
c_rank_ids = uint32_array(*ranks)
|
|
182
|
+
subcomm_rank = ranks.index(self.rank)
|
|
183
|
+
ranks_size = len(ranks)
|
|
184
|
+
subcomm_id = self.subcomm_id
|
|
185
|
+
|
|
186
|
+
subcomm = self.hccl.hcclCreateSubCommConfig(
|
|
187
|
+
self.comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config
|
|
188
|
+
)
|
|
189
|
+
self.subcomm_id += 1
|
|
190
|
+
return subcomm
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class DistributedHccl(Distributed):
|
|
194
|
+
def __init__(self):
|
|
195
|
+
self.pg: StatelessProcessGroup = None
|
|
196
|
+
self.pyhccl: PyHcclCommunicatorEx = None
|
|
197
|
+
self.sub_groups: dict[int, CommGroup] = {}
|
|
198
|
+
self.comm: hcclComm_t = None
|
|
199
|
+
|
|
200
|
+
self.host: str = None
|
|
201
|
+
self.port: int = None
|
|
202
|
+
self.rank: int = None
|
|
203
|
+
self.world_size: int = None
|
|
204
|
+
self.device: torch.device = None
|
|
205
|
+
|
|
206
|
+
self.initialized: bool = False
|
|
207
|
+
|
|
208
|
+
@contextmanager
|
|
209
|
+
def _use_group(self, group: CommGroup | None, src: int | None = None):
|
|
210
|
+
active_src = src
|
|
211
|
+
if group:
|
|
212
|
+
assert group.handle in self.sub_groups, "invalid sub_group"
|
|
213
|
+
newcomm = ctypes.c_void_p(group.handle)
|
|
214
|
+
self.pyhccl.comm = newcomm
|
|
215
|
+
|
|
216
|
+
if src is not None:
|
|
217
|
+
assert src in group.ranks, "src rank not in group"
|
|
218
|
+
# convert src rank id in default world to newcomm
|
|
219
|
+
active_src = group.ranks.index(src)
|
|
220
|
+
self.pyhccl.rank = group.ranks.index(self.rank)
|
|
221
|
+
|
|
222
|
+
try:
|
|
223
|
+
yield active_src
|
|
224
|
+
finally:
|
|
225
|
+
if group:
|
|
226
|
+
self.pyhccl.comm = self.comm
|
|
227
|
+
if src is not None:
|
|
228
|
+
self.pyhccl.rank = self.rank
|
|
229
|
+
|
|
230
|
+
def init_process_group(
|
|
231
|
+
self,
|
|
232
|
+
rank: int,
|
|
233
|
+
world_size: int,
|
|
234
|
+
store: torch.distributed.TCPStore,
|
|
235
|
+
**kwargs,
|
|
236
|
+
):
|
|
237
|
+
assert not self.initialized, "already initialized"
|
|
238
|
+
|
|
239
|
+
self.rank = rank
|
|
240
|
+
self.world_size = world_size
|
|
241
|
+
self.device = torch.device("npu", torch.npu.current_device())
|
|
242
|
+
|
|
243
|
+
self.pg = StatelessProcessGroup(rank=rank, world_size=world_size, store=store, socket=None)
|
|
244
|
+
self.pyhccl = PyHcclCommunicatorEx(group=self.pg, device=self.device)
|
|
245
|
+
self.comm = self.pyhccl.comm
|
|
246
|
+
self.initialized = True
|
|
247
|
+
|
|
248
|
+
def destroy_process_group(
|
|
249
|
+
self,
|
|
250
|
+
group: CommGroup | None = None,
|
|
251
|
+
):
|
|
252
|
+
assert self.initialized, "not initialized"
|
|
253
|
+
|
|
254
|
+
if group and group.handle in self.sub_groups:
|
|
255
|
+
subcomm = ctypes.c_void_p(group.handle)
|
|
256
|
+
self.pyhccl.destroy_comm(subcomm)
|
|
257
|
+
del self.sub_groups[group.handle]
|
|
258
|
+
return
|
|
259
|
+
|
|
260
|
+
self.pyhccl.destroy_comm()
|
|
261
|
+
self.pyhccl = None
|
|
262
|
+
self.pg = None
|
|
263
|
+
self.initialized = False
|
|
264
|
+
|
|
265
|
+
def is_initialized(self) -> bool:
|
|
266
|
+
return self.initialized
|
|
267
|
+
|
|
268
|
+
def all_gather_object(self, object_list: list[Any], obj: Any, group: CommGroup | None = None):
|
|
269
|
+
assert self.initialized, "not initialized"
|
|
270
|
+
|
|
271
|
+
with self._use_group(group):
|
|
272
|
+
_common_all_gather_object(self.pyhccl, self.device, self.world_size, object_list, obj)
|
|
273
|
+
current_stream().synchronize()
|
|
274
|
+
|
|
275
|
+
def all_reduce(
|
|
276
|
+
self,
|
|
277
|
+
tensor: torch.Tensor,
|
|
278
|
+
op: ReduceOp.RedOpType = ReduceOp.SUM,
|
|
279
|
+
group: CommGroup | None = None,
|
|
280
|
+
**kwargs,
|
|
281
|
+
):
|
|
282
|
+
assert self.initialized, "not initialized"
|
|
283
|
+
|
|
284
|
+
with self._use_group(group):
|
|
285
|
+
out_tensor = self.pyhccl.all_reduce(tensor, op)
|
|
286
|
+
current_stream().synchronize()
|
|
287
|
+
tensor.copy_(out_tensor)
|
|
288
|
+
|
|
289
|
+
def broadcast(
|
|
290
|
+
self, tensor: torch.Tensor, src: int | None = None, group: CommGroup | None = None, **kwargs
|
|
291
|
+
):
|
|
292
|
+
assert self.initialized, "not initialized"
|
|
293
|
+
|
|
294
|
+
with self._use_group(group, src) as local_rank:
|
|
295
|
+
self.pyhccl.broadcast(tensor, local_rank)
|
|
296
|
+
current_stream().synchronize()
|
|
297
|
+
|
|
298
|
+
def barrier(self, group: CommGroup | None = None, **kwargs):
|
|
299
|
+
assert self.initialized, "not initialized"
|
|
300
|
+
|
|
301
|
+
with self._use_group(group):
|
|
302
|
+
data = torch.zeros(1, device=self.device)
|
|
303
|
+
self.pyhccl.all_reduce(data)
|
|
304
|
+
current_stream().synchronize()
|
|
305
|
+
|
|
306
|
+
def new_group(self, ranks: list[int], **kwargs) -> CommGroup | None:
|
|
307
|
+
assert self.initialized, "not initialized"
|
|
308
|
+
|
|
309
|
+
# ranks is None or []
|
|
310
|
+
if not ranks:
|
|
311
|
+
ranks = list(range(self.world_size))
|
|
312
|
+
else:
|
|
313
|
+
ranks.sort()
|
|
314
|
+
|
|
315
|
+
group: CommGroup = None
|
|
316
|
+
if self.rank not in ranks:
|
|
317
|
+
return group
|
|
318
|
+
|
|
319
|
+
subcomm = self.pyhccl.create_subcomm(ranks)
|
|
320
|
+
if subcomm:
|
|
321
|
+
group = CommGroup(subcomm.value, ranks)
|
|
322
|
+
self.sub_groups[subcomm.value] = group
|
|
323
|
+
return group
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
import ctypes
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from typing import Any, ClassVar
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch.distributed import ReduceOp
|
|
7
|
+
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
|
8
|
+
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
|
9
|
+
Function,
|
|
10
|
+
NCCLLibrary,
|
|
11
|
+
ncclComm_t,
|
|
12
|
+
ncclResult_t,
|
|
13
|
+
)
|
|
14
|
+
from vllm.distributed.utils import StatelessProcessGroup
|
|
15
|
+
from vllm.utils import current_stream
|
|
16
|
+
|
|
17
|
+
from checkpoint_engine.distributed.base import CommGroup, Distributed, _common_all_gather_object
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class NcclConfigT(ctypes.Structure):
|
|
21
|
+
_fields_: ClassVar[list[tuple[str, Any]]] = [
|
|
22
|
+
("size", ctypes.c_size_t),
|
|
23
|
+
("magic", ctypes.c_uint),
|
|
24
|
+
("version", ctypes.c_uint),
|
|
25
|
+
("blocking", ctypes.c_int),
|
|
26
|
+
("cgaClusterSize", ctypes.c_int),
|
|
27
|
+
("minCTAs", ctypes.c_int),
|
|
28
|
+
("maxCTAs", ctypes.c_int),
|
|
29
|
+
("netName", ctypes.c_char_p),
|
|
30
|
+
("splitShare", ctypes.c_int),
|
|
31
|
+
("trafficClass", ctypes.c_int),
|
|
32
|
+
("commName", ctypes.c_char_p),
|
|
33
|
+
("collnetEnable", ctypes.c_int),
|
|
34
|
+
("CTAPolicy", ctypes.c_int),
|
|
35
|
+
("shrinkShare", ctypes.c_int),
|
|
36
|
+
("nvlsCTAs", ctypes.c_int),
|
|
37
|
+
("nChannelsPerNetPeer", ctypes.c_int),
|
|
38
|
+
("nvlinkCentricSched", ctypes.c_int),
|
|
39
|
+
("graphUsageMode", ctypes.c_int),
|
|
40
|
+
("numRmaCtx", ctypes.c_int),
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
nccl_orig_exported_functions = NCCLLibrary.exported_functions
|
|
45
|
+
nccl_extended_functions = [
|
|
46
|
+
# ncclResult_t ncclCommSplit(
|
|
47
|
+
# ncclComm_t comm, int color, int key, ncclComm_t *newcomm, NcclConfigT *config
|
|
48
|
+
# )
|
|
49
|
+
Function(
|
|
50
|
+
"ncclCommSplit",
|
|
51
|
+
ncclResult_t,
|
|
52
|
+
[
|
|
53
|
+
ncclComm_t,
|
|
54
|
+
ctypes.c_int,
|
|
55
|
+
ctypes.c_int,
|
|
56
|
+
ctypes.POINTER(ncclComm_t),
|
|
57
|
+
ctypes.POINTER(NcclConfigT),
|
|
58
|
+
],
|
|
59
|
+
),
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def nccl_comm_split(
|
|
64
|
+
self, # noqa: ANN001
|
|
65
|
+
comm: ncclComm_t,
|
|
66
|
+
color: int,
|
|
67
|
+
key: int,
|
|
68
|
+
) -> ncclComm_t:
|
|
69
|
+
newcomm = ncclComm_t()
|
|
70
|
+
|
|
71
|
+
self.NCCL_CHECK(self._funcs["ncclCommSplit"](comm, color, key, ctypes.byref(newcomm), None))
|
|
72
|
+
return newcomm
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# extend NCCLLibrary
|
|
76
|
+
NCCLLibrary.exported_functions = nccl_orig_exported_functions + nccl_extended_functions
|
|
77
|
+
NCCLLibrary.ncclCommSplit = nccl_comm_split
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class PyNcclCommunicatorEx(PyNcclCommunicator):
|
|
81
|
+
def destroy_comm(self, comm: ncclComm_t = None):
|
|
82
|
+
if comm:
|
|
83
|
+
self.nccl.ncclCommDestroy(comm)
|
|
84
|
+
else:
|
|
85
|
+
self.nccl.ncclCommDestroy(self.comm)
|
|
86
|
+
|
|
87
|
+
def create_newcomm(self, ranks: list[int]) -> ncclComm_t:
|
|
88
|
+
if self.rank in ranks:
|
|
89
|
+
color = 0
|
|
90
|
+
else:
|
|
91
|
+
color = -1 # NCCL_SPLIT_NOCOLOR
|
|
92
|
+
newcomm = self.nccl.ncclCommSplit(self.comm, color, self.rank)
|
|
93
|
+
return newcomm
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class DistributedNccl(Distributed):
|
|
97
|
+
def __init__(self):
|
|
98
|
+
self.pg: StatelessProcessGroup = None
|
|
99
|
+
self.pynccl: PyNcclCommunicatorEx = None
|
|
100
|
+
self.sub_groups: dict[int, list[int]] = {}
|
|
101
|
+
self.comm: ncclComm_t = None
|
|
102
|
+
|
|
103
|
+
self.host: str = None
|
|
104
|
+
self.port: int = None
|
|
105
|
+
self.rank: int = None
|
|
106
|
+
self.world_size: int = None
|
|
107
|
+
self.device: torch.device = None
|
|
108
|
+
|
|
109
|
+
self.initialized: bool = False
|
|
110
|
+
|
|
111
|
+
@contextmanager
|
|
112
|
+
def _use_group(self, group: CommGroup | None, src: int | None = None):
|
|
113
|
+
active_src = src
|
|
114
|
+
if group:
|
|
115
|
+
assert group.handle in self.sub_groups, "invalid sub_group"
|
|
116
|
+
newcomm = ctypes.c_void_p(group.handle)
|
|
117
|
+
self.pynccl.comm = newcomm
|
|
118
|
+
|
|
119
|
+
if src is not None:
|
|
120
|
+
assert src in group.ranks, "src rank not in group"
|
|
121
|
+
# convert src rank id in default world to newcomm
|
|
122
|
+
active_src = group.ranks.index(src)
|
|
123
|
+
self.pynccl.rank = group.ranks.index(self.rank)
|
|
124
|
+
|
|
125
|
+
try:
|
|
126
|
+
yield active_src
|
|
127
|
+
finally:
|
|
128
|
+
if group:
|
|
129
|
+
self.pynccl.comm = self.comm
|
|
130
|
+
if src is not None:
|
|
131
|
+
self.pynccl.rank = self.rank
|
|
132
|
+
|
|
133
|
+
def init_process_group(
|
|
134
|
+
self,
|
|
135
|
+
rank: int,
|
|
136
|
+
world_size: int,
|
|
137
|
+
store: torch.distributed.TCPStore,
|
|
138
|
+
**kwargs,
|
|
139
|
+
):
|
|
140
|
+
assert not self.initialized, "already initialized"
|
|
141
|
+
|
|
142
|
+
self.rank = rank
|
|
143
|
+
self.world_size = world_size
|
|
144
|
+
self.device = torch.device("cuda", torch.cuda.current_device())
|
|
145
|
+
|
|
146
|
+
self.pg = StatelessProcessGroup(rank=rank, world_size=world_size, store=store, socket=None)
|
|
147
|
+
self.pynccl = PyNcclCommunicatorEx(group=self.pg, device=self.device)
|
|
148
|
+
self.comm = self.pynccl.comm
|
|
149
|
+
self.initialized = True
|
|
150
|
+
|
|
151
|
+
def destroy_process_group(
|
|
152
|
+
self,
|
|
153
|
+
group: CommGroup | None = None,
|
|
154
|
+
):
|
|
155
|
+
assert self.initialized, "not initialized"
|
|
156
|
+
|
|
157
|
+
if group and group.handle in self.sub_groups:
|
|
158
|
+
newcomm = ctypes.c_void_p(group.handle)
|
|
159
|
+
self.pynccl.destroy_comm(newcomm)
|
|
160
|
+
del self.sub_groups[group.handle]
|
|
161
|
+
return
|
|
162
|
+
|
|
163
|
+
self.pynccl.destroy_comm()
|
|
164
|
+
self.pynccl = None
|
|
165
|
+
self.pg = None
|
|
166
|
+
self.initialized = False
|
|
167
|
+
|
|
168
|
+
def is_initialized(self) -> bool:
|
|
169
|
+
return self.initialized
|
|
170
|
+
|
|
171
|
+
def all_gather_object(self, object_list: list[Any], obj: Any, group: CommGroup | None = None):
|
|
172
|
+
assert self.initialized, "not initialized"
|
|
173
|
+
|
|
174
|
+
with self._use_group(group):
|
|
175
|
+
_common_all_gather_object(self.pynccl, self.device, self.world_size, object_list, obj)
|
|
176
|
+
current_stream().synchronize()
|
|
177
|
+
|
|
178
|
+
def all_reduce(
|
|
179
|
+
self,
|
|
180
|
+
tensor: torch.Tensor,
|
|
181
|
+
op: ReduceOp.RedOpType = ReduceOp.SUM,
|
|
182
|
+
group: CommGroup | None = None,
|
|
183
|
+
**kwargs,
|
|
184
|
+
):
|
|
185
|
+
assert self.initialized, "not initialized"
|
|
186
|
+
|
|
187
|
+
with self._use_group(group):
|
|
188
|
+
out_tensor = self.pynccl.all_reduce(in_tensor=tensor, op=op)
|
|
189
|
+
current_stream().synchronize()
|
|
190
|
+
tensor.copy_(out_tensor)
|
|
191
|
+
|
|
192
|
+
def broadcast(
|
|
193
|
+
self, tensor: torch.Tensor, src: int | None = None, group: CommGroup | None = None, **kwargs
|
|
194
|
+
):
|
|
195
|
+
assert self.initialized, "not initialized"
|
|
196
|
+
|
|
197
|
+
with self._use_group(group, src) as local_src:
|
|
198
|
+
self.pynccl.broadcast(tensor, local_src)
|
|
199
|
+
current_stream().synchronize()
|
|
200
|
+
|
|
201
|
+
def barrier(self, group: CommGroup | None = None, **kwargs):
|
|
202
|
+
assert self.initialized, "not initialized"
|
|
203
|
+
|
|
204
|
+
with self._use_group(group):
|
|
205
|
+
data = torch.zeros(1, device=self.device)
|
|
206
|
+
self.pynccl.all_reduce(data)
|
|
207
|
+
current_stream().synchronize()
|
|
208
|
+
|
|
209
|
+
def new_group(self, ranks: list[int], **kwargs) -> CommGroup | None:
|
|
210
|
+
assert self.initialized, "not initialized"
|
|
211
|
+
|
|
212
|
+
# ranks is None or []
|
|
213
|
+
if not ranks:
|
|
214
|
+
ranks = list(range(self.world_size))
|
|
215
|
+
else:
|
|
216
|
+
ranks.sort()
|
|
217
|
+
|
|
218
|
+
group: CommGroup = None
|
|
219
|
+
newcomm = self.pynccl.create_newcomm(ranks)
|
|
220
|
+
if newcomm:
|
|
221
|
+
group = CommGroup(newcomm.value, ranks)
|
|
222
|
+
self.sub_groups[newcomm.value] = group
|
|
223
|
+
return group
|
checkpoint_engine/ps.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import ctypes
|
|
2
|
+
import gc
|
|
2
3
|
import os
|
|
3
4
|
import threading
|
|
4
5
|
from collections import defaultdict
|
|
@@ -7,11 +8,12 @@ from datetime import timedelta
|
|
|
7
8
|
from typing import TYPE_CHECKING
|
|
8
9
|
|
|
9
10
|
import torch
|
|
10
|
-
import torch.distributed
|
|
11
|
+
import torch.distributed
|
|
11
12
|
import zmq
|
|
12
13
|
from loguru import logger
|
|
13
14
|
from torch.multiprocessing.reductions import reduce_tensor
|
|
14
15
|
|
|
16
|
+
import checkpoint_engine.distributed as dist
|
|
15
17
|
from checkpoint_engine.data_types import (
|
|
16
18
|
BucketRange,
|
|
17
19
|
DataToGather,
|
|
@@ -175,6 +177,8 @@ class ParameterServer:
|
|
|
175
177
|
auto_pg: bool = True,
|
|
176
178
|
gpu_count: int | None = None,
|
|
177
179
|
mem_fraction: float | None = None,
|
|
180
|
+
master_addr: str | None = None,
|
|
181
|
+
master_port: int | None = None,
|
|
178
182
|
):
|
|
179
183
|
"""
|
|
180
184
|
Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
|
|
@@ -228,6 +232,17 @@ class ParameterServer:
|
|
|
228
232
|
self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
|
|
229
233
|
self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
|
|
230
234
|
|
|
235
|
+
master_addr = master_addr or os.getenv("MASTER_ADDR")
|
|
236
|
+
assert master_addr, "master_addr is required"
|
|
237
|
+
self._store = torch.distributed.TCPStore(
|
|
238
|
+
master_addr,
|
|
239
|
+
_get_master_port(master_port),
|
|
240
|
+
self._world_size,
|
|
241
|
+
timeout=timedelta(minutes=10),
|
|
242
|
+
is_master=self._rank == 0,
|
|
243
|
+
)
|
|
244
|
+
self._store_counter = 0
|
|
245
|
+
|
|
231
246
|
def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
|
|
232
247
|
if checkpoint_name == self._current_shared_memory_pool_user:
|
|
233
248
|
assert self._memory_pool[self.shared_memory_pool_name], (
|
|
@@ -487,8 +502,6 @@ class ParameterServer:
|
|
|
487
502
|
def init_process_group(
|
|
488
503
|
self,
|
|
489
504
|
*,
|
|
490
|
-
master_addr: str | None = None,
|
|
491
|
-
master_port: int | None = None,
|
|
492
505
|
timeout: timedelta = timedelta(minutes=10),
|
|
493
506
|
):
|
|
494
507
|
"""
|
|
@@ -498,27 +511,18 @@ class ParameterServer:
|
|
|
498
511
|
master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
|
|
499
512
|
timeout: The timeout of the process group.
|
|
500
513
|
"""
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
store = dist.TCPStore(
|
|
504
|
-
master_addr,
|
|
505
|
-
_get_master_port(master_port),
|
|
506
|
-
self._world_size,
|
|
507
|
-
timeout=timeout,
|
|
508
|
-
is_master=self._rank == 0,
|
|
509
|
-
)
|
|
514
|
+
self._store_counter += 1
|
|
515
|
+
sub_store = torch.distributed.PrefixStore(f"prefix-{self._store_counter}", self._store)
|
|
510
516
|
dist.init_process_group(
|
|
511
517
|
backend=self.device_manager.backend,
|
|
512
518
|
world_size=self._world_size,
|
|
513
519
|
rank=self._rank,
|
|
514
520
|
timeout=timeout,
|
|
515
|
-
store=
|
|
521
|
+
store=sub_store,
|
|
516
522
|
)
|
|
517
523
|
logger.info(f"[rank{self._rank}] init process group successfully.")
|
|
518
524
|
|
|
519
|
-
def store_based_barrier(
|
|
520
|
-
self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5)
|
|
521
|
-
) -> None:
|
|
525
|
+
def store_based_barrier(self, timeout: timedelta = timedelta(minutes=5)) -> None:
|
|
522
526
|
"""
|
|
523
527
|
Perform a store-based barrier synchronization across all ranks.
|
|
524
528
|
|
|
@@ -529,9 +533,9 @@ class ParameterServer:
|
|
|
529
533
|
Args:
|
|
530
534
|
store: The TCPStore instance to use for synchronization.
|
|
531
535
|
"""
|
|
532
|
-
|
|
536
|
+
torch.distributed.distributed_c10d._store_based_barrier(
|
|
533
537
|
rank=self._rank,
|
|
534
|
-
store=
|
|
538
|
+
store=self._store,
|
|
535
539
|
group_name="parameter_server_barrier",
|
|
536
540
|
rendezvous_count=self._world_size,
|
|
537
541
|
timeout=timeout,
|
|
@@ -544,8 +548,6 @@ class ParameterServer:
|
|
|
544
548
|
*,
|
|
545
549
|
timeout: timedelta = timedelta(minutes=10),
|
|
546
550
|
ranks: list[int] | None = None,
|
|
547
|
-
master_addr: str | None = None,
|
|
548
|
-
master_port: int | None = None,
|
|
549
551
|
) -> None:
|
|
550
552
|
"""
|
|
551
553
|
Update the checkpoint to inference engine. This function should be called after gather_metas.
|
|
@@ -566,28 +568,12 @@ class ParameterServer:
|
|
|
566
568
|
assert req_func is not None, "req_func is required"
|
|
567
569
|
ranks_group = None
|
|
568
570
|
try:
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
if self._auto_pg:
|
|
572
|
-
if not dist.is_initialized():
|
|
573
|
-
self.init_process_group(
|
|
574
|
-
timeout=timeout, master_addr=master_addr, master_port=master_port
|
|
575
|
-
)
|
|
576
|
-
manager_store = dist.distributed_c10d._get_default_store()
|
|
577
|
-
else:
|
|
578
|
-
# HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
|
|
579
|
-
# If master_port is provided, use master_port+1 for barrier store
|
|
580
|
-
manager_store = dist.TCPStore(
|
|
581
|
-
master_addr,
|
|
582
|
-
_get_master_port(master_port) + 1,
|
|
583
|
-
self._world_size,
|
|
584
|
-
timeout=timeout,
|
|
585
|
-
is_master=self._rank == 0,
|
|
586
|
-
)
|
|
571
|
+
if self._auto_pg and not dist.is_initialized():
|
|
572
|
+
self.init_process_group(timeout=timeout)
|
|
587
573
|
# if ranks is None or [], it will use fully broadcast to update to all ranks
|
|
588
574
|
ranks_group = dist.new_group(ranks) if ranks else None
|
|
589
575
|
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
|
|
590
|
-
self.store_based_barrier(
|
|
576
|
+
self.store_based_barrier()
|
|
591
577
|
except Exception as e:
|
|
592
578
|
logger.exception(
|
|
593
579
|
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
|
|
@@ -616,7 +602,10 @@ class ParameterServer:
|
|
|
616
602
|
return socket, socket_paths
|
|
617
603
|
|
|
618
604
|
def _detect_bucket_size(
|
|
619
|
-
self,
|
|
605
|
+
self,
|
|
606
|
+
ranks_group: dist.DistributedProcessGroup | None,
|
|
607
|
+
*,
|
|
608
|
+
disable_h2d_buffer: bool = False,
|
|
620
609
|
) -> tuple[int, bool]:
|
|
621
610
|
GiB = 1 << 30 # noqa: N806
|
|
622
611
|
# auto detect bucket size
|
|
@@ -633,7 +622,7 @@ class ParameterServer:
|
|
|
633
622
|
dtype=torch.int64,
|
|
634
623
|
device=self.device_manager.device_type,
|
|
635
624
|
)
|
|
636
|
-
dist.all_reduce(tensor, op=
|
|
625
|
+
dist.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN, group=ranks_group)
|
|
637
626
|
tensor = tensor.cpu()
|
|
638
627
|
free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
|
|
639
628
|
max_tensor_bytes = 0
|
|
@@ -735,7 +724,7 @@ class ParameterServer:
|
|
|
735
724
|
self,
|
|
736
725
|
checkpoint_name: str,
|
|
737
726
|
req_func: Callable[[list[tuple[str, str]]], None],
|
|
738
|
-
ranks_group: dist.
|
|
727
|
+
ranks_group: dist.DistributedProcessGroup | None,
|
|
739
728
|
ranks: list[int] | None = None,
|
|
740
729
|
):
|
|
741
730
|
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
|
|
@@ -854,7 +843,7 @@ class ParameterServer:
|
|
|
854
843
|
f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
|
|
855
844
|
)
|
|
856
845
|
ret_code.fill_(1)
|
|
857
|
-
dist.all_reduce(ret_code, op=
|
|
846
|
+
dist.all_reduce(ret_code, op=torch.distributed.ReduceOp.SUM, group=ranks_group)
|
|
858
847
|
self.device_manager.device_module.synchronize()
|
|
859
848
|
if ret_code.item() != 0:
|
|
860
849
|
# quit early if any rank failed
|
|
@@ -864,6 +853,29 @@ class ParameterServer:
|
|
|
864
853
|
gidx += 1
|
|
865
854
|
|
|
866
855
|
socket.recv()
|
|
856
|
+
device_mem = self.device_manager.device_module.mem_get_info()
|
|
857
|
+
logger.info(
|
|
858
|
+
f"[rank{self._rank}] weights broadcast done, device mem usage: {(device_mem[1] - device_mem[0]) / 1024 / 1024:.2f} MB, allocated memory: {self.device_manager.device_module.memory_allocated() / 1024 / 1024:.2f} MB, reserved memory: {self.device_manager.device_module.memory_reserved() / 1024 / 1024:.2f} MB"
|
|
859
|
+
)
|
|
860
|
+
# Notify worker to release handle
|
|
861
|
+
socket.send_pyobj(None)
|
|
862
|
+
socket.recv()
|
|
863
|
+
# Set to None in correct order (views first, then base tensors)
|
|
864
|
+
del buffer_b, h2d_buffer, buffer, handle
|
|
865
|
+
self.device_manager.device_module.synchronize()
|
|
866
|
+
gc.collect()
|
|
867
|
+
self.device_manager.device_module.ipc_collect()
|
|
868
|
+
self.device_manager.device_module.empty_cache()
|
|
869
|
+
self.device_manager.device_module.synchronize()
|
|
870
|
+
|
|
871
|
+
# Log actual memory usage
|
|
872
|
+
device_mem = self.device_manager.device_module.mem_get_info()
|
|
873
|
+
logger.info(
|
|
874
|
+
f"[rank{self._rank}] post-release: device mem usage: {(device_mem[1] - device_mem[0]) / 1024 / 1024:.2f} MB, "
|
|
875
|
+
f"allocated: {self.device_manager.device_module.memory_allocated() / 1024 / 1024:.2f} MB, "
|
|
876
|
+
f"reserved: {self.device_manager.device_module.memory_reserved() / 1024 / 1024:.2f} MB"
|
|
877
|
+
)
|
|
878
|
+
# Notify worker to call post_hook
|
|
867
879
|
socket.send_pyobj(None)
|
|
868
880
|
socket.recv()
|
|
869
881
|
finally:
|
checkpoint_engine/worker.py
CHANGED
|
@@ -10,6 +10,9 @@ import zmq
|
|
|
10
10
|
from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
|
|
11
11
|
|
|
12
12
|
|
|
13
|
+
_WEIGHTS_TYPE = list[tuple[str, torch.Tensor]]
|
|
14
|
+
|
|
15
|
+
|
|
13
16
|
def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
|
|
14
17
|
func, args = handle
|
|
15
18
|
list_args = list(args)
|
|
@@ -29,11 +32,9 @@ class FlattenedTensorMetadata(TypedDict):
|
|
|
29
32
|
offset: int
|
|
30
33
|
|
|
31
34
|
|
|
32
|
-
def _extract_weights(
|
|
33
|
-
payload: list[FlattenedTensorMetadata], buffer: torch.Tensor
|
|
34
|
-
) -> list[tuple[str, torch.Tensor]]:
|
|
35
|
+
def _extract_weights(payload: list[FlattenedTensorMetadata], buffer: torch.Tensor) -> _WEIGHTS_TYPE:
|
|
35
36
|
assert buffer is not None
|
|
36
|
-
weights:
|
|
37
|
+
weights: _WEIGHTS_TYPE = []
|
|
37
38
|
for item in payload:
|
|
38
39
|
shape = item["shape"]
|
|
39
40
|
if isinstance(shape, list | tuple):
|
|
@@ -69,15 +70,35 @@ def update_weights_from_ipc(
|
|
|
69
70
|
socket.send_string(msg)
|
|
70
71
|
socket.recv() # wait for ack
|
|
71
72
|
raise
|
|
73
|
+
# State machine:
|
|
74
|
+
# + receive tensor_metadata -> update_weights
|
|
75
|
+
# + receive Exception -> raise and stop
|
|
76
|
+
# + receive None first time -> release resources
|
|
77
|
+
# + receive None second time -> call post_hook and stop
|
|
72
78
|
try:
|
|
79
|
+
released = False
|
|
73
80
|
while True:
|
|
74
81
|
payload: list[FlattenedTensorMetadata] | Exception | None = socket.recv_pyobj()
|
|
75
|
-
if
|
|
82
|
+
if released:
|
|
83
|
+
assert payload is None, "Should not receive any payload after released"
|
|
76
84
|
if post_hook is not None:
|
|
77
85
|
post_hook()
|
|
78
86
|
device_manager.device_module.synchronize()
|
|
79
87
|
socket.send(b"")
|
|
80
88
|
break
|
|
89
|
+
if payload is None: # done signal
|
|
90
|
+
# TODO: wrap all messages to an object instead of None and Exception
|
|
91
|
+
device_manager.device_module.synchronize()
|
|
92
|
+
released = True
|
|
93
|
+
buffer = None
|
|
94
|
+
del ipc_handle
|
|
95
|
+
|
|
96
|
+
gc.collect()
|
|
97
|
+
device_manager.device_module.ipc_collect()
|
|
98
|
+
device_manager.device_module.empty_cache()
|
|
99
|
+
device_manager.device_module.synchronize()
|
|
100
|
+
socket.send(b"")
|
|
101
|
+
continue
|
|
81
102
|
if isinstance(payload, list): # still updating weights
|
|
82
103
|
try:
|
|
83
104
|
run(_extract_weights(payload, buffer))
|
|
@@ -166,12 +187,31 @@ class VllmColocateWorkerExtension:
|
|
|
166
187
|
self.device = torch.device(f"npu:{self.local_rank}")
|
|
167
188
|
assert self.device is not None
|
|
168
189
|
|
|
190
|
+
def _load_weights(weights: _WEIGHTS_TYPE):
|
|
191
|
+
# Load main model weights
|
|
192
|
+
self.model_runner.model.load_weights(weights)
|
|
193
|
+
# Load drafter model weights if MTP/speculative decoding is enabled
|
|
194
|
+
if (
|
|
195
|
+
getattr(self.model_runner, "drafter", None) is not None
|
|
196
|
+
and getattr(self.model_runner.drafter, "model", None) is not None
|
|
197
|
+
):
|
|
198
|
+
self.model_runner.drafter.model.load_weights(weights=weights)
|
|
199
|
+
|
|
200
|
+
def _post_hook():
|
|
201
|
+
process_weights_after_loading(self.model_runner.model, self.model_config, self.device)
|
|
202
|
+
# Also trigger drafter model's post processing if MTP is enabled
|
|
203
|
+
if (
|
|
204
|
+
getattr(self.model_runner, "drafter", None) is not None
|
|
205
|
+
and getattr(self.model_runner.drafter, "model", None) is not None
|
|
206
|
+
):
|
|
207
|
+
process_weights_after_loading(
|
|
208
|
+
self.model_runner.drafter.model, self.model_config, self.device
|
|
209
|
+
)
|
|
210
|
+
|
|
169
211
|
update_weights_from_ipc(
|
|
170
212
|
self._zmq_ctx,
|
|
171
213
|
zmq_handles[self._device_uuid],
|
|
172
214
|
device_id=self.device.index,
|
|
173
|
-
run=
|
|
174
|
-
post_hook=
|
|
175
|
-
self.model_runner.model, self.model_config, self.device
|
|
176
|
-
),
|
|
215
|
+
run=_load_weights,
|
|
216
|
+
post_hook=_post_hook,
|
|
177
217
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
checkpoint_engine/__init__.py,sha256=OeWxe9mxl2sZ6cW-blSTg6JbFlOMpGbBghLZtxGOqXk,942
|
|
2
|
+
checkpoint_engine/__main__.py,sha256=yzQlApuYo6eIOqtqM018RosyxNzXzB5a-stxUvsh-dg,709
|
|
3
|
+
checkpoint_engine/_version.py,sha256=2_0GUP7yBCXRus-qiJKxQD62z172WSs1sQ6DVpPsbmM,704
|
|
4
|
+
checkpoint_engine/api.py,sha256=JDiQ4i3Gb6GoaBhlp8lNuUPaVURoFFdeGJY9ZDDGvPc,3518
|
|
5
|
+
checkpoint_engine/data_types.py,sha256=O9uAXjwB20iwrOHfEEQd8Y9CmaFspNJ9ks9noHqwQKk,2716
|
|
6
|
+
checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
|
|
7
|
+
checkpoint_engine/p2p_store.py,sha256=abiCDVmRISPt9QFfavHB9Jo7ZpBbSjUS1NevGuB-AVA,8721
|
|
8
|
+
checkpoint_engine/pin_memory.py,sha256=b7nABKJV2bSIsOfX2YTHzUk1OkOze6AQjCaOIFaQnbA,16708
|
|
9
|
+
checkpoint_engine/ps.py,sha256=DQ9-hvZJW0eA9d6bU1glIbSwYl4cZCmVRjFPPRF41YY,41957
|
|
10
|
+
checkpoint_engine/worker.py,sha256=fTWiF6Gggehzjx4mnIFTDZFR-GwkEBUdTAC_ZLmsgZE,8649
|
|
11
|
+
checkpoint_engine/distributed/__init__.py,sha256=fC0EEX1nfWkg8OolzAj5vd2P0x6s4hScOlwV8q8Uiik,492
|
|
12
|
+
checkpoint_engine/distributed/base.py,sha256=dpdjcGXNdCdAUDPnX-vdJmCGXbGS4A69yNsd60t-UgA,7800
|
|
13
|
+
checkpoint_engine/distributed/vllm_hccl.py,sha256=bLE-GrnOxu1GTw_2GIqu2o67_Sw7vgjzJnlMvvQz_8c,10313
|
|
14
|
+
checkpoint_engine/distributed/vllm_nccl.py,sha256=nHnlY1jk--xNEjKDDnywx36FgrnjEGc9lrBBC3o-YzE,7015
|
|
15
|
+
checkpoint_engine-0.4.0.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
16
|
+
checkpoint_engine-0.4.0.dist-info/METADATA,sha256=qtS4bAI6SC3nBatKqAi5EVkty2zkvEZOWfshpaswF6k,11559
|
|
17
|
+
checkpoint_engine-0.4.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
18
|
+
checkpoint_engine-0.4.0.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
19
|
+
checkpoint_engine-0.4.0.dist-info/RECORD,,
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
checkpoint_engine/__init__.py,sha256=OeWxe9mxl2sZ6cW-blSTg6JbFlOMpGbBghLZtxGOqXk,942
|
|
2
|
-
checkpoint_engine/__main__.py,sha256=yzQlApuYo6eIOqtqM018RosyxNzXzB5a-stxUvsh-dg,709
|
|
3
|
-
checkpoint_engine/_version.py,sha256=lemL_4Kl75FgrO6lVuFrrtw6-Dcf9wtXBalKkXuzkO4,704
|
|
4
|
-
checkpoint_engine/api.py,sha256=JDiQ4i3Gb6GoaBhlp8lNuUPaVURoFFdeGJY9ZDDGvPc,3518
|
|
5
|
-
checkpoint_engine/data_types.py,sha256=O9uAXjwB20iwrOHfEEQd8Y9CmaFspNJ9ks9noHqwQKk,2716
|
|
6
|
-
checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
|
|
7
|
-
checkpoint_engine/p2p_store.py,sha256=abiCDVmRISPt9QFfavHB9Jo7ZpBbSjUS1NevGuB-AVA,8721
|
|
8
|
-
checkpoint_engine/pin_memory.py,sha256=b7nABKJV2bSIsOfX2YTHzUk1OkOze6AQjCaOIFaQnbA,16708
|
|
9
|
-
checkpoint_engine/ps.py,sha256=wBsHu2qWy5oRBrvLc7aEOroG_j58UJoWT6lFH4ylMRk,41092
|
|
10
|
-
checkpoint_engine/worker.py,sha256=ghj9d2u8hY_U2uiOZWIN2CqRNZH6PrzujT22fHUFBWI,6879
|
|
11
|
-
checkpoint_engine-0.3.3.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
12
|
-
checkpoint_engine-0.3.3.dist-info/METADATA,sha256=WyyGLw1qrteQgRGOWhAm15NN2nzklTqw4iiQ9U2nYpQ,11559
|
|
13
|
-
checkpoint_engine-0.3.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
14
|
-
checkpoint_engine-0.3.3.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
15
|
-
checkpoint_engine-0.3.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|