checkpoint-engine 0.3.3__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/PKG-INFO +1 -1
  2. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/_version.py +3 -3
  3. checkpoint_engine-0.4.0/checkpoint_engine/distributed/__init__.py +28 -0
  4. checkpoint_engine-0.4.0/checkpoint_engine/distributed/base.py +288 -0
  5. checkpoint_engine-0.4.0/checkpoint_engine/distributed/vllm_hccl.py +323 -0
  6. checkpoint_engine-0.4.0/checkpoint_engine/distributed/vllm_nccl.py +223 -0
  7. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/ps.py +55 -43
  8. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/worker.py +49 -9
  9. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine.egg-info/PKG-INFO +1 -1
  10. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine.egg-info/SOURCES.txt +4 -0
  11. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/examples/update.py +4 -1
  12. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/tests/test_reuse_pin_memory.py +2 -0
  13. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/tests/test_update.py +16 -0
  14. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/.github/workflows/cpu-tests.yml +0 -0
  15. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/.github/workflows/pre-commit.yaml +0 -0
  16. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/.github/workflows/python-publish.yml +0 -0
  17. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/.gitignore +0 -0
  18. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/.pre-commit-config.yaml +0 -0
  19. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/LICENCE +0 -0
  20. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/README.md +0 -0
  21. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/__init__.py +0 -0
  22. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/__main__.py +0 -0
  23. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/api.py +0 -0
  24. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/data_types.py +0 -0
  25. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/device_utils.py +0 -0
  26. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/p2p_store.py +0 -0
  27. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine/pin_memory.py +0 -0
  28. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine.egg-info/dependency_links.txt +0 -0
  29. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine.egg-info/requires.txt +0 -0
  30. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/checkpoint_engine.egg-info/top_level.txt +0 -0
  31. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/docs/npu_start.md +0 -0
  32. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/figures/checkpoint-engine.png +0 -0
  33. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/figures/overlap-update-and-copy.png +0 -0
  34. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/figures/pipeline.png +0 -0
  35. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/patches/vllm_fp8.patch +0 -0
  36. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/pyproject.toml +0 -0
  37. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/setup.cfg +0 -0
  38. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/tests/test_assign_receiver_ranks.py +0 -0
  39. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/tests/test_inplace_unpin.py +0 -0
  40. {checkpoint_engine-0.3.3 → checkpoint_engine-0.4.0}/tests/test_rdma_parser.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.3.3
3
+ Version: 0.4.0
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.3.3'
32
- __version_tuple__ = version_tuple = (0, 3, 3)
31
+ __version__ = version = '0.4.0'
32
+ __version_tuple__ = version_tuple = (0, 4, 0)
33
33
 
