checkpoint-engine 0.3.3__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/PKG-INFO +1 -1
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/_version.py +3 -3
- checkpoint_engine-0.4.0/checkpoint_engine/distributed/__init__.py +28 -0
- checkpoint_engine-0.4.0/checkpoint_engine/distributed/base.py +288 -0
- checkpoint_engine-0.4.0/checkpoint_engine/distributed/vllm_hccl.py +323 -0
- checkpoint_engine-0.4.0/checkpoint_engine/distributed/vllm_nccl.py +223 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/ps.py +55 -43
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/worker.py +49 -9
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine.egg-info/PKG-INFO +1 -1
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine.egg-info/SOURCES.txt +4 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/examples/update.py +4 -1
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/tests/test_reuse_pin_memory.py +2 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/tests/test_update.py +16 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/.github/workflows/cpu-tests.yml +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/.github/workflows/pre-commit.yaml +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/.github/workflows/python-publish.yml +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/.gitignore +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/.pre-commit-config.yaml +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/LICENCE +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/README.md +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/__init__.py +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/__main__.py +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/api.py +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/data_types.py +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/device_utils.py +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/p2p_store.py +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/pin_memory.py +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine.egg-info/dependency_links.txt +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine.egg-info/requires.txt +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine.egg-info/top_level.txt +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/docs/npu_start.md +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/figures/checkpoint-engine.png +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/figures/overlap-update-and-copy.png +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/figures/pipeline.png +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/patches/vllm_fp8.patch +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/pyproject.toml +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/setup.cfg +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/tests/test_assign_receiver_ranks.py +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/tests/test_inplace_unpin.py +0 -0
- {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/tests/test_rdma_parser.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0,
|
|
31
|
+
__version__ = version = '0.4.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 4, 0)
|
|
33
33
|
|
|
34
|
-
__commit_id__ = commit_id = '
|
|
34
|
+
__commit_id__ = commit_id = 'ge906b46e8'
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from .base import (
|
|
2
|
+
Distributed,
|
|
3
|
+
DistributedProcessGroup,
|
|
4
|
+
all_gather_object,
|
|
5
|
+
all_reduce,
|
|
6
|
+
barrier,
|
|
7
|
+
broadcast,
|
|
8
|
+
destroy_process_group,
|
|
9
|
+
init_process_group,
|
|
10
|
+
is_initialized,
|
|
11
|
+
new_group,
|
|
12
|
+
use_backend,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"Distributed",
|
|
18
|
+
"DistributedProcessGroup",
|
|
19
|
+
"all_gather_object",
|
|
20
|
+
"all_reduce",
|
|
21
|
+
"barrier",
|
|
22
|
+
"broadcast",
|
|
23
|
+
"destroy_process_group",
|
|
24
|
+
"init_process_group",
|
|
25
|
+
"is_initialized",
|
|
26
|
+
"new_group",
|
|
27
|
+
"use_backend",
|
|
28
|
+
]
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import io
|
|
3
|
+
import pickle
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from datetime import timedelta
|
|
6
|
+
from typing import Any, Protocol
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.distributed as torch_dist
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CommunicatorProtocol(Protocol):
|
|
13
|
+
def all_gather(self, *args: Any, **kwargs: Any) -> torch.Tensor: ...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CommGroup:
|
|
17
|
+
def __init__(self, comm_handle: int, ranks: list[int]):
|
|
18
|
+
self._comm = comm_handle
|
|
19
|
+
self._ranks = ranks
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def handle(self) -> int:
|
|
23
|
+
return self._comm
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def ranks(self) -> list[int]:
|
|
27
|
+
return self._ranks
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
DistributedProcessGroup = torch_dist.ProcessGroup | CommGroup
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Distributed(ABC):
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def init_process_group(
|
|
36
|
+
self,
|
|
37
|
+
rank: int,
|
|
38
|
+
world_size: int,
|
|
39
|
+
store: torch_dist.TCPStore,
|
|
40
|
+
**kwargs,
|
|
41
|
+
):
|
|
42
|
+
raise NotImplementedError
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def destroy_process_group(
|
|
46
|
+
self,
|
|
47
|
+
group: DistributedProcessGroup | None = None,
|
|
48
|
+
):
|
|
49
|
+
raise NotImplementedError
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def is_initialized(self) -> bool:
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def all_gather_object(
|
|
57
|
+
self,
|
|
58
|
+
object_list: list[Any],
|
|
59
|
+
obj: Any,
|
|
60
|
+
group: DistributedProcessGroup | None = None,
|
|
61
|
+
):
|
|
62
|
+
raise NotImplementedError
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def all_reduce(
|
|
66
|
+
self,
|
|
67
|
+
tensor: torch.Tensor,
|
|
68
|
+
op: torch_dist.ReduceOp.RedOpType,
|
|
69
|
+
group: DistributedProcessGroup | None = None,
|
|
70
|
+
**kwargs,
|
|
71
|
+
):
|
|
72
|
+
raise NotImplementedError
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
def broadcast(
|
|
76
|
+
self,
|
|
77
|
+
tensor: torch.Tensor,
|
|
78
|
+
src: int,
|
|
79
|
+
group: DistributedProcessGroup | None = None,
|
|
80
|
+
**kwargs,
|
|
81
|
+
):
|
|
82
|
+
raise NotImplementedError
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def barrier(
|
|
86
|
+
self,
|
|
87
|
+
group: DistributedProcessGroup | None = None,
|
|
88
|
+
**kwargs,
|
|
89
|
+
):
|
|
90
|
+
raise NotImplementedError
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
def new_group(
|
|
94
|
+
self,
|
|
95
|
+
ranks: list[int],
|
|
96
|
+
**kwargs,
|
|
97
|
+
):
|
|
98
|
+
raise NotImplementedError
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class TorchBackend(Distributed):
|
|
102
|
+
def init_process_group(
|
|
103
|
+
self,
|
|
104
|
+
rank: int,
|
|
105
|
+
world_size: int,
|
|
106
|
+
store: torch_dist.TCPStore,
|
|
107
|
+
**kwargs,
|
|
108
|
+
):
|
|
109
|
+
backend = kwargs.get("backend", "nccl")
|
|
110
|
+
timeout = kwargs.get("timeout", timedelta(minutes=10))
|
|
111
|
+
|
|
112
|
+
torch_dist.init_process_group(
|
|
113
|
+
backend=backend,
|
|
114
|
+
world_size=world_size,
|
|
115
|
+
rank=rank,
|
|
116
|
+
timeout=timeout,
|
|
117
|
+
store=store,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def destroy_process_group(self, group: DistributedProcessGroup | None = None):
|
|
121
|
+
torch_dist.destroy_process_group(group)
|
|
122
|
+
|
|
123
|
+
def is_initialized(self) -> bool:
|
|
124
|
+
return torch_dist.is_initialized()
|
|
125
|
+
|
|
126
|
+
def all_gather_object(
|
|
127
|
+
self, object_list: list[Any], obj: Any, group: DistributedProcessGroup | None = None
|
|
128
|
+
):
|
|
129
|
+
torch_dist.all_gather_object(object_list, obj, group)
|
|
130
|
+
|
|
131
|
+
def all_reduce(
|
|
132
|
+
self,
|
|
133
|
+
tensor: torch.Tensor,
|
|
134
|
+
op: torch_dist.ReduceOp.RedOpType = torch_dist.ReduceOp.SUM,
|
|
135
|
+
group: DistributedProcessGroup | None = None,
|
|
136
|
+
**kwargs,
|
|
137
|
+
):
|
|
138
|
+
torch_dist.all_reduce(tensor, op, group, **kwargs)
|
|
139
|
+
|
|
140
|
+
def broadcast(
|
|
141
|
+
self,
|
|
142
|
+
tensor: torch.Tensor,
|
|
143
|
+
src: int = 0,
|
|
144
|
+
group: DistributedProcessGroup | None = None,
|
|
145
|
+
**kwargs,
|
|
146
|
+
):
|
|
147
|
+
torch_dist.broadcast(tensor, src, group, **kwargs)
|
|
148
|
+
|
|
149
|
+
def barrier(self, group: DistributedProcessGroup | None = None, **kwargs):
|
|
150
|
+
torch_dist.barrier(group, **kwargs)
|
|
151
|
+
|
|
152
|
+
def new_group(self, ranks: list[int], **kwargs) -> DistributedProcessGroup | None:
|
|
153
|
+
return torch_dist.new_group(ranks, **kwargs)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# specific device instance
|
|
157
|
+
_BACKEND_INSTANCE: Distributed = TorchBackend()
|
|
158
|
+
|
|
159
|
+
_pickler = pickle.Pickler
|
|
160
|
+
_unpickler = pickle.Unpickler
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _object_to_tensor(obj: Any, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
|
164
|
+
f = io.BytesIO()
|
|
165
|
+
_pickler(f).dump(obj)
|
|
166
|
+
byte_storage = torch.ByteStorage._from_buffer(f.getvalue())
|
|
167
|
+
byte_tensor = torch.ByteTensor(byte_storage).to(device)
|
|
168
|
+
local_size = torch.LongTensor([byte_tensor.numel()]).to(device)
|
|
169
|
+
return byte_tensor, local_size
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _tensor_to_object(tensor: torch.Tensor, tensor_size: int) -> Any:
|
|
173
|
+
tensor = tensor.cpu()
|
|
174
|
+
buf = tensor.numpy().tobytes()[:tensor_size]
|
|
175
|
+
return _unpickler(io.BytesIO(buf)).load()
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _flatten_for_scatter_gather(
|
|
179
|
+
tensor_list: list[torch.Tensor], copy: bool = False
|
|
180
|
+
) -> torch.Tensor:
|
|
181
|
+
if not tensor_list:
|
|
182
|
+
raise RuntimeError("Received an empty list.")
|
|
183
|
+
t = tensor_list[0]
|
|
184
|
+
buffer_shape = [len(tensor_list)] + list(t.shape)
|
|
185
|
+
|
|
186
|
+
buffer = torch.empty(tuple(buffer_shape), dtype=t.dtype, device=t.device)
|
|
187
|
+
if copy:
|
|
188
|
+
for i, tensor in enumerate(tensor_list):
|
|
189
|
+
buffer[i].copy_(tensor)
|
|
190
|
+
return buffer
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _common_all_gather_object(
|
|
194
|
+
comm: CommunicatorProtocol,
|
|
195
|
+
device: torch.device,
|
|
196
|
+
world_size: int,
|
|
197
|
+
object_list: list[Any],
|
|
198
|
+
object: Any,
|
|
199
|
+
):
|
|
200
|
+
input_tensor, local_size = _object_to_tensor(object, device)
|
|
201
|
+
object_sizes_tensor = torch.empty(world_size, dtype=torch.long, device=device)
|
|
202
|
+
comm.all_gather(object_sizes_tensor, local_size)
|
|
203
|
+
object_size_list = [object_sizes_tensor[i].unsqueeze(dim=0) for i in range(world_size)]
|
|
204
|
+
max_object_size = int(max(object_size_list).item())
|
|
205
|
+
input_tensor.resize_(max_object_size)
|
|
206
|
+
coalesced_output_tensor = torch.empty(
|
|
207
|
+
max_object_size * world_size, dtype=torch.uint8, device=device
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
comm.all_gather(coalesced_output_tensor, input_tensor)
|
|
211
|
+
output_tensors = [
|
|
212
|
+
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
|
|
213
|
+
for i in range(world_size)
|
|
214
|
+
]
|
|
215
|
+
for i, tensor in enumerate(output_tensors):
|
|
216
|
+
tensor = tensor.type(torch.uint8)
|
|
217
|
+
tensor_size = object_size_list[i]
|
|
218
|
+
object_list[i] = _tensor_to_object(tensor, tensor_size)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def use_backend(backend: str | None):
|
|
222
|
+
global _BACKEND_INSTANCE
|
|
223
|
+
|
|
224
|
+
if not backend:
|
|
225
|
+
return
|
|
226
|
+
|
|
227
|
+
mapping = {
|
|
228
|
+
"vllm_nccl": ".vllm_nccl.DistributedNccl",
|
|
229
|
+
"vllm_hccl": ".vllm_hccl.DistributedHccl",
|
|
230
|
+
}
|
|
231
|
+
if backend not in mapping:
|
|
232
|
+
raise ValueError(f"Unsupported custom backend: {backend}")
|
|
233
|
+
|
|
234
|
+
module_path, class_name = mapping[backend].rsplit(".", 1)
|
|
235
|
+
module = importlib.import_module(module_path, "checkpoint_engine.distributed")
|
|
236
|
+
backend_class = getattr(module, class_name)
|
|
237
|
+
_BACKEND_INSTANCE = backend_class()
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def init_process_group(
|
|
241
|
+
rank: int,
|
|
242
|
+
world_size: int,
|
|
243
|
+
store: torch_dist.TCPStore,
|
|
244
|
+
**kwargs,
|
|
245
|
+
):
|
|
246
|
+
_BACKEND_INSTANCE.init_process_group(rank, world_size, store, **kwargs)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def destroy_process_group(group: DistributedProcessGroup | None = None):
|
|
250
|
+
_BACKEND_INSTANCE.destroy_process_group(group)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def is_initialized() -> bool:
|
|
254
|
+
return _BACKEND_INSTANCE.is_initialized()
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def all_gather_object(
|
|
258
|
+
object_list: list[Any],
|
|
259
|
+
obj: Any,
|
|
260
|
+
group: DistributedProcessGroup | None = None,
|
|
261
|
+
):
|
|
262
|
+
_BACKEND_INSTANCE.all_gather_object(object_list, obj, group)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def all_reduce(
|
|
266
|
+
tensor: torch.Tensor,
|
|
267
|
+
op: torch_dist.ReduceOp.RedOpType = torch_dist.ReduceOp.SUM,
|
|
268
|
+
group: DistributedProcessGroup | None = None,
|
|
269
|
+
**kwargs,
|
|
270
|
+
):
|
|
271
|
+
_BACKEND_INSTANCE.all_reduce(tensor, op, group, **kwargs)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def broadcast(
|
|
275
|
+
tensor: torch.Tensor,
|
|
276
|
+
src: int = 0,
|
|
277
|
+
group: DistributedProcessGroup | None = None,
|
|
278
|
+
**kwargs,
|
|
279
|
+
):
|
|
280
|
+
_BACKEND_INSTANCE.broadcast(tensor, src, group, **kwargs)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def barrier(group: DistributedProcessGroup | None = None, **kwargs):
|
|
284
|
+
_BACKEND_INSTANCE.barrier(group, **kwargs)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def new_group(ranks: list[int], **kwargs) -> DistributedProcessGroup | None:
|
|
288
|
+
return _BACKEND_INSTANCE.new_group(ranks, **kwargs)
|
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
import ctypes
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from typing import Any, ClassVar
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch.distributed import ReduceOp
|
|
7
|
+
from vllm.distributed.utils import StatelessProcessGroup
|
|
8
|
+
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator
|
|
9
|
+
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
|
|
10
|
+
Function,
|
|
11
|
+
HCCLLibrary,
|
|
12
|
+
aclrtStream_t,
|
|
13
|
+
buffer_type,
|
|
14
|
+
hcclComm_t,
|
|
15
|
+
hcclDataType_t,
|
|
16
|
+
hcclDataTypeEnum,
|
|
17
|
+
hcclResult_t,
|
|
18
|
+
)
|
|
19
|
+
from vllm_ascend.utils import current_stream
|
|
20
|
+
|
|
21
|
+
from checkpoint_engine.distributed.base import CommGroup, Distributed, _common_all_gather_object
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class HcclCommConfig(ctypes.Structure):
|
|
25
|
+
_fields_: ClassVar[list[tuple[str, Any]]] = [
|
|
26
|
+
("size", ctypes.c_size_t),
|
|
27
|
+
("magic_word", ctypes.c_uint32),
|
|
28
|
+
("version", ctypes.c_uint32),
|
|
29
|
+
("reserved", ctypes.c_uint64),
|
|
30
|
+
("hccl_buffer_size", ctypes.c_uint32),
|
|
31
|
+
("hccl_deterministic", ctypes.c_uint32),
|
|
32
|
+
("hccl_comm_name", ctypes.c_char * 128),
|
|
33
|
+
("hccl_udi", ctypes.c_char * 128),
|
|
34
|
+
("hccl_op_expansion_mode", ctypes.c_uint32),
|
|
35
|
+
("hccl_rdma_traffic_class", ctypes.c_uint32),
|
|
36
|
+
("hccl_rdma_service_level", ctypes.c_uint32),
|
|
37
|
+
("hcll_world_rank_id", ctypes.c_uint32),
|
|
38
|
+
("hccl_job_id", ctypes.c_uint64),
|
|
39
|
+
("comm_engine", ctypes.c_int32),
|
|
40
|
+
("thread_num", ctypes.c_uint32),
|
|
41
|
+
("notify_num_per_thread", ctypes.c_uint32),
|
|
42
|
+
("acl_graph_zero_copy_enable", ctypes.c_uint8),
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
orig_exported_functions = HCCLLibrary.exported_functions
|
|
47
|
+
extended_functions = [
|
|
48
|
+
# HcclResult HcclAllGather(
|
|
49
|
+
# void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType,
|
|
50
|
+
# HcclComm comm, alcrtStream stream
|
|
51
|
+
# )
|
|
52
|
+
Function(
|
|
53
|
+
"HcclAllGather",
|
|
54
|
+
hcclResult_t,
|
|
55
|
+
[
|
|
56
|
+
buffer_type,
|
|
57
|
+
buffer_type,
|
|
58
|
+
ctypes.c_uint64,
|
|
59
|
+
hcclDataType_t,
|
|
60
|
+
hcclComm_t,
|
|
61
|
+
aclrtStream_t,
|
|
62
|
+
],
|
|
63
|
+
),
|
|
64
|
+
# HcclResult HcclCreateSubCommConfig(
|
|
65
|
+
# HcclComm *comm, uin32_t rankNum, uint32_t *rankIds, uint64_t subCommId,
|
|
66
|
+
# uint32_t subCommRankId, HcclCommConfig *config, HcclComm *subComm
|
|
67
|
+
# )
|
|
68
|
+
Function(
|
|
69
|
+
"HcclCreateSubCommConfig",
|
|
70
|
+
hcclResult_t,
|
|
71
|
+
[
|
|
72
|
+
ctypes.POINTER(hcclComm_t),
|
|
73
|
+
ctypes.c_uint32,
|
|
74
|
+
ctypes.POINTER(ctypes.c_uint32),
|
|
75
|
+
ctypes.c_uint64,
|
|
76
|
+
ctypes.c_uint32,
|
|
77
|
+
ctypes.POINTER(HcclCommConfig),
|
|
78
|
+
ctypes.POINTER(hcclComm_t),
|
|
79
|
+
],
|
|
80
|
+
),
|
|
81
|
+
]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def hccl_all_gather(
|
|
85
|
+
self, # noqa: ANN001
|
|
86
|
+
send_buf: buffer_type,
|
|
87
|
+
recv_buf: buffer_type,
|
|
88
|
+
count: ctypes.c_uint64,
|
|
89
|
+
data_type: hcclDataType_t,
|
|
90
|
+
comm: hcclComm_t,
|
|
91
|
+
stream: aclrtStream_t,
|
|
92
|
+
):
|
|
93
|
+
self.HCCL_CHECK(
|
|
94
|
+
self._funcs["HcclAllGather"](send_buf, recv_buf, count, data_type, comm, stream)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def hccl_create_subcomm_config(
|
|
99
|
+
self, # noqa: ANN001
|
|
100
|
+
comm: hcclComm_t,
|
|
101
|
+
ranks_size: ctypes.c_uint32,
|
|
102
|
+
c_rank_ids: ctypes.POINTER(ctypes.c_uint32),
|
|
103
|
+
subcomm_id: ctypes.c_uint64,
|
|
104
|
+
subcomm_rank: ctypes.c_uint64,
|
|
105
|
+
comm_config: HcclCommConfig,
|
|
106
|
+
) -> hcclComm_t:
|
|
107
|
+
subcomm = hcclComm_t()
|
|
108
|
+
self.HCCL_CHECK(
|
|
109
|
+
self._funcs["HcclCreateSubCommConfig"](
|
|
110
|
+
ctypes.byref(comm),
|
|
111
|
+
ranks_size,
|
|
112
|
+
c_rank_ids,
|
|
113
|
+
subcomm_id,
|
|
114
|
+
subcomm_rank,
|
|
115
|
+
ctypes.byref(comm_config),
|
|
116
|
+
ctypes.byref(subcomm),
|
|
117
|
+
)
|
|
118
|
+
)
|
|
119
|
+
return subcomm
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
# extend HCCLLibrary
|
|
123
|
+
HCCLLibrary.exported_functions = orig_exported_functions + extended_functions
|
|
124
|
+
HCCLLibrary.hcclAllGather = hccl_all_gather
|
|
125
|
+
HCCLLibrary.hcclCreateSubCommConfig = hccl_create_subcomm_config
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class PyHcclCommunicatorEx(PyHcclCommunicator):
|
|
129
|
+
def __init__(self, group: StatelessProcessGroup, device: torch.device):
|
|
130
|
+
super().__init__(group, device)
|
|
131
|
+
self.subcomm_id = 1
|
|
132
|
+
|
|
133
|
+
def destroy_comm(self, comm: hcclComm_t = None):
|
|
134
|
+
if comm:
|
|
135
|
+
self.hccl.hcclCommDestroy(comm)
|
|
136
|
+
else:
|
|
137
|
+
self.hccl.hcclCommDestroy(self.comm)
|
|
138
|
+
|
|
139
|
+
def all_gather(
|
|
140
|
+
self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream: torch.npu.Stream = None
|
|
141
|
+
) -> torch.Tensor:
|
|
142
|
+
if self.disabled:
|
|
143
|
+
return
|
|
144
|
+
assert in_tensor.device == self.device, (
|
|
145
|
+
f"this hccl communicator is created to work on {self.device}, "
|
|
146
|
+
f"but the input tensor in on {in_tensor.device}"
|
|
147
|
+
)
|
|
148
|
+
if stream is None:
|
|
149
|
+
stream = current_stream()
|
|
150
|
+
self.hccl.hcclAllGather(
|
|
151
|
+
buffer_type(in_tensor.data_ptr()),
|
|
152
|
+
buffer_type(out_tensor.data_ptr()),
|
|
153
|
+
in_tensor.numel(),
|
|
154
|
+
hcclDataTypeEnum.from_torch(in_tensor.dtype),
|
|
155
|
+
self.comm, # todo
|
|
156
|
+
aclrtStream_t(stream.npu_stream),
|
|
157
|
+
)
|
|
158
|
+
return out_tensor
|
|
159
|
+
|
|
160
|
+
def create_subcomm(self, ranks: list[int]) -> hcclComm_t:
|
|
161
|
+
comm_config = HcclCommConfig(
|
|
162
|
+
size=312,
|
|
163
|
+
magic_word=0xF0F0F0F0,
|
|
164
|
+
version=6,
|
|
165
|
+
reserved=0,
|
|
166
|
+
hccl_buffer_size=0xFFFFFFFF,
|
|
167
|
+
hccl_deterministic=0xFFFFFFFF,
|
|
168
|
+
hccl_comm_name=b"\0",
|
|
169
|
+
hccl_udi=b"\0",
|
|
170
|
+
hccl_op_expansize_mode=0,
|
|
171
|
+
hccl_rdma_traffic_class=0xFFFFFFFF,
|
|
172
|
+
hccl_rdma_service_level=0xFFFFFFFF,
|
|
173
|
+
hccl_world_rank_id=0,
|
|
174
|
+
hccl_job_id=0,
|
|
175
|
+
comm_engine=-1,
|
|
176
|
+
thread_num=0xFFFFFFFF,
|
|
177
|
+
notify_num_per_thread=0xFFFFFFFF,
|
|
178
|
+
acl_graph_zero_copy_enable=0,
|
|
179
|
+
)
|
|
180
|
+
uint32_array = ctypes.c_uint32 * len(ranks)
|
|
181
|
+
c_rank_ids = uint32_array(*ranks)
|
|
182
|
+
subcomm_rank = ranks.index(self.rank)
|
|
183
|
+
ranks_size = len(ranks)
|
|
184
|
+
subcomm_id = self.subcomm_id
|
|
185
|
+
|
|
186
|
+
subcomm = self.hccl.hcclCreateSubCommConfig(
|
|
187
|
+
self.comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config
|
|
188
|
+
)
|
|
189
|
+
self.subcomm_id += 1
|
|
190
|
+
return subcomm
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class DistributedHccl(Distributed):
|
|
194
|
+
def __init__(self):
|
|
195
|
+
self.pg: StatelessProcessGroup = None
|
|
196
|
+
self.pyhccl: PyHcclCommunicatorEx = None
|
|
197
|
+
self.sub_groups: dict[int, CommGroup] = {}
|
|
198
|
+
self.comm: hcclComm_t = None
|
|
199
|
+
|
|
200
|
+
self.host: str = None
|
|
201
|
+
self.port: int = None
|
|
202
|
+
self.rank: int = None
|
|
203
|
+
self.world_size: int = None
|
|
204
|
+
self.device: torch.device = None
|
|
205
|
+
|
|
206
|
+
self.initialized: bool = False
|
|
207
|
+
|
|
208
|
+
@contextmanager
|
|
209
|
+
def _use_group(self, group: CommGroup | None, src: int | None = None):
|
|
210
|
+
active_src = src
|
|
211
|
+
if group:
|
|
212
|
+
assert group.handle in self.sub_groups, "invalid sub_group"
|
|
213
|
+
newcomm = ctypes.c_void_p(group.handle)
|
|
214
|
+
self.pyhccl.comm = newcomm
|
|
215
|
+
|
|
216
|
+
if src is not None:
|
|
217
|
+
assert src in group.ranks, "src rank not in group"
|
|
218
|
+
# convert src rank id in default world to newcomm
|
|
219
|
+
active_src = group.ranks.index(src)
|
|
220
|
+
self.pyhccl.rank = group.ranks.index(self.rank)
|
|
221
|
+
|
|
222
|
+
try:
|
|
223
|
+
yield active_src
|
|
224
|
+
finally:
|
|
225
|
+
if group:
|
|
226
|
+
self.pyhccl.comm = self.comm
|
|
227
|
+
if src is not None:
|
|
228
|
+
self.pyhccl.rank = self.rank
|
|
229
|
+
|
|
230
|
+
def init_process_group(
|
|
231
|
+
self,
|
|
232
|
+
rank: int,
|
|
233
|
+
world_size: int,
|
|
234
|
+
store: torch.distributed.TCPStore,
|
|
235
|
+
**kwargs,
|
|
236
|
+
):
|
|
237
|
+
assert not self.initialized, "already initialized"
|
|
238
|
+
|
|
239
|
+
self.rank = rank
|
|
240
|
+
self.world_size = world_size
|
|
241
|
+
self.device = torch.device("npu", torch.npu.current_device())
|
|
242
|
+
|
|
243
|
+
self.pg = StatelessProcessGroup(rank=rank, world_size=world_size, store=store, socket=None)
|
|
244
|
+
self.pyhccl = PyHcclCommunicatorEx(group=self.pg, device=self.device)
|
|
245
|
+
self.comm = self.pyhccl.comm
|
|
246
|
+
self.initialized = True
|
|
247
|
+
|
|
248
|
+
def destroy_process_group(
|
|
249
|
+
self,
|
|
250
|
+
group: CommGroup | None = None,
|
|
251
|
+
):
|
|
252
|
+
assert self.initialized, "not initialized"
|
|
253
|
+
|
|
254
|
+
if group and group.handle in self.sub_groups:
|
|
255
|
+
subcomm = ctypes.c_void_p(group.handle)
|
|
256
|
+
self.pyhccl.destroy_comm(subcomm)
|
|
257
|
+
del self.sub_groups[group.handle]
|
|
258
|
+
return
|
|
259
|
+
|
|
260
|
+
self.pyhccl.destroy_comm()
|
|
261
|
+
self.pyhccl = None
|
|
262
|
+
self.pg = None
|
|
263
|
+
self.initialized = False
|
|
264
|
+
|
|
265
|
+
def is_initialized(self) -> bool:
|
|
266
|
+
return self.initialized
|
|
267
|
+
|
|
268
|
+
def all_gather_object(self, object_list: list[Any], obj: Any, group: CommGroup | None = None):
|
|
269
|
+
assert self.initialized, "not initialized"
|
|
270
|
+
|
|
271
|
+
with self._use_group(group):
|
|
272
|
+
_common_all_gather_object(self.pyhccl, self.device, self.world_size, object_list, obj)
|
|
273
|
+
current_stream().synchronize()
|
|
274
|
+
|
|
275
|
+
def all_reduce(
|
|
276
|
+
self,
|
|
277
|
+
tensor: torch.Tensor,
|
|
278
|
+
op: ReduceOp.RedOpType = ReduceOp.SUM,
|
|
279
|
+
group: CommGroup | None = None,
|
|
280
|
+
**kwargs,
|
|
281
|
+
):
|
|
282
|
+
assert self.initialized, "not initialized"
|
|
283
|
+
|
|
284
|
+
with self._use_group(group):
|
|
285
|
+
out_tensor = self.pyhccl.all_reduce(tensor, op)
|
|
286
|
+
current_stream().synchronize()
|
|
287
|
+
tensor.copy_(out_tensor)
|
|
288
|
+
|
|
289
|
+
def broadcast(
|
|
290
|
+
self, tensor: torch.Tensor, src: int | None = None, group: CommGroup | None = None, **kwargs
|
|
291
|
+
):
|
|
292
|
+
assert self.initialized, "not initialized"
|
|
293
|
+
|
|
294
|
+
with self._use_group(group, src) as local_rank:
|
|
295
|
+
self.pyhccl.broadcast(tensor, local_rank)
|
|
296
|
+
current_stream().synchronize()
|
|
297
|
+
|
|
298
|
+
def barrier(self, group: CommGroup | None = None, **kwargs):
|
|
299
|
+
assert self.initialized, "not initialized"
|
|
300
|
+
|
|
301
|
+
with self._use_group(group):
|
|
302
|
+
data = torch.zeros(1, device=self.device)
|
|
303
|
+
self.pyhccl.all_reduce(data)
|
|
304
|
+
current_stream().synchronize()
|
|
305
|
+
|
|
306
|
+
def new_group(self, ranks: list[int], **kwargs) -> CommGroup | None:
|
|
307
|
+
assert self.initialized, "not initialized"
|
|
308
|
+
|
|
309
|
+
# ranks is None or []
|
|
310
|
+
if not ranks:
|
|
311
|
+
ranks = list(range(self.world_size))
|
|
312
|
+
else:
|
|
313
|
+
ranks.sort()
|
|
314
|
+
|
|
315
|
+
group: CommGroup = None
|
|
316
|
+
if self.rank not in ranks:
|
|
317
|
+
return group
|
|
318
|
+
|
|
319
|
+
subcomm = self.pyhccl.create_subcomm(ranks)
|
|
320
|
+
if subcomm:
|
|
321
|
+
group = CommGroup(subcomm.value, ranks)
|
|
322
|
+
self.sub_groups[subcomm.value] = group
|
|
323
|
+
return group
|