checkpoint-engine 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.3.3'
32
- __version_tuple__ = version_tuple = (0, 3, 3)
31
+ __version__ = version = '0.4.0'
32
+ __version_tuple__ = version_tuple = (0, 4, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -0,0 +1,28 @@
1
+ from .base import (
2
+ Distributed,
3
+ DistributedProcessGroup,
4
+ all_gather_object,
5
+ all_reduce,
6
+ barrier,
7
+ broadcast,
8
+ destroy_process_group,
9
+ init_process_group,
10
+ is_initialized,
11
+ new_group,
12
+ use_backend,
13
+ )
14
+
15
+
16
+ __all__ = [
17
+ "Distributed",
18
+ "DistributedProcessGroup",
19
+ "all_gather_object",
20
+ "all_reduce",
21
+ "barrier",
22
+ "broadcast",
23
+ "destroy_process_group",
24
+ "init_process_group",
25
+ "is_initialized",
26
+ "new_group",
27
+ "use_backend",
28
+ ]
@@ -0,0 +1,288 @@
1
+ import importlib
2
+ import io
3
+ import pickle
4
+ from abc import ABC, abstractmethod
5
+ from datetime import timedelta
6
+ from typing import Any, Protocol
7
+
8
+ import torch
9
+ import torch.distributed as torch_dist
10
+
11
+
12
+ class CommunicatorProtocol(Protocol):
13
+ def all_gather(self, *args: Any, **kwargs: Any) -> torch.Tensor: ...
14
+
15
+
16
+ class CommGroup:
17
+ def __init__(self, comm_handle: int, ranks: list[int]):
18
+ self._comm = comm_handle
19
+ self._ranks = ranks
20
+
21
+ @property
22
+ def handle(self) -> int:
23
+ return self._comm
24
+
25
+ @property
26
+ def ranks(self) -> list[int]:
27
+ return self._ranks
28
+
29
+
30
+ DistributedProcessGroup = torch_dist.ProcessGroup | CommGroup
31
+
32
+
33
+ class Distributed(ABC):
34
+ @abstractmethod
35
+ def init_process_group(
36
+ self,
37
+ rank: int,
38
+ world_size: int,
39
+ store: torch_dist.TCPStore,
40
+ **kwargs,
41
+ ):
42
+ raise NotImplementedError
43
+
44
+ @abstractmethod
45
+ def destroy_process_group(
46
+ self,
47
+ group: DistributedProcessGroup | None = None,
48
+ ):
49
+ raise NotImplementedError
50
+
51
+ @abstractmethod
52
+ def is_initialized(self) -> bool:
53
+ raise NotImplementedError
54
+
55
+ @abstractmethod
56
+ def all_gather_object(
57
+ self,
58
+ object_list: list[Any],
59
+ obj: Any,
60
+ group: DistributedProcessGroup | None = None,
61
+ ):
62
+ raise NotImplementedError
63
+
64
+ @abstractmethod
65
+ def all_reduce(
66
+ self,
67
+ tensor: torch.Tensor,
68
+ op: torch_dist.ReduceOp.RedOpType,
69
+ group: DistributedProcessGroup | None = None,
70
+ **kwargs,
71
+ ):
72
+ raise NotImplementedError
73
+
74
+ @abstractmethod
75
+ def broadcast(
76
+ self,
77
+ tensor: torch.Tensor,
78
+ src: int,
79
+ group: DistributedProcessGroup | None = None,
80
+ **kwargs,
81
+ ):
82
+ raise NotImplementedError
83
+
84
+ @abstractmethod
85
+ def barrier(
86
+ self,
87
+ group: DistributedProcessGroup | None = None,
88
+ **kwargs,
89
+ ):
90
+ raise NotImplementedError
91
+
92
+ @abstractmethod
93
+ def new_group(
94
+ self,
95
+ ranks: list[int],
96
+ **kwargs,
97
+ ):
98
+ raise NotImplementedError
99
+
100
+
101
+ class TorchBackend(Distributed):
102
+ def init_process_group(
103
+ self,
104
+ rank: int,
105
+ world_size: int,
106
+ store: torch_dist.TCPStore,
107
+ **kwargs,
108
+ ):
109
+ backend = kwargs.get("backend", "nccl")
110
+ timeout = kwargs.get("timeout", timedelta(minutes=10))
111
+
112
+ torch_dist.init_process_group(
113
+ backend=backend,
114
+ world_size=world_size,
115
+ rank=rank,
116
+ timeout=timeout,
117
+ store=store,
118
+ )
119
+
120
+ def destroy_process_group(self, group: DistributedProcessGroup | None = None):
121
+ torch_dist.destroy_process_group(group)
122
+
123
+ def is_initialized(self) -> bool:
124
+ return torch_dist.is_initialized()
125
+
126
+ def all_gather_object(
127
+ self, object_list: list[Any], obj: Any, group: DistributedProcessGroup | None = None
128
+ ):
129
+ torch_dist.all_gather_object(object_list, obj, group)
130
+
131
+ def all_reduce(
132
+ self,
133
+ tensor: torch.Tensor,
134
+ op: torch_dist.ReduceOp.RedOpType = torch_dist.ReduceOp.SUM,
135
+ group: DistributedProcessGroup | None = None,
136
+ **kwargs,
137
+ ):
138
+ torch_dist.all_reduce(tensor, op, group, **kwargs)
139
+
140
+ def broadcast(
141
+ self,
142
+ tensor: torch.Tensor,
143
+ src: int = 0,
144
+ group: DistributedProcessGroup | None = None,
145
+ **kwargs,
146
+ ):
147
+ torch_dist.broadcast(tensor, src, group, **kwargs)
148
+
149
+ def barrier(self, group: DistributedProcessGroup | None = None, **kwargs):
150
+ torch_dist.barrier(group, **kwargs)
151
+
152
+ def new_group(self, ranks: list[int], **kwargs) -> DistributedProcessGroup | None:
153
+ return torch_dist.new_group(ranks, **kwargs)
154
+
155
+
156
+ # specific device instance
157
+ _BACKEND_INSTANCE: Distributed = TorchBackend()
158
+
159
+ _pickler = pickle.Pickler
160
+ _unpickler = pickle.Unpickler
161
+
162
+
163
+ def _object_to_tensor(obj: Any, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
164
+ f = io.BytesIO()
165
+ _pickler(f).dump(obj)
166
+ byte_storage = torch.ByteStorage._from_buffer(f.getvalue())
167
+ byte_tensor = torch.ByteTensor(byte_storage).to(device)
168
+ local_size = torch.LongTensor([byte_tensor.numel()]).to(device)
169
+ return byte_tensor, local_size
170
+
171
+
172
+ def _tensor_to_object(tensor: torch.Tensor, tensor_size: int) -> Any:
173
+ tensor = tensor.cpu()
174
+ buf = tensor.numpy().tobytes()[:tensor_size]
175
+ return _unpickler(io.BytesIO(buf)).load()
176
+
177
+
178
+ def _flatten_for_scatter_gather(
179
+ tensor_list: list[torch.Tensor], copy: bool = False
180
+ ) -> torch.Tensor:
181
+ if not tensor_list:
182
+ raise RuntimeError("Received an empty list.")
183
+ t = tensor_list[0]
184
+ buffer_shape = [len(tensor_list)] + list(t.shape)
185
+
186
+ buffer = torch.empty(tuple(buffer_shape), dtype=t.dtype, device=t.device)
187
+ if copy:
188
+ for i, tensor in enumerate(tensor_list):
189
+ buffer[i].copy_(tensor)
190
+ return buffer
191
+
192
+
193
+ def _common_all_gather_object(
194
+ comm: CommunicatorProtocol,
195
+ device: torch.device,
196
+ world_size: int,
197
+ object_list: list[Any],
198
+ object: Any,
199
+ ):
200
+ input_tensor, local_size = _object_to_tensor(object, device)
201
+ object_sizes_tensor = torch.empty(world_size, dtype=torch.long, device=device)
202
+ comm.all_gather(object_sizes_tensor, local_size)
203
+ object_size_list = [object_sizes_tensor[i].unsqueeze(dim=0) for i in range(world_size)]
204
+ max_object_size = int(max(object_size_list).item())
205
+ input_tensor.resize_(max_object_size)
206
+ coalesced_output_tensor = torch.empty(
207
+ max_object_size * world_size, dtype=torch.uint8, device=device
208
+ )
209
+
210
+ comm.all_gather(coalesced_output_tensor, input_tensor)
211
+ output_tensors = [
212
+ coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
213
+ for i in range(world_size)
214
+ ]
215
+ for i, tensor in enumerate(output_tensors):
216
+ tensor = tensor.type(torch.uint8)
217
+ tensor_size = object_size_list[i]
218
+ object_list[i] = _tensor_to_object(tensor, tensor_size)
219
+
220
+
221
+ def use_backend(backend: str | None):
222
+ global _BACKEND_INSTANCE
223
+
224
+ if not backend:
225
+ return
226
+
227
+ mapping = {
228
+ "vllm_nccl": ".vllm_nccl.DistributedNccl",
229
+ "vllm_hccl": ".vllm_hccl.DistributedHccl",
230
+ }
231
+ if backend not in mapping:
232
+ raise ValueError(f"Unsupported custom backend: {backend}")
233
+
234
+ module_path, class_name = mapping[backend].rsplit(".", 1)
235
+ module = importlib.import_module(module_path, "checkpoint_engine.distributed")
236
+ backend_class = getattr(module, class_name)
237
+ _BACKEND_INSTANCE = backend_class()
238
+
239
+
240
+ def init_process_group(
241
+ rank: int,
242
+ world_size: int,
243
+ store: torch_dist.TCPStore,
244
+ **kwargs,
245
+ ):
246
+ _BACKEND_INSTANCE.init_process_group(rank, world_size, store, **kwargs)
247
+
248
+
249
+ def destroy_process_group(group: DistributedProcessGroup | None = None):
250
+ _BACKEND_INSTANCE.destroy_process_group(group)
251
+
252
+
253
+ def is_initialized() -> bool:
254
+ return _BACKEND_INSTANCE.is_initialized()
255
+
256
+
257
+ def all_gather_object(
258
+ object_list: list[Any],
259
+ obj: Any,
260
+ group: DistributedProcessGroup | None = None,
261
+ ):
262
+ _BACKEND_INSTANCE.all_gather_object(object_list, obj, group)
263
+
264
+
265
+ def all_reduce(
266
+ tensor: torch.Tensor,
267
+ op: torch_dist.ReduceOp.RedOpType = torch_dist.ReduceOp.SUM,
268
+ group: DistributedProcessGroup | None = None,
269
+ **kwargs,
270
+ ):
271
+ _BACKEND_INSTANCE.all_reduce(tensor, op, group, **kwargs)
272
+
273
+
274
+ def broadcast(
275
+ tensor: torch.Tensor,
276
+ src: int = 0,
277
+ group: DistributedProcessGroup | None = None,
278
+ **kwargs,
279
+ ):
280
+ _BACKEND_INSTANCE.broadcast(tensor, src, group, **kwargs)
281
+
282
+
283
+ def barrier(group: DistributedProcessGroup | None = None, **kwargs):
284
+ _BACKEND_INSTANCE.barrier(group, **kwargs)
285
+
286
+
287
+ def new_group(ranks: list[int], **kwargs) -> DistributedProcessGroup | None:
288
+ return _BACKEND_INSTANCE.new_group(ranks, **kwargs)
@@ -0,0 +1,323 @@
1
+ import ctypes
2
+ from contextlib import contextmanager
3
+ from typing import Any, ClassVar
4
+
5
+ import torch
6
+ from torch.distributed import ReduceOp
7
+ from vllm.distributed.utils import StatelessProcessGroup
8
+ from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator
9
+ from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
10
+ Function,
11
+ HCCLLibrary,
12
+ aclrtStream_t,
13
+ buffer_type,
14
+ hcclComm_t,
15
+ hcclDataType_t,
16
+ hcclDataTypeEnum,
17
+ hcclResult_t,
18
+ )
19
+ from vllm_ascend.utils import current_stream
20
+
21
+ from checkpoint_engine.distributed.base import CommGroup, Distributed, _common_all_gather_object
22
+
23
+
24
+ class HcclCommConfig(ctypes.Structure):
25
+ _fields_: ClassVar[list[tuple[str, Any]]] = [
26
+ ("size", ctypes.c_size_t),
27
+ ("magic_word", ctypes.c_uint32),
28
+ ("version", ctypes.c_uint32),
29
+ ("reserved", ctypes.c_uint64),
30
+ ("hccl_buffer_size", ctypes.c_uint32),
31
+ ("hccl_deterministic", ctypes.c_uint32),
32
+ ("hccl_comm_name", ctypes.c_char * 128),
33
+ ("hccl_udi", ctypes.c_char * 128),
34
+ ("hccl_op_expansion_mode", ctypes.c_uint32),
35
+ ("hccl_rdma_traffic_class", ctypes.c_uint32),
36
+ ("hccl_rdma_service_level", ctypes.c_uint32),
37
+ ("hcll_world_rank_id", ctypes.c_uint32),
38
+ ("hccl_job_id", ctypes.c_uint64),
39
+ ("comm_engine", ctypes.c_int32),
40
+ ("thread_num", ctypes.c_uint32),
41
+ ("notify_num_per_thread", ctypes.c_uint32),
42
+ ("acl_graph_zero_copy_enable", ctypes.c_uint8),
43
+ ]
44
+
45
+
46
+ orig_exported_functions = HCCLLibrary.exported_functions
47
+ extended_functions = [
48
+ # HcclResult HcclAllGather(
49
+ # void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType,
50
+ # HcclComm comm, alcrtStream stream
51
+ # )
52
+ Function(
53
+ "HcclAllGather",
54
+ hcclResult_t,
55
+ [
56
+ buffer_type,
57
+ buffer_type,
58
+ ctypes.c_uint64,
59
+ hcclDataType_t,
60
+ hcclComm_t,
61
+ aclrtStream_t,
62
+ ],
63
+ ),
64
+ # HcclResult HcclCreateSubCommConfig(
65
+ # HcclComm *comm, uin32_t rankNum, uint32_t *rankIds, uint64_t subCommId,
66
+ # uint32_t subCommRankId, HcclCommConfig *config, HcclComm *subComm
67
+ # )
68
+ Function(
69
+ "HcclCreateSubCommConfig",
70
+ hcclResult_t,
71
+ [
72
+ ctypes.POINTER(hcclComm_t),
73
+ ctypes.c_uint32,
74
+ ctypes.POINTER(ctypes.c_uint32),
75
+ ctypes.c_uint64,
76
+ ctypes.c_uint32,
77
+ ctypes.POINTER(HcclCommConfig),
78
+ ctypes.POINTER(hcclComm_t),
79
+ ],
80
+ ),
81
+ ]
82
+
83
+
84
+ def hccl_all_gather(
85
+ self, # noqa: ANN001
86
+ send_buf: buffer_type,
87
+ recv_buf: buffer_type,
88
+ count: ctypes.c_uint64,
89
+ data_type: hcclDataType_t,
90
+ comm: hcclComm_t,
91
+ stream: aclrtStream_t,
92
+ ):
93
+ self.HCCL_CHECK(
94
+ self._funcs["HcclAllGather"](send_buf, recv_buf, count, data_type, comm, stream)
95
+ )
96
+
97
+
98
+ def hccl_create_subcomm_config(
99
+ self, # noqa: ANN001
100
+ comm: hcclComm_t,
101
+ ranks_size: ctypes.c_uint32,
102
+ c_rank_ids: ctypes.POINTER(ctypes.c_uint32),
103
+ subcomm_id: ctypes.c_uint64,
104
+ subcomm_rank: ctypes.c_uint64,
105
+ comm_config: HcclCommConfig,
106
+ ) -> hcclComm_t:
107
+ subcomm = hcclComm_t()
108
+ self.HCCL_CHECK(
109
+ self._funcs["HcclCreateSubCommConfig"](
110
+ ctypes.byref(comm),
111
+ ranks_size,
112
+ c_rank_ids,
113
+ subcomm_id,
114
+ subcomm_rank,
115
+ ctypes.byref(comm_config),
116
+ ctypes.byref(subcomm),
117
+ )
118
+ )
119
+ return subcomm
120
+
121
+
122
+ # extend HCCLLibrary
123
+ HCCLLibrary.exported_functions = orig_exported_functions + extended_functions
124
+ HCCLLibrary.hcclAllGather = hccl_all_gather
125
+ HCCLLibrary.hcclCreateSubCommConfig = hccl_create_subcomm_config
126
+
127
+
128
+ class PyHcclCommunicatorEx(PyHcclCommunicator):
129
+ def __init__(self, group: StatelessProcessGroup, device: torch.device):
130
+ super().__init__(group, device)
131
+ self.subcomm_id = 1
132
+
133
+ def destroy_comm(self, comm: hcclComm_t = None):
134
+ if comm:
135
+ self.hccl.hcclCommDestroy(comm)
136
+ else:
137
+ self.hccl.hcclCommDestroy(self.comm)
138
+
139
+ def all_gather(
140
+ self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream: torch.npu.Stream = None
141
+ ) -> torch.Tensor:
142
+ if self.disabled:
143
+ return
144
+ assert in_tensor.device == self.device, (
145
+ f"this hccl communicator is created to work on {self.device}, "
146
+ f"but the input tensor in on {in_tensor.device}"
147
+ )
148
+ if stream is None:
149
+ stream = current_stream()
150
+ self.hccl.hcclAllGather(
151
+ buffer_type(in_tensor.data_ptr()),
152
+ buffer_type(out_tensor.data_ptr()),
153
+ in_tensor.numel(),
154
+ hcclDataTypeEnum.from_torch(in_tensor.dtype),
155
+ self.comm, # todo
156
+ aclrtStream_t(stream.npu_stream),
157
+ )
158
+ return out_tensor
159
+
160
+ def create_subcomm(self, ranks: list[int]) -> hcclComm_t:
161
+ comm_config = HcclCommConfig(
162
+ size=312,
163
+ magic_word=0xF0F0F0F0,
164
+ version=6,
165
+ reserved=0,
166
+ hccl_buffer_size=0xFFFFFFFF,
167
+ hccl_deterministic=0xFFFFFFFF,
168
+ hccl_comm_name=b"\0",
169
+ hccl_udi=b"\0",
170
+ hccl_op_expansize_mode=0,
171
+ hccl_rdma_traffic_class=0xFFFFFFFF,
172
+ hccl_rdma_service_level=0xFFFFFFFF,
173
+ hccl_world_rank_id=0,
174
+ hccl_job_id=0,
175
+ comm_engine=-1,
176
+ thread_num=0xFFFFFFFF,
177
+ notify_num_per_thread=0xFFFFFFFF,
178
+ acl_graph_zero_copy_enable=0,
179
+ )
180
+ uint32_array = ctypes.c_uint32 * len(ranks)
181
+ c_rank_ids = uint32_array(*ranks)
182
+ subcomm_rank = ranks.index(self.rank)
183
+ ranks_size = len(ranks)
184
+ subcomm_id = self.subcomm_id
185
+
186
+ subcomm = self.hccl.hcclCreateSubCommConfig(
187
+ self.comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config
188
+ )
189
+ self.subcomm_id += 1
190
+ return subcomm
191
+
192
+
193
+ class DistributedHccl(Distributed):
194
+ def __init__(self):
195
+ self.pg: StatelessProcessGroup = None
196
+ self.pyhccl: PyHcclCommunicatorEx = None
197
+ self.sub_groups: dict[int, CommGroup] = {}
198
+ self.comm: hcclComm_t = None
199
+
200
+ self.host: str = None
201
+ self.port: int = None
202
+ self.rank: int = None
203
+ self.world_size: int = None
204
+ self.device: torch.device = None
205
+
206
+ self.initialized: bool = False
207
+
208
+ @contextmanager
209
+ def _use_group(self, group: CommGroup | None, src: int | None = None):
210
+ active_src = src
211
+ if group:
212
+ assert group.handle in self.sub_groups, "invalid sub_group"
213
+ newcomm = ctypes.c_void_p(group.handle)
214
+ self.pyhccl.comm = newcomm
215
+
216
+ if src is not None:
217
+ assert src in group.ranks, "src rank not in group"
218
+ # convert src rank id in default world to newcomm
219
+ active_src = group.ranks.index(src)
220
+ self.pyhccl.rank = group.ranks.index(self.rank)
221
+
222
+ try:
223
+ yield active_src
224
+ finally:
225
+ if group:
226
+ self.pyhccl.comm = self.comm
227
+ if src is not None:
228
+ self.pyhccl.rank = self.rank
229
+
230
+ def init_process_group(
231
+ self,
232
+ rank: int,
233
+ world_size: int,
234
+ store: torch.distributed.TCPStore,
235
+ **kwargs,
236
+ ):
237
+ assert not self.initialized, "already initialized"
238
+
239
+ self.rank = rank
240
+ self.world_size = world_size
241
+ self.device = torch.device("npu", torch.npu.current_device())
242
+
243
+ self.pg = StatelessProcessGroup(rank=rank, world_size=world_size, store=store, socket=None)
244
+ self.pyhccl = PyHcclCommunicatorEx(group=self.pg, device=self.device)
245
+ self.comm = self.pyhccl.comm
246
+ self.initialized = True
247
+
248
+ def destroy_process_group(
249
+ self,
250
+ group: CommGroup | None = None,
251
+ ):
252
+ assert self.initialized, "not initialized"
253
+
254
+ if group and group.handle in self.sub_groups:
255
+ subcomm = ctypes.c_void_p(group.handle)
256
+ self.pyhccl.destroy_comm(subcomm)
257
+ del self.sub_groups[group.handle]
258
+ return
259
+
260
+ self.pyhccl.destroy_comm()
261
+ self.pyhccl = None
262
+ self.pg = None
263
+ self.initialized = False
264
+
265
+ def is_initialized(self) -> bool:
266
+ return self.initialized
267
+
268
+ def all_gather_object(self, object_list: list[Any], obj: Any, group: CommGroup | None = None):
269
+ assert self.initialized, "not initialized"
270
+
271
+ with self._use_group(group):
272
+ _common_all_gather_object(self.pyhccl, self.device, self.world_size, object_list, obj)
273
+ current_stream().synchronize()
274
+
275
+ def all_reduce(
276
+ self,
277
+ tensor: torch.Tensor,
278
+ op: ReduceOp.RedOpType = ReduceOp.SUM,
279
+ group: CommGroup | None = None,
280
+ **kwargs,
281
+ ):
282
+ assert self.initialized, "not initialized"
283
+
284
+ with self._use_group(group):
285
+ out_tensor = self.pyhccl.all_reduce(tensor, op)
286
+ current_stream().synchronize()
287
+ tensor.copy_(out_tensor)
288
+
289
+ def broadcast(
290
+ self, tensor: torch.Tensor, src: int | None = None, group: CommGroup | None = None, **kwargs
291
+ ):
292
+ assert self.initialized, "not initialized"
293
+
294
+ with self._use_group(group, src) as local_rank:
295
+ self.pyhccl.broadcast(tensor, local_rank)
296
+ current_stream().synchronize()
297
+
298
+ def barrier(self, group: CommGroup | None = None, **kwargs):
299
+ assert self.initialized, "not initialized"
300
+
301
+ with self._use_group(group):
302
+ data = torch.zeros(1, device=self.device)
303
+ self.pyhccl.all_reduce(data)
304
+ current_stream().synchronize()
305
+
306
+ def new_group(self, ranks: list[int], **kwargs) -> CommGroup | None:
307
+ assert self.initialized, "not initialized"
308
+
309
+ # ranks is None or []
310
+ if not ranks:
311
+ ranks = list(range(self.world_size))
312
+ else:
313
+ ranks.sort()
314
+
315
+ group: CommGroup = None
316
+ if self.rank not in ranks:
317
+ return group
318
+
319
+ subcomm = self.pyhccl.create_subcomm(ranks)
320
+ if subcomm:
321
+ group = CommGroup(subcomm.value, ranks)
322
+ self.sub_groups[subcomm.value] = group
323
+ return group
@@ -0,0 +1,223 @@
1
+ import ctypes
2
+ from contextlib import contextmanager
3
+ from typing import Any, ClassVar
4
+
5
+ import torch
6
+ from torch.distributed import ReduceOp
7
+ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
8
+ from vllm.distributed.device_communicators.pynccl_wrapper import (
9
+ Function,
10
+ NCCLLibrary,
11
+ ncclComm_t,
12
+ ncclResult_t,
13
+ )
14
+ from vllm.distributed.utils import StatelessProcessGroup
15
+ from vllm.utils import current_stream
16
+
17
+ from checkpoint_engine.distributed.base import CommGroup, Distributed, _common_all_gather_object
18
+
19
+
20
+ class NcclConfigT(ctypes.Structure):
21
+ _fields_: ClassVar[list[tuple[str, Any]]] = [
22
+ ("size", ctypes.c_size_t),
23
+ ("magic", ctypes.c_uint),
24
+ ("version", ctypes.c_uint),
25
+ ("blocking", ctypes.c_int),
26
+ ("cgaClusterSize", ctypes.c_int),
27
+ ("minCTAs", ctypes.c_int),
28
+ ("maxCTAs", ctypes.c_int),
29
+ ("netName", ctypes.c_char_p),
30
+ ("splitShare", ctypes.c_int),
31
+ ("trafficClass", ctypes.c_int),
32
+ ("commName", ctypes.c_char_p),
33
+ ("collnetEnable", ctypes.c_int),
34
+ ("CTAPolicy", ctypes.c_int),
35
+ ("shrinkShare", ctypes.c_int),
36
+ ("nvlsCTAs", ctypes.c_int),
37
+ ("nChannelsPerNetPeer", ctypes.c_int),
38
+ ("nvlinkCentricSched", ctypes.c_int),
39
+ ("graphUsageMode", ctypes.c_int),
40
+ ("numRmaCtx", ctypes.c_int),
41
+ ]
42
+
43
+
44
+ nccl_orig_exported_functions = NCCLLibrary.exported_functions
45
+ nccl_extended_functions = [
46
+ # ncclResult_t ncclCommSplit(
47
+ # ncclComm_t comm, int color, int key, ncclComm_t *newcomm, NcclConfigT *config
48
+ # )
49
+ Function(
50
+ "ncclCommSplit",
51
+ ncclResult_t,
52
+ [
53
+ ncclComm_t,
54
+ ctypes.c_int,
55
+ ctypes.c_int,
56
+ ctypes.POINTER(ncclComm_t),
57
+ ctypes.POINTER(NcclConfigT),
58
+ ],
59
+ ),
60
+ ]
61
+
62
+
63
+ def nccl_comm_split(
64
+ self, # noqa: ANN001
65
+ comm: ncclComm_t,
66
+ color: int,
67
+ key: int,
68
+ ) -> ncclComm_t:
69
+ newcomm = ncclComm_t()
70
+
71
+ self.NCCL_CHECK(self._funcs["ncclCommSplit"](comm, color, key, ctypes.byref(newcomm), None))
72
+ return newcomm
73
+
74
+
75
+ # extend NCCLLibrary
76
+ NCCLLibrary.exported_functions = nccl_orig_exported_functions + nccl_extended_functions
77
+ NCCLLibrary.ncclCommSplit = nccl_comm_split
78
+
79
+
80
+ class PyNcclCommunicatorEx(PyNcclCommunicator):
81
+ def destroy_comm(self, comm: ncclComm_t = None):
82
+ if comm:
83
+ self.nccl.ncclCommDestroy(comm)
84
+ else:
85
+ self.nccl.ncclCommDestroy(self.comm)
86
+
87
+ def create_newcomm(self, ranks: list[int]) -> ncclComm_t:
88
+ if self.rank in ranks:
89
+ color = 0
90
+ else:
91
+ color = -1 # NCCL_SPLIT_NOCOLOR
92
+ newcomm = self.nccl.ncclCommSplit(self.comm, color, self.rank)
93
+ return newcomm
94
+
95
+
96
+ class DistributedNccl(Distributed):
97
+ def __init__(self):
98
+ self.pg: StatelessProcessGroup = None
99
+ self.pynccl: PyNcclCommunicatorEx = None
100
+ self.sub_groups: dict[int, list[int]] = {}
101
+ self.comm: ncclComm_t = None
102
+
103
+ self.host: str = None
104
+ self.port: int = None
105
+ self.rank: int = None
106
+ self.world_size: int = None
107
+ self.device: torch.device = None
108
+
109
+ self.initialized: bool = False
110
+
111
+ @contextmanager
112
+ def _use_group(self, group: CommGroup | None, src: int | None = None):
113
+ active_src = src
114
+ if group:
115
+ assert group.handle in self.sub_groups, "invalid sub_group"
116
+ newcomm = ctypes.c_void_p(group.handle)
117
+ self.pynccl.comm = newcomm
118
+
119
+ if src is not None:
120
+ assert src in group.ranks, "src rank not in group"
121
+ # convert src rank id in default world to newcomm
122
+ active_src = group.ranks.index(src)
123
+ self.pynccl.rank = group.ranks.index(self.rank)
124
+
125
+ try:
126
+ yield active_src
127
+ finally:
128
+ if group:
129
+ self.pynccl.comm = self.comm
130
+ if src is not None:
131
+ self.pynccl.rank = self.rank
132
+
133
+ def init_process_group(
134
+ self,
135
+ rank: int,
136
+ world_size: int,
137
+ store: torch.distributed.TCPStore,
138
+ **kwargs,
139
+ ):
140
+ assert not self.initialized, "already initialized"
141
+
142
+ self.rank = rank
143
+ self.world_size = world_size
144
+ self.device = torch.device("cuda", torch.cuda.current_device())
145
+
146
+ self.pg = StatelessProcessGroup(rank=rank, world_size=world_size, store=store, socket=None)
147
+ self.pynccl = PyNcclCommunicatorEx(group=self.pg, device=self.device)
148
+ self.comm = self.pynccl.comm
149
+ self.initialized = True
150
+
151
+ def destroy_process_group(
152
+ self,
153
+ group: CommGroup | None = None,
154
+ ):
155
+ assert self.initialized, "not initialized"
156
+
157
+ if group and group.handle in self.sub_groups:
158
+ newcomm = ctypes.c_void_p(group.handle)
159
+ self.pynccl.destroy_comm(newcomm)
160
+ del self.sub_groups[group.handle]
161
+ return
162
+
163
+ self.pynccl.destroy_comm()
164
+ self.pynccl = None
165
+ self.pg = None
166
+ self.initialized = False
167
+
168
+ def is_initialized(self) -> bool:
169
+ return self.initialized
170
+
171
+ def all_gather_object(self, object_list: list[Any], obj: Any, group: CommGroup | None = None):
172
+ assert self.initialized, "not initialized"
173
+
174
+ with self._use_group(group):
175
+ _common_all_gather_object(self.pynccl, self.device, self.world_size, object_list, obj)
176
+ current_stream().synchronize()
177
+
178
+ def all_reduce(
179
+ self,
180
+ tensor: torch.Tensor,
181
+ op: ReduceOp.RedOpType = ReduceOp.SUM,
182
+ group: CommGroup | None = None,
183
+ **kwargs,
184
+ ):
185
+ assert self.initialized, "not initialized"
186
+
187
+ with self._use_group(group):
188
+ out_tensor = self.pynccl.all_reduce(in_tensor=tensor, op=op)
189
+ current_stream().synchronize()
190
+ tensor.copy_(out_tensor)
191
+
192
+ def broadcast(
193
+ self, tensor: torch.Tensor, src: int | None = None, group: CommGroup | None = None, **kwargs
194
+ ):
195
+ assert self.initialized, "not initialized"
196
+
197
+ with self._use_group(group, src) as local_src:
198
+ self.pynccl.broadcast(tensor, local_src)
199
+ current_stream().synchronize()
200
+
201
+ def barrier(self, group: CommGroup | None = None, **kwargs):
202
+ assert self.initialized, "not initialized"
203
+
204
+ with self._use_group(group):
205
+ data = torch.zeros(1, device=self.device)
206
+ self.pynccl.all_reduce(data)
207
+ current_stream().synchronize()
208
+
209
+ def new_group(self, ranks: list[int], **kwargs) -> CommGroup | None:
210
+ assert self.initialized, "not initialized"
211
+
212
+ # ranks is None or []
213
+ if not ranks:
214
+ ranks = list(range(self.world_size))
215
+ else:
216
+ ranks.sort()
217
+
218
+ group: CommGroup = None
219
+ newcomm = self.pynccl.create_newcomm(ranks)
220
+ if newcomm:
221
+ group = CommGroup(newcomm.value, ranks)
222
+ self.sub_groups[newcomm.value] = group
223
+ return group
checkpoint_engine/ps.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import ctypes
2
+ import gc
2
3
  import os
3
4
  import threading
4
5
  from collections import defaultdict
@@ -7,11 +8,12 @@ from datetime import timedelta
7
8
  from typing import TYPE_CHECKING
8
9
 
9
10
  import torch
10
- import torch.distributed as dist
11
+ import torch.distributed
11
12
  import zmq
12
13
  from loguru import logger
13
14
  from torch.multiprocessing.reductions import reduce_tensor
14
15
 
16
+ import checkpoint_engine.distributed as dist
15
17
  from checkpoint_engine.data_types import (
16
18
  BucketRange,
17
19
  DataToGather,
@@ -175,6 +177,8 @@ class ParameterServer:
175
177
  auto_pg: bool = True,
176
178
  gpu_count: int | None = None,
177
179
  mem_fraction: float | None = None,
180
+ master_addr: str | None = None,
181
+ master_port: int | None = None,
178
182
  ):
179
183
  """
180
184
  Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
@@ -228,6 +232,17 @@ class ParameterServer:
228
232
  self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
229
233
  self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
230
234
 
235
+ master_addr = master_addr or os.getenv("MASTER_ADDR")
236
+ assert master_addr, "master_addr is required"
237
+ self._store = torch.distributed.TCPStore(
238
+ master_addr,
239
+ _get_master_port(master_port),
240
+ self._world_size,
241
+ timeout=timedelta(minutes=10),
242
+ is_master=self._rank == 0,
243
+ )
244
+ self._store_counter = 0
245
+
231
246
  def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
232
247
  if checkpoint_name == self._current_shared_memory_pool_user:
233
248
  assert self._memory_pool[self.shared_memory_pool_name], (
@@ -487,8 +502,6 @@ class ParameterServer:
487
502
  def init_process_group(
488
503
  self,
489
504
  *,
490
- master_addr: str | None = None,
491
- master_port: int | None = None,
492
505
  timeout: timedelta = timedelta(minutes=10),
493
506
  ):
494
507
  """
@@ -498,27 +511,18 @@ class ParameterServer:
498
511
  master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
499
512
  timeout: The timeout of the process group.
500
513
  """
501
- master_addr = master_addr or os.getenv("MASTER_ADDR")
502
- assert master_addr, "master_addr is required"
503
- store = dist.TCPStore(
504
- master_addr,
505
- _get_master_port(master_port),
506
- self._world_size,
507
- timeout=timeout,
508
- is_master=self._rank == 0,
509
- )
514
+ self._store_counter += 1
515
+ sub_store = torch.distributed.PrefixStore(f"prefix-{self._store_counter}", self._store)
510
516
  dist.init_process_group(
511
517
  backend=self.device_manager.backend,
512
518
  world_size=self._world_size,
513
519
  rank=self._rank,
514
520
  timeout=timeout,
515
- store=store,
521
+ store=sub_store,
516
522
  )
517
523
  logger.info(f"[rank{self._rank}] init process group successfully.")
518
524
 
519
- def store_based_barrier(
520
- self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5)
521
- ) -> None:
525
+ def store_based_barrier(self, timeout: timedelta = timedelta(minutes=5)) -> None:
522
526
  """
523
527
  Perform a store-based barrier synchronization across all ranks.
524
528
 
@@ -529,9 +533,9 @@ class ParameterServer:
529
533
  Args:
530
534
  store: The TCPStore instance to use for synchronization.
531
535
  """
532
- dist.distributed_c10d._store_based_barrier(
536
+ torch.distributed.distributed_c10d._store_based_barrier(
533
537
  rank=self._rank,
534
- store=store,
538
+ store=self._store,
535
539
  group_name="parameter_server_barrier",
536
540
  rendezvous_count=self._world_size,
537
541
  timeout=timeout,
@@ -544,8 +548,6 @@ class ParameterServer:
544
548
  *,
545
549
  timeout: timedelta = timedelta(minutes=10),
546
550
  ranks: list[int] | None = None,
547
- master_addr: str | None = None,
548
- master_port: int | None = None,
549
551
  ) -> None:
550
552
  """
551
553
  Update the checkpoint to inference engine. This function should be called after gather_metas.
@@ -566,28 +568,12 @@ class ParameterServer:
566
568
  assert req_func is not None, "req_func is required"
567
569
  ranks_group = None
568
570
  try:
569
- master_addr = os.getenv("MASTER_ADDR") or master_addr
570
- assert master_addr, "master_addr is required"
571
- if self._auto_pg:
572
- if not dist.is_initialized():
573
- self.init_process_group(
574
- timeout=timeout, master_addr=master_addr, master_port=master_port
575
- )
576
- manager_store = dist.distributed_c10d._get_default_store()
577
- else:
578
- # HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
579
- # If master_port is provided, use master_port+1 for barrier store
580
- manager_store = dist.TCPStore(
581
- master_addr,
582
- _get_master_port(master_port) + 1,
583
- self._world_size,
584
- timeout=timeout,
585
- is_master=self._rank == 0,
586
- )
571
+ if self._auto_pg and not dist.is_initialized():
572
+ self.init_process_group(timeout=timeout)
587
573
  # if ranks is None or [], it will use fully broadcast to update to all ranks
588
574
  ranks_group = dist.new_group(ranks) if ranks else None
589
575
  self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
590
- self.store_based_barrier(manager_store)
576
+ self.store_based_barrier()
591
577
  except Exception as e:
592
578
  logger.exception(
593
579
  f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
@@ -616,7 +602,10 @@ class ParameterServer:
616
602
  return socket, socket_paths
617
603
 
618
604
  def _detect_bucket_size(
619
- self, ranks_group: dist.ProcessGroup | None, *, disable_h2d_buffer: bool = False
605
+ self,
606
+ ranks_group: dist.DistributedProcessGroup | None,
607
+ *,
608
+ disable_h2d_buffer: bool = False,
620
609
  ) -> tuple[int, bool]:
621
610
  GiB = 1 << 30 # noqa: N806
622
611
  # auto detect bucket size
@@ -633,7 +622,7 @@ class ParameterServer:
633
622
  dtype=torch.int64,
634
623
  device=self.device_manager.device_type,
635
624
  )
636
- dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group)
625
+ dist.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN, group=ranks_group)
637
626
  tensor = tensor.cpu()