34
- __commit_id__ = commit_id = 'gf6910d646'
34
+ __commit_id__ = commit_id = 'ge906b46e8'
@@ -0,0 +1,28 @@
1
+ from .base import (
2
+ Distributed,
3
+ DistributedProcessGroup,
4
+ all_gather_object,
5
+ all_reduce,
6
+ barrier,
7
+ broadcast,
8
+ destroy_process_group,
9
+ init_process_group,
10
+ is_initialized,
11
+ new_group,
12
+ use_backend,
13
+ )
14
+
15
+
16
+ __all__ = [
17
+ "Distributed",
18
+ "DistributedProcessGroup",
19
+ "all_gather_object",
20
+ "all_reduce",
21
+ "barrier",
22
+ "broadcast",
23
+ "destroy_process_group",
24
+ "init_process_group",
25
+ "is_initialized",
26
+ "new_group",
27
+ "use_backend",
28
+ ]
@@ -0,0 +1,288 @@
1
+ import importlib
2
+ import io
3
+ import pickle
4
+ from abc import ABC, abstractmethod
5
+ from datetime import timedelta
6
+ from typing import Any, Protocol
7
+
8
+ import torch
9
+ import torch.distributed as torch_dist
10
+
11
+
12
+ class CommunicatorProtocol(Protocol):
13
+ def all_gather(self, *args: Any, **kwargs: Any) -> torch.Tensor: ...
14
+
15
+
16
+ class CommGroup:
17
+ def __init__(self, comm_handle: int, ranks: list[int]):
18
+ self._comm = comm_handle
19
+ self._ranks = ranks
20
+
21
+ @property
22
+ def handle(self) -> int:
23
+ return self._comm
24
+
25
+ @property
26
+ def ranks(self) -> list[int]:
27
+ return self._ranks
28
+
29
+
30
+ DistributedProcessGroup = torch_dist.ProcessGroup | CommGroup
31
+
32
+
33
+ class Distributed(ABC):
34
+ @abstractmethod
35
+ def init_process_group(
36
+ self,
37
+ rank: int,
38
+ world_size: int,
39
+ store: torch_dist.TCPStore,
40
+ **kwargs,
41
+ ):
42
+ raise NotImplementedError
43
+
44
+ @abstractmethod
45
+ def destroy_process_group(
46
+ self,
47
+ group: DistributedProcessGroup | None = None,
48
+ ):
49
+ raise NotImplementedError
50
+
51
+ @abstractmethod
52
+ def is_initialized(self) -> bool:
53
+ raise NotImplementedError
54
+
55
+ @abstractmethod
56
+ def all_gather_object(
57
+ self,
58
+ object_list: list[Any],
59
+ obj: Any,
60
+ group: DistributedProcessGroup | None = None,
61
+ ):
62
+ raise NotImplementedError
63
+
64
+ @abstractmethod
65
+ def all_reduce(
66
+ self,
67
+ tensor: torch.Tensor,
68
+ op: torch_dist.ReduceOp.RedOpType,
69
+ group: DistributedProcessGroup | None = None,
70
+ **kwargs,
71
+ ):
72
+ raise NotImplementedError
73
+
74
+ @abstractmethod
75
+ def broadcast(
76
+ self,
77
+ tensor: torch.Tensor,
78
+ src: int,
79
+ group: DistributedProcessGroup | None = None,
80
+ **kwargs,
81
+ ):
82
+ raise NotImplementedError
83
+
84
+ @abstractmethod
85
+ def barrier(
86
+ self,
87
+ group: DistributedProcessGroup | None = None,
88
+ **kwargs,
89
+ ):
90
+ raise NotImplementedError
91
+
92
+ @abstractmethod
93
+ def new_group(
94
+ self,
95
+ ranks: list[int],
96
+ **kwargs,
97
+ ):
98
+ raise NotImplementedError
99
+
100
+
101
+ class TorchBackend(Distributed):
102
+ def init_process_group(
103
+ self,
104
+ rank: int,
105
+ world_size: int,
106
+ store: torch_dist.TCPStore,
107
+ **kwargs,
108
+ ):
109
+ backend = kwargs.get("backend", "nccl")
110
+ timeout = kwargs.get("timeout", timedelta(minutes=10))
111
+
112
+ torch_dist.init_process_group(
113
+ backend=backend,
114
+ world_size=world_size,
115
+ rank=rank,
116
+ timeout=timeout,
117
+ store=store,
118
+ )
119
+
120
+ def destroy_process_group(self, group: DistributedProcessGroup | None = None):
121
+ torch_dist.destroy_process_group(group)
122
+
123
+ def is_initialized(self) -> bool:
124
+ return torch_dist.is_initialized()
125
+
126
+ def all_gather_object(
127
+ self, object_list: list[Any], obj: Any, group: DistributedProcessGroup | None = None
128
+ ):
129
+ torch_dist.all_gather_object(object_list, obj, group)
130
+
131
+ def all_reduce(
132
+ self,
133
+ tensor: torch.Tensor,
134
+ op: torch_dist.ReduceOp.RedOpType = torch_dist.ReduceOp.SUM,
135
+ group: DistributedProcessGroup | None = None,
136
+ **kwargs,
137
+ ):
138
+ torch_dist.all_reduce(tensor, op, group, **kwargs)
139
+
140
+ def broadcast(
141
+ self,
142
+ tensor: torch.Tensor,
143
+ src: int = 0,
144
+ group: DistributedProcessGroup | None = None,
145
+ **kwargs,
146
+ ):
147
+ torch_dist.broadcast(tensor, src, group, **kwargs)
148
+
149
+ def barrier(self, group: DistributedProcessGroup | None = None, **kwargs):
150
+ torch_dist.barrier(group, **kwargs)
151
+
152
+ def new_group(self, ranks: list[int], **kwargs) -> DistributedProcessGroup | None:
153
+ return torch_dist.new_group(ranks, **kwargs)
154
+
155
+
156
+ # specific device instance
157
+ _BACKEND_INSTANCE: Distributed = TorchBackend()
158
+
159
+ _pickler = pickle.Pickler
160
+ _unpickler = pickle.Unpickler
161
+
162
+
163
+ def _object_to_tensor(obj: Any, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
164
+ f = io.BytesIO()
165
+ _pickler(f).dump(obj)
166
+ byte_storage = torch.ByteStorage._from_buffer(f.getvalue())
167
+ byte_tensor = torch.ByteTensor(byte_storage).to(device)
168
+ local_size = torch.LongTensor([byte_tensor.numel()]).to(device)
169
+ return byte_tensor, local_size
170
+
171
+
172
+ def _tensor_to_object(tensor: torch.Tensor, tensor_size: int) -> Any:
173
+ tensor = tensor.cpu()
174
+ buf = tensor.numpy().tobytes()[:tensor_size]
175
+ return _unpickler(io.BytesIO(buf)).load()
176
+
177
+
178
+ def _flatten_for_scatter_gather(
179
+ tensor_list: list[torch.Tensor], copy: bool = False
180
+ ) -> torch.Tensor:
181
+ if not tensor_list:
182
+ raise RuntimeError("Received an empty list.")
183
+ t = tensor_list[0]
184
+ buffer_shape = [len(tensor_list)] + list(t.shape)
185
+
186
+ buffer = torch.empty(tuple(buffer_shape), dtype=t.dtype, device=t.device)
187
+ if copy:
188
+ for i, tensor in enumerate(tensor_list):
189
+ buffer[i].copy_(tensor)
190
+ return buffer
191
+
192
+
193
+ def _common_all_gather_object(
194
+ comm: CommunicatorProtocol,
195
+ device: torch.device,
196
+ world_size: int,
197
+ object_list: list[Any],
198
+ object: Any,
199
+ ):
200
+ input_tensor, local_size = _object_to_tensor(object, device)
201
+ object_sizes_tensor = torch.empty(world_size, dtype=torch.long, device=device)
202
+ comm.all_gather(object_sizes_tensor, local_size)
203
+ object_size_list = [object_sizes_tensor[i].unsqueeze(dim=0) for i in range(world_size)]
204
+ max_object_size = int(max(object_size_list).item())
205
+ input_tensor.resize_(max_object_size)
206
+ coalesced_output_tensor = torch.empty(
207
+ max_object_size * world_size, dtype=torch.uint8, device=device
208
+ )
209
+
210
+ comm.all_gather(coalesced_output_tensor, input_tensor)
211
+ output_tensors = [
212
+ coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
213
+ for i in range(world_size)
214
+ ]
215
+ for i, tensor in enumerate(output_tensors):
216
+ tensor = tensor.type(torch.uint8)
217
+ tensor_size = object_size_list[i]
218
+ object_list[i] = _tensor_to_object(tensor, tensor_size)
219
+
220
+
221
+ def use_backend(backend: str | None):
222
+ global _BACKEND_INSTANCE
223
+
224
+ if not backend:
225
+ return
226
+
227
+ mapping = {
228
+ "vllm_nccl": ".vllm_nccl.DistributedNccl",
229
+ "vllm_hccl": ".vllm_hccl.DistributedHccl",
230
+ }
231
+ if backend not in mapping:
232
+ raise ValueError(f"Unsupported custom backend: {backend}")
233
+
234
+ module_path, class_name = mapping[backend].rsplit(".", 1)
235
+ module = importlib.import_module(module_path, "checkpoint_engine.distributed")
236
+ backend_class = getattr(module, class_name)
237
+ _BACKEND_INSTANCE = backend_class()
238
+
239
+
240
+ def init_process_group(
241
+ rank: int,
242
+ world_size: int,
243
+ store: torch_dist.TCPStore,
244
+ **kwargs,
245
+ ):
246
+ _BACKEND_INSTANCE.init_process_group(rank, world_size, store, **kwargs)
247
+
248
+
249
+ def destroy_process_group(group: DistributedProcessGroup | None = None):
250
+ _BACKEND_INSTANCE.destroy_process_group(group)
251
+
252
+
253
+ def is_initialized() -> bool:
254
+ return _BACKEND_INSTANCE.is_initialized()
255
+
256
+
257
+ def all_gather_object(
258
+ object_list: list[Any],
259
+ obj: Any,
260
+ group: DistributedProcessGroup | None = None,
261
+ ):
262
+ _BACKEND_INSTANCE.all_gather_object(object_list, obj, group)
263
+
264
+
265
+ def all_reduce(
266
+ tensor: torch.Tensor,
267
+ op: torch_dist.ReduceOp.RedOpType = torch_dist.ReduceOp.SUM,
268
+ group: DistributedProcessGroup | None = None,
269
+ **kwargs,
270
+ ):
271
+ _BACKEND_INSTANCE.all_reduce(tensor, op, group, **kwargs)
272
+
273
+
274
+ def broadcast(
275
+ tensor: torch.Tensor,
276
+ src: int = 0,
277
+ group: DistributedProcessGroup | None = None,
278
+ **kwargs,
279
+ ):
280
+ _BACKEND_INSTANCE.broadcast(tensor, src, group, **kwargs)
281
+
282
+
283
+ def barrier(group: DistributedProcessGroup | None = None, **kwargs):
284
+ _BACKEND_INSTANCE.barrier(group, **kwargs)
285
+
286
+
287
+ def new_group(ranks: list[int], **kwargs) -> DistributedProcessGroup | None:
288
+ return _BACKEND_INSTANCE.new_group(ranks, **kwargs)
@@ -0,0 +1,323 @@
1
+ import ctypes
2
+ from contextlib import contextmanager
3
+ from typing import Any, ClassVar
4
+
5
+ import torch
6
+ from torch.distributed import ReduceOp
7
+ from vllm.distributed.utils import StatelessProcessGroup
8
+ from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator
9
+ from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
10
+ Function,
11
+ HCCLLibrary,
12
+ aclrtStream_t,
13
+ buffer_type,
14
+ hcclComm_t,
15
+ hcclDataType_t,
16
+ hcclDataTypeEnum,
17
+ hcclResult_t,
18
+ )
19
+ from vllm_ascend.utils import current_stream
20
+
21
+ from checkpoint_engine.distributed.base import CommGroup, Distributed, _common_all_gather_object
22
+
23
+
24
+ class HcclCommConfig(ctypes.Structure):
25
+ _fields_: ClassVar[list[tuple[str, Any]]] = [
26
+ ("size", ctypes.c_size_t),
27
+ ("magic_word", ctypes.c_uint32),
28
+ ("version", ctypes.c_uint32),
29
+ ("reserved", ctypes.c_uint64),
30
+ ("hccl_buffer_size", ctypes.c_uint32),
31
+ ("hccl_deterministic", ctypes.c_uint32),
32
+ ("hccl_comm_name", ctypes.c_char * 128),
33
+ ("hccl_udi", ctypes.c_char * 128),
34
+ ("hccl_op_expansion_mode", ctypes.c_uint32),
35
+ ("hccl_rdma_traffic_class", ctypes.c_uint32),
36
+ ("hccl_rdma_service_level", ctypes.c_uint32),
37
+ ("hcll_world_rank_id", ctypes.c_uint32),
38
+ ("hccl_job_id", ctypes.c_uint64),
39
+ ("comm_engine", ctypes.c_int32),
40
+ ("thread_num", ctypes.c_uint32),
41
+ ("notify_num_per_thread", ctypes.c_uint32),
42
+ ("acl_graph_zero_copy_enable", ctypes.c_uint8),
43
+ ]
44
+
45
+
46
+ orig_exported_functions = HCCLLibrary.exported_functions
47
+ extended_functions = [
48
+ # HcclResult HcclAllGather(
49
+ # void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType,
50
+ # HcclComm comm, alcrtStream stream
51
+ # )
52
+ Function(
53
+ "HcclAllGather",
54
+ hcclResult_t,
55
+ [
56
+ buffer_type,
57
+ buffer_type,
58
+ ctypes.c_uint64,
59
+ hcclDataType_t,
60
+ hcclComm_t,
61
+ aclrtStream_t,
62
+ ],
63
+ ),
64
+ # HcclResult HcclCreateSubCommConfig(
65
+ # HcclComm *comm, uin32_t rankNum, uint32_t *rankIds, uint64_t subCommId,
66
+ # uint32_t subCommRankId, HcclCommConfig *config, HcclComm *subComm
67
+ # )
68
+ Function(
69
+ "HcclCreateSubCommConfig",
70
+ hcclResult_t,
71
+ [
72
+ ctypes.POINTER(hcclComm_t),
73
+ ctypes.c_uint32,
74
+ ctypes.POINTER(ctypes.c_uint32),
75
+ ctypes.c_uint64,
76
+ ctypes.c_uint32,
77
+ ctypes.POINTER(HcclCommConfig),
78
+ ctypes.POINTER(hcclComm_t),
79
+ ],
80
+ ),
81
+ ]
82
+
83
+
84
+ def hccl_all_gather(
85
+ self, # noqa: ANN001
86
+ send_buf: buffer_type,
87
+ recv_buf: buffer_type,
88
+ count: ctypes.c_uint64,
89
+ data_type: hcclDataType_t,
90
+ comm: hcclComm_t,
91
+ stream: aclrtStream_t,
92
+ ):
93
+ self.HCCL_CHECK(
94
+ self._funcs["HcclAllGather"](send_buf, recv_buf, count, data_type, comm, stream)
95
+ )
96
+
97
+
98
+ def hccl_create_subcomm_config(
99
+ self, # noqa: ANN001
100
+ comm: hcclComm_t,
101
+ ranks_size: ctypes.c_uint32,
102
+ c_rank_ids: ctypes.POINTER(ctypes.c_uint32),
103
+ subcomm_id: ctypes.c_uint64,
104
+ subcomm_rank: ctypes.c_uint64,
105
+ comm_config: HcclCommConfig,
106
+ ) -> hcclComm_t:
107
+ subcomm = hcclComm_t()
108
+ self.HCCL_CHECK(
109
+ self._funcs["HcclCreateSubCommConfig"](
110
+ ctypes.byref(comm),
111
+ ranks_size,
112
+ c_rank_ids,
113
+ subcomm_id,
114
+ subcomm_rank,
115
+ ctypes.byref(comm_config),
116
+ ctypes.byref(subcomm),
117
+ )
118
+ )
119
+ return subcomm
120
+
121
+
122
+ # extend HCCLLibrary
123
+ HCCLLibrary.exported_functions = orig_exported_functions + extended_functions
124
+ HCCLLibrary.hcclAllGather = hccl_all_gather
125
+ HCCLLibrary.hcclCreateSubCommConfig = hccl_create_subcomm_config
126
+
127
+
128
+ class PyHcclCommunicatorEx(PyHcclCommunicator):
129
+ def __init__(self, group: StatelessProcessGroup, device: torch.device):
130
+ super().__init__(group, device)
131
+ self.subcomm_id = 1
132
+
133
+ def destroy_comm(self, comm: hcclComm_t = None):
134
+ if comm:
135
+ self.hccl.hcclCommDestroy(comm)
136
+ else:
137
+ self.hccl.hcclCommDestroy(self.comm)
138
+
139
+ def all_gather(
140
+ self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream: torch.npu.Stream = None
141
+ ) -> torch.Tensor:
142
+ if self.disabled:
143
+ return
144
+ assert in_tensor.device == self.device, (
145
+ f"this hccl communicator is created to work on {self.device}, "
146
+ f"but the input tensor in on {in_tensor.device}"
147
+ )
148
+ if stream is None:
149
+ stream = current_stream()
150
+ self.hccl.hcclAllGather(
151
+ buffer_type(in_tensor.data_ptr()),
152
+ buffer_type(out_tensor.data_ptr()),
153
+ in_tensor.numel(),
154
+ hcclDataTypeEnum.from_torch(in_tensor.dtype),
155
+ self.comm, # todo
156
+ aclrtStream_t(stream.npu_stream),
157
+ )
158
+ return out_tensor
159
+
160
+ def create_subcomm(self, ranks: list[int]) -> hcclComm_t:
161
+ comm_config = HcclCommConfig(
162
+ size=312,
163
+ magic_word=0xF0F0F0F0,
164
+ version=6,
165
+ reserved=0,
166
+ hccl_buffer_size=0xFFFFFFFF,
167
+ hccl_deterministic=0xFFFFFFFF,
168
+ hccl_comm_name=b"\0",
169
+ hccl_udi=b"\0",
170
+ hccl_op_expansize_mode=0,
171
+ hccl_rdma_traffic_class=0xFFFFFFFF,
172
+ hccl_rdma_service_level=0xFFFFFFFF,
173
+ hccl_world_rank_id=0,
174
+ hccl_job_id=0,
175
+ comm_engine=-1,
176
+ thread_num=0xFFFFFFFF,
177
+ notify_num_per_thread=0xFFFFFFFF,
178
+ acl_graph_zero_copy_enable=0,
179
+ )
180
+ uint32_array = ctypes.c_uint32 * len(ranks)
181
+ c_rank_ids = uint32_array(*ranks)
182
+ subcomm_rank = ranks.index(self.rank)
183
+ ranks_size = len(ranks)
184
+ subcomm_id = self.subcomm_id
185
+
186
+ subcomm = self.hccl.hcclCreateSubCommConfig(
187
+ self.comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config
188
+ )
189
+ self.subcomm_id += 1
190
+ return subcomm
191
+
192
+
193
+ class DistributedHccl(Distributed):
194
+ def __init__(self):
195
+ self.pg: StatelessProcessGroup = None
196
+ self.pyhccl: PyHcclCommunicatorEx = None
197
+ self.sub_groups: dict[int, CommGroup] = {}
198
+ self.comm: hcclComm_t = None
199
+
200
+ self.host: str = None
201
+ self.port: int = None
202
+ self.rank: int = None
203
+ self.world_size: int = None
204
+ self.device: torch.device = None
205
+
206
+ self.initialized: bool = False
207
+
208
+ @contextmanager
209
+ def _use_group(self, group: CommGroup | None, src: int | None = None):
210
+ active_src = src
211
+ if group:
212
+ assert group.handle in self.sub_groups, "invalid sub_group"
213
+ newcomm = ctypes.c_void_p(group.handle)
214
+ self.pyhccl.comm = newcomm
215
+
216
+ if src is not None:
217
+ assert src in group.ranks, "src rank not in group"
218
+ # convert src rank id in default world to newcomm
219
+ active_src = group.ranks.index(src)
220
+ self.pyhccl.rank = group.ranks.index(self.rank)
221
+
222
+ try:
223
+ yield active_src
224
+ finally:
225
+ if group:
226
+ self.pyhccl.comm = self.comm
227
+ if src is not None:
228
+ self.pyhccl.rank = self.rank
229
+
230
+ def init_process_group(
231
+ self,
232
+ rank: int,
233
+ world_size: int,
234
+ store: torch.distributed.TCPStore,
235
+ **kwargs,
236
+ ):
237
+ assert not self.initialized, "already initialized"
238
+
239
+ self.rank = rank
240
+ self.world_size = world_size
241
+ self.device = torch.device("npu", torch.npu.current_device())
242
+
243
+ self.pg = StatelessProcessGroup(rank=rank, world_size=world_size, store=store, socket=None)
244
+ self.pyhccl = PyHcclCommunicatorEx(group=self.pg, device=self.device)
245
+ self.comm = self.pyhccl.comm
246
+ self.initialized = True
247
+
248
+ def destroy_process_group(
249
+ self,
250
+ group: CommGroup | None = None,
251
+ ):
252
+ assert self.initialized, "not initialized"
253
+
254
+ if group and group.handle in self.sub_groups:
255
+ subcomm = ctypes.c_void_p(group.handle)
256
+ self.pyhccl.destroy_comm(subcomm)
257
+ del self.sub_groups[group.handle]
258
+ return
259
+
260
+ self.pyhccl.destroy_comm()
261
+ self.pyhccl = None
262
+ self.pg = None
263
+ self.initialized = False
264
+
265
+ def is_initialized(self) -> bool:
266
+ return self.initialized
267
+
268
+ def all_gather_object(self, object_list: list[Any], obj: Any, group: CommGroup | None = None):
269
+ assert self.initialized, "not initialized"
270
+
271
+ with self._use_group(group):
272
+ _common_all_gather_object(self.pyhccl, self.device, self.world_size, object_list, obj)
273
+ current_stream().synchronize()
274
+
275
+ def all_reduce(
276
+ self,
277
+ tensor: torch.Tensor,
278
+ op: ReduceOp.RedOpType = ReduceOp.SUM,
279
+ group: CommGroup | None = None,
280
+ **kwargs,
281
+ ):
282
+ assert self.initialized, "not initialized"
283
+
284
+ with self._use_group(group):
285
+ out_tensor = self.pyhccl.all_reduce(tensor, op)
286
+ current_stream().synchronize()
287
+ tensor.copy_(out_tensor)
288
+
289
+ def broadcast(
290
+ self, tensor: torch.Tensor, src: int | None = None, group: CommGroup | None = None, **kwargs
291
+ ):
292
+ assert self.initialized, "not initialized"
293
+
294
+ with self._use_group(group, src) as local_rank:
295
+ self.pyhccl.broadcast(tensor, local_rank)
296
+ current_stream().synchronize()
297
+
298
+ def barrier(self, group: CommGroup | None = None, **kwargs):
299
+ assert self.initialized, "not initialized"
300
+
301
+ with self._use_group(group):
302
+ data = torch.zeros(1, device=self.device)
303
+ self.pyhccl.all_reduce(data)
304
+ current_stream().synchronize()
305
+
306
+ def new_group(self, ranks: list[int], **kwargs) -> CommGroup | None:
307
+ assert self.initialized, "not initialized"
308
+
309
+ # ranks is None or []
310
+ if not ranks:
311
+ ranks = list(range(self.world_size))
312
+ else:
313
+ ranks.sort()
314
+
315
+ group: CommGroup = None
316
+ if self.rank not in ranks:
317
+ return group
318
+
319
+ subcomm = self.pyhccl.create_subcomm(ranks)
320
+ if subcomm:
321
+ group = CommGroup(subcomm.value, ranks)
322
+ self.sub_groups[subcomm.value] = group
323
+ return group