638
627
  free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
639
628
  max_tensor_bytes = 0
@@ -735,7 +724,7 @@ class ParameterServer:
735
724
  self,
736
725
  checkpoint_name: str,
737
726
  req_func: Callable[[list[tuple[str, str]]], None],
738
- ranks_group: dist.ProcessGroup | None,
727
+ ranks_group: dist.DistributedProcessGroup | None,
739
728
  ranks: list[int] | None = None,
740
729
  ):
741
730
  assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
@@ -854,7 +843,7 @@ class ParameterServer:
854
843
  f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
855
844
  )
856
845
  ret_code.fill_(1)
857
- dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group)
846
+ dist.all_reduce(ret_code, op=torch.distributed.ReduceOp.SUM, group=ranks_group)
858
847
  self.device_manager.device_module.synchronize()
859
848
  if ret_code.item() != 0:
860
849
  # quit early if any rank failed
@@ -864,6 +853,29 @@ class ParameterServer:
864
853
  gidx += 1
865
854
 
866
855
  socket.recv()
856
+ device_mem = self.device_manager.device_module.mem_get_info()
857
+ logger.info(
858
+ f"[rank{self._rank}] weights broadcast done, device mem usage: {(device_mem[1] - device_mem[0]) / 1024 / 1024:.2f} MB, allocated memory: {self.device_manager.device_module.memory_allocated() / 1024 / 1024:.2f} MB, reserved memory: {self.device_manager.device_module.memory_reserved() / 1024 / 1024:.2f} MB"
859
+ )
860
+ # Notify worker to release handle
861
+ socket.send_pyobj(None)
862
+ socket.recv()
863
+ # Set to None in correct order (views first, then base tensors)
864
+ del buffer_b, h2d_buffer, buffer, handle
865
+ self.device_manager.device_module.synchronize()
866
+ gc.collect()
867
+ self.device_manager.device_module.ipc_collect()
868
+ self.device_manager.device_module.empty_cache()
869
+ self.device_manager.device_module.synchronize()
870
+
871
+ # Log actual memory usage
872
+ device_mem = self.device_manager.device_module.mem_get_info()
873
+ logger.info(
874
+ f"[rank{self._rank}] post-release: device mem usage: {(device_mem[1] - device_mem[0]) / 1024 / 1024:.2f} MB, "
875
+ f"allocated: {self.device_manager.device_module.memory_allocated() / 1024 / 1024:.2f} MB, "
876
+ f"reserved: {self.device_manager.device_module.memory_reserved() / 1024 / 1024:.2f} MB"
877
+ )
878
+ # Notify worker to call post_hook
867
879
  socket.send_pyobj(None)
868
880
  socket.recv()
869
881
  finally:
@@ -10,6 +10,9 @@ import zmq
10
10
  from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
11
11
 
12
12
 
13
+ _WEIGHTS_TYPE = list[tuple[str, torch.Tensor]]
14
+
15
+
13
16
  def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
14
17
  func, args = handle
15
18
  list_args = list(args)
@@ -29,11 +32,9 @@ class FlattenedTensorMetadata(TypedDict):
29
32
  offset: int
30
33
 
31
34
 
32
- def _extract_weights(
33
- payload: list[FlattenedTensorMetadata], buffer: torch.Tensor
34
- ) -> list[tuple[str, torch.Tensor]]:
35
+ def _extract_weights(payload: list[FlattenedTensorMetadata], buffer: torch.Tensor) -> _WEIGHTS_TYPE:
35
36
  assert buffer is not None
36
- weights: list[tuple[str, torch.Tensor]] = []
37
+ weights: _WEIGHTS_TYPE = []
37
38
  for item in payload:
38
39
  shape = item["shape"]
39
40
  if isinstance(shape, list | tuple):
@@ -69,15 +70,35 @@ def update_weights_from_ipc(
69
70
  socket.send_string(msg)
70
71
  socket.recv() # wait for ack
71
72
  raise
73
+ # State machine:
74
+ # + receive tensor_metadata -> update_weights
75
+ # + receive Exception -> raise and stop
76
+ # + receive None first time -> release resources
77
+ # + receive None second time -> call post_hook and stop
72
78
  try:
79
+ released = False
73
80
  while True:
74
81
  payload: list[FlattenedTensorMetadata] | Exception | None = socket.recv_pyobj()
75
- if payload is None: # done signal
82
+ if released:
83
+ assert payload is None, "Should not receive any payload after released"
76
84
  if post_hook is not None:
77
85
  post_hook()
78
86
  device_manager.device_module.synchronize()
79
87
  socket.send(b"")
80
88
  break
89
+ if payload is None: # done signal
90
+ # TODO: wrap all messages to an object instead of None and Exception
91
+ device_manager.device_module.synchronize()
92
+ released = True
93
+ buffer = None
94
+ del ipc_handle
95
+
96
+ gc.collect()
97
+ device_manager.device_module.ipc_collect()
98
+ device_manager.device_module.empty_cache()
99
+ device_manager.device_module.synchronize()
100
+ socket.send(b"")
101
+ continue
81
102
  if isinstance(payload, list): # still updating weights
82
103
  try:
83
104
  run(_extract_weights(payload, buffer))
@@ -166,12 +187,31 @@ class VllmColocateWorkerExtension:
166
187
  self.device = torch.device(f"npu:{self.local_rank}")
167
188
  assert self.device is not None
168
189
 
190
+ def _load_weights(weights: _WEIGHTS_TYPE):
191
+ # Load main model weights
192
+ self.model_runner.model.load_weights(weights)
193
+ # Load drafter model weights if MTP/speculative decoding is enabled
194
+ if (
195
+ getattr(self.model_runner, "drafter", None) is not None
196
+ and getattr(self.model_runner.drafter, "model", None) is not None
197
+ ):
198
+ self.model_runner.drafter.model.load_weights(weights=weights)
199
+
200
+ def _post_hook():
201
+ process_weights_after_loading(self.model_runner.model, self.model_config, self.device)
202
+ # Also trigger drafter model's post processing if MTP is enabled
203
+ if (
204
+ getattr(self.model_runner, "drafter", None) is not None
205
+ and getattr(self.model_runner.drafter, "model", None) is not None
206
+ ):
207
+ process_weights_after_loading(
208
+ self.model_runner.drafter.model, self.model_config, self.device
209
+ )
210
+
169
211
  update_weights_from_ipc(
170
212
  self._zmq_ctx,
171
213
  zmq_handles[self._device_uuid],
172
214
  device_id=self.device.index,
173
- run=self.model_runner.model.load_weights,
174
- post_hook=lambda: process_weights_after_loading(
175
- self.model_runner.model, self.model_config, self.device
176
- ),
215
+ run=_load_weights,
216
+ post_hook=_post_hook,
177
217
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.3.3
3
+ Version: 0.4.0
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -0,0 +1,19 @@
1
+ checkpoint_engine/__init__.py,sha256=OeWxe9mxl2sZ6cW-blSTg6JbFlOMpGbBghLZtxGOqXk,942
2
+ checkpoint_engine/__main__.py,sha256=yzQlApuYo6eIOqtqM018RosyxNzXzB5a-stxUvsh-dg,709
3
+ checkpoint_engine/_version.py,sha256=2_0GUP7yBCXRus-qiJKxQD62z172WSs1sQ6DVpPsbmM,704
4
+ checkpoint_engine/api.py,sha256=JDiQ4i3Gb6GoaBhlp8lNuUPaVURoFFdeGJY9ZDDGvPc,3518
5
+ checkpoint_engine/data_types.py,sha256=O9uAXjwB20iwrOHfEEQd8Y9CmaFspNJ9ks9noHqwQKk,2716
6
+ checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
7
+ checkpoint_engine/p2p_store.py,sha256=abiCDVmRISPt9QFfavHB9Jo7ZpBbSjUS1NevGuB-AVA,8721
8
+ checkpoint_engine/pin_memory.py,sha256=b7nABKJV2bSIsOfX2YTHzUk1OkOze6AQjCaOIFaQnbA,16708
9
+ checkpoint_engine/ps.py,sha256=DQ9-hvZJW0eA9d6bU1glIbSwYl4cZCmVRjFPPRF41YY,41957
10
+ checkpoint_engine/worker.py,sha256=fTWiF6Gggehzjx4mnIFTDZFR-GwkEBUdTAC_ZLmsgZE,8649
11
+ checkpoint_engine/distributed/__init__.py,sha256=fC0EEX1nfWkg8OolzAj5vd2P0x6s4hScOlwV8q8Uiik,492
12
+ checkpoint_engine/distributed/base.py,sha256=dpdjcGXNdCdAUDPnX-vdJmCGXbGS4A69yNsd60t-UgA,7800
13
+ checkpoint_engine/distributed/vllm_hccl.py,sha256=bLE-GrnOxu1GTw_2GIqu2o67_Sw7vgjzJnlMvvQz_8c,10313
14
+ checkpoint_engine/distributed/vllm_nccl.py,sha256=nHnlY1jk--xNEjKDDnywx36FgrnjEGc9lrBBC3o-YzE,7015
15
+ checkpoint_engine-0.4.0.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
16
+ checkpoint_engine-0.4.0.dist-info/METADATA,sha256=qtS4bAI6SC3nBatKqAi5EVkty2zkvEZOWfshpaswF6k,11559
17
+ checkpoint_engine-0.4.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
18
+ checkpoint_engine-0.4.0.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
19
+ checkpoint_engine-0.4.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,15 +0,0 @@
1
- checkpoint_engine/__init__.py,sha256=OeWxe9mxl2sZ6cW-blSTg6JbFlOMpGbBghLZtxGOqXk,942
2
- checkpoint_engine/__main__.py,sha256=yzQlApuYo6eIOqtqM018RosyxNzXzB5a-stxUvsh-dg,709
3
- checkpoint_engine/_version.py,sha256=lemL_4Kl75FgrO6lVuFrrtw6-Dcf9wtXBalKkXuzkO4,704
4
- checkpoint_engine/api.py,sha256=JDiQ4i3Gb6GoaBhlp8lNuUPaVURoFFdeGJY9ZDDGvPc,3518
5
- checkpoint_engine/data_types.py,sha256=O9uAXjwB20iwrOHfEEQd8Y9CmaFspNJ9ks9noHqwQKk,2716
6
- checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
7
- checkpoint_engine/p2p_store.py,sha256=abiCDVmRISPt9QFfavHB9Jo7ZpBbSjUS1NevGuB-AVA,8721
8
- checkpoint_engine/pin_memory.py,sha256=b7nABKJV2bSIsOfX2YTHzUk1OkOze6AQjCaOIFaQnbA,16708
9
- checkpoint_engine/ps.py,sha256=wBsHu2qWy5oRBrvLc7aEOroG_j58UJoWT6lFH4ylMRk,41092
10
- checkpoint_engine/worker.py,sha256=ghj9d2u8hY_U2uiOZWIN2CqRNZH6PrzujT22fHUFBWI,6879
11
- checkpoint_engine-0.3.3.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
12
- checkpoint_engine-0.3.3.dist-info/METADATA,sha256=WyyGLw1qrteQgRGOWhAm15NN2nzklTqw4iiQ9U2nYpQ,11559
13
- checkpoint_engine-0.3.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
14
- checkpoint_engine-0.3.3.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
15
- checkpoint_engine-0.3.3.dist-info/RECORD,,