checkpoint-engine 0.3.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1576 @@
1
+ import argparse
2
+ import concurrent.futures
3
+ import ctypes
4
+ import json
5
+ import os
6
+ import pickle
7
+ import random
8
+ import threading
9
+ import time
10
+ from collections import defaultdict
11
+ from collections.abc import Callable
12
+ from datetime import timedelta
13
+ from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
14
+
15
+ import httpx
16
+ import numpy as np
17
+ import torch
18
+ import torch.distributed as dist
19
+ import zmq
20
+ from loguru import logger
21
+ from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
22
+ from safetensors.torch import _getdtype, safe_open
23
+ from torch.multiprocessing.reductions import reduce_tensor
24
+
25
+ from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
26
+
27
+
28
+ if TYPE_CHECKING:
29
+ from typing import TypeVar
30
+
31
+ from typing_extensions import TypedDict
32
+
33
+ class FileMeta(TypedDict):
34
+ key: str # parameter name
35
+ dtype: torch.dtype
36
+ shape: torch.Size
37
+ type: type
38
+ tp_concat_dim: int
39
+
40
+ T = TypeVar("T")
41
+
42
+
43
+ def _dt_validate(value: Any) -> torch.dtype:
44
+ if isinstance(value, str):
45
+ if not value.startswith("torch."):
46
+ raise ValueError(f"dtype {value} should start with torch.")
47
+ try:
48
+ value = getattr(torch, value.split(".")[1])
49
+ except AttributeError as e:
50
+ raise ValueError(f"unknown dtype: {value}") from e
51
+ if not isinstance(value, torch.dtype):
52
+ raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
53
+ return value
54
+
55
+
56
+ _TorchDtype = Annotated[
57
+ torch.dtype,
58
+ PlainValidator(_dt_validate),
59
+ PlainSerializer(lambda x: str(x), return_type=str),
60
+ WithJsonSchema({"type": "string"}, mode="serialization"),
61
+ ]
62
+
63
+
64
+ def _size_validate(value: Any) -> torch.Size:
65
+ if isinstance(value, list | tuple):
66
+ return torch.Size(value)
67
+ if not isinstance(value, torch.Size):
68
+ raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
69
+ return value
70
+
71
+
72
+ _TorchSize = Annotated[
73
+ torch.Size,
74
+ PlainValidator(_size_validate),
75
+ PlainSerializer(lambda x: tuple(x), return_type=tuple),
76
+ WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
77
+ ]
78
+
79
+
80
+ def _tensor_validate(value: Any) -> torch.Tensor:
81
+ if isinstance(value, torch.Tensor):
82
+ return value
83
+ raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}")
84
+
85
+
86
+ _TorchTensor = Annotated[
87
+ torch.Tensor,
88
+ PlainValidator(_tensor_validate),
89
+ ]
90
+
91
+
92
+ class ParameterMeta(BaseModel):
93
+ name: str
94
+ dtype: _TorchDtype
95
+ shape: _TorchSize
96
+ aligned_size: int
97
+
98
+
99
+ class BucketRange(NamedTuple):
100
+ idx: int # bucket_idx of MemoryBucket in memory_pool
101
+ offset: int
102
+ size: int
103
+
104
+
105
+ class H2DBucket(BaseModel):
106
+ size: int
107
+ ranges: list[BucketRange]
108
+ items: list[ParameterMeta]
109
+
110
+
111
+ class MemoryBufferMetas(BaseModel):
112
+ metas: list[ParameterMeta]
113
+ ptr: int
114
+ size: int
115
+
116
+
117
+ class MemoryBuffer(BaseModel):
118
+ buffer: _TorchTensor
119
+ size: int
120
+ metas: list[ParameterMeta]
121
+
122
+
123
+ class MemoryBufferMetaList(BaseModel):
124
+ p2p_store_addr: str | None
125
+ memory_buffer_metas_list: list[MemoryBufferMetas]
126
+ rdma_device: str
127
+
128
+
129
+ class DataToGather(MemoryBufferMetaList):
130
+ host_ip: str
131
+ device_uuid: str
132
+
133
+
134
+ # 256 bytes alignment when flatten torch tensors to uint8 buffer
135
+ _ALIGN_SIZE = 256
136
+
137
+
138
+ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
139
+ return (dtype.itemsize * shape.numel() + _ALIGN_SIZE - 1) // _ALIGN_SIZE * _ALIGN_SIZE
140
+
141
+
142
+ def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
143
+ ret = []
144
+ for meta in metas:
145
+ size = meta.aligned_size
146
+ ret.append(
147
+ {
148
+ "name": meta.name,
149
+ "dtype": meta.dtype,
150
+ "shape": meta.shape,
151
+ "offset": offset,
152
+ }
153
+ )
154
+ offset += size
155
+ return ret
156
+
157
+
158
+ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]:
159
+ def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
160
+ ret = {}
161
+ with safe_open(fn, framework="pt") as f:
162
+ for name in f.keys(): # noqa: SIM118
163
+ weight = f.get_tensor(name)
164
+ meta = {
165
+ "key": name,
166
+ "dtype": weight.dtype,
167
+ "shape": weight.shape,
168
+ "type": type(weight),
169
+ "tp_concat_dim": -1, # safetensors does not support tp_concat_dim
170
+ }
171
+ ret[name] = (meta, weight)
172
+ return ret
173
+
174
+ # deprecated, will be removed in the future
175
+ def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
176
+ """load *.np file and return memmap and related tensor meta"""
177
+
178
+ def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:
179
+ start = fin.tell()
180
+ major, minor = np.lib.format.read_magic(fin)
181
+ if major == 1 and minor == 0:
182
+ read_header_fn = np.lib.format.read_array_header_1_0
183
+ elif major == 2 and minor == 0:
184
+ read_header_fn = np.lib.format.read_array_header_2_0
185
+ else:
186
+ raise ValueError(
187
+ f"unknown version {major}.{minor} when parsing npy header from {fn}"
188
+ )
189
+ shape, is_fortran, dtype = read_header_fn(fin)
190
+ return {
191
+ "shape": shape,
192
+ "is_fortran": is_fortran,
193
+ "dtype": dtype,
194
+ "header_length": fin.tell() - start,
195
+ }
196
+
197
+ meta_fn = fn + ".meta"
198
+ with open(meta_fn, "rb") as fin:
199
+ meta_lst = pickle.load(fin)
200
+
201
+ tensors = []
202
+ offset = 0
203
+ with open(fn, "rb") as fin:
204
+ fin.seek(0, os.SEEK_END)
205
+ filesize = fin.tell()
206
+ fin.seek(0)
207
+ while fin.tell() < filesize:
208
+ tensor_meta = parse_npy_header(fin)
209
+ tensor = np.memmap(
210
+ fn,
211
+ dtype=tensor_meta["dtype"],
212
+ mode="c",
213
+ offset=offset + tensor_meta["header_length"],
214
+ shape=tensor_meta["shape"],
215
+ )
216
+ offset += tensor_meta["header_length"] + tensor.nbytes
217
+ fin.seek(offset)
218
+ tensors.append(tensor)
219
+
220
+ assert len(meta_lst) == len(tensors)
221
+ ret = {}
222
+ for meta, tensor in zip(meta_lst, tensors):
223
+ if meta["type"] == torch.Tensor:
224
+ tensor = torch.from_numpy(tensor)
225
+ tensor = tensor.view(dtype=meta["dtype"]).view(*meta["shape"])
226
+ ret[meta["key"]] = (meta, tensor)
227
+ return ret
228
+
229
+ tp_rank = 0
230
+ if file_path.endswith(".npy"):
231
+ logger.warning("numpy model file is deprecated, will be removed in the future")
232
+ filename_split = os.path.basename(file_path).split(".")
233
+ # if using numpy and want to specify tp rank
234
+ # file should be in model.{layer}.{tp}[.{ep}].npy format
235
+ tp_rank = int(filename_split[2]) if len(filename_split) > 3 else 0
236
+ ret = _fast_np_load(file_path)
237
+ elif file_path.endswith(".safetensors"):
238
+ ret = _safetensors_load(file_path)
239
+ else:
240
+ raise ValueError(f"unsupported file format: {file_path}")
241
+ return tp_rank, ret
242
+
243
+
244
+ def _concat_tp_weights(
245
+ tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_size: int
246
+ ) -> torch.Tensor:
247
+ """Concat tp weights with meta info.
248
+ If meta.concat_dim is -1, meas this is shared tp weights, just use the first weights.
249
+ Else we will cat weights in concat_dim.
250
+ """
251
+ if tp_concat_dim == -1:
252
+ return tp_weights[0]
253
+ assert tp_size == len(tp_weights)
254
+ if len(tp_weights) == 1:
255
+ return tp_weights[0]
256
+ return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
257
+
258
+
259
+ def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str:
260
+ try:
261
+ if device_manager.device_type == "npu":
262
+ return f"NPU-{npu_generate_uuid()}"
263
+ else:
264
+ return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}"
265
+ except AssertionError as e:
266
+ raise ValueError(f"fail to get physical gpu id {device_index}") from e
267
+
268
+
269
+ def _ibv_get_device_list() -> list[str]:
270
+ lib = ctypes.CDLL("libibverbs.so.1")
271
+ lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
272
+ lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **
273
+
274
+ lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
275
+ lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
276
+ lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *
277
+
278
+ num = ctypes.c_int()
279
+ dev_array = lib.ibv_get_device_list(ctypes.byref(num))
280
+ if not dev_array or num.value <= 0:
281
+ return []
282
+
283
+ devices = []
284
+ for i in range(num.value):
285
+ dev_ptr = dev_array[i] # struct ibv_device *
286
+ name = lib.ibv_get_device_name(dev_ptr) # const char *
287
+ devices.append(name.decode())
288
+ lib.ibv_free_device_list(dev_array)
289
+ return devices
290
+
291
+
292
+ def _get_rdma_devices() -> list[str]:
293
+ """
294
+ use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
295
+ """
296
+ devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
297
+ if devices_str:
298
+ return devices_str.split(",")
299
+ # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
300
+ hca = os.getenv("NCCL_IB_HCA", None)
301
+ return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()
302
+
303
+
304
+ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
305
+ """
306
+ implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
307
+ """
308
+ if not devices:
309
+ raise RuntimeError("no rdma devices found")
310
+ try:
311
+ assert len(devices) <= gpu_count, (
312
+ f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
313
+ )
314
+ assert gpu_count % len(devices) == 0, (
315
+ f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
316
+ )
317
+ return devices[local_rank // (gpu_count // len(devices))]
318
+ except AssertionError:
319
+ logger.error(
320
+ "Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
321
+ "The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
322
+ "The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
323
+ )
324
+ raise
325
+
326
+
327
+ def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
328
+ """
329
+ The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
330
+ The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.
331
+
332
+ The list is comma-separated; port numbers are NOT supported yet.
333
+ An optional prefix '^' indicates the list is an exclude list.
334
+ A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
335
+ Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.
336
+
337
+ Examples:
338
+ - `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`.
339
+ - `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`.
340
+ - `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`.
341
+ - `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`.
342
+ """
343
+ max_hcas = 32
344
+ if not value or value.strip() == "":
345
+ return available_devices[:max_hcas]
346
+
347
+ value = value.strip()
348
+ result = []
349
+ is_exclude = value.startswith("^")
350
+ if is_exclude:
351
+ value = value.removeprefix("^")
352
+ is_exact_match = value.startswith("=")
353
+ if is_exact_match:
354
+ value = value.removeprefix("=")
355
+
356
+ device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]
357
+
358
+ result = _resolve_device_specs(device_specs, is_exact_match, available_devices)
359
+ if is_exclude:
360
+ result = [dev for dev in available_devices if dev not in result]
361
+ if len(result) > max_hcas:
362
+ result = result[:max_hcas]
363
+
364
+ logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")
365
+
366
+ return result
367
+
368
+
369
+ def _resolve_device_specs(
370
+ device_specs: list[str], is_exact_match: bool, available_devices: list[str]
371
+ ) -> list[str]:
372
+ devices = set()
373
+ for spec in device_specs:
374
+ parts = spec.split(":", 1)
375
+ device_name = parts[0].strip()
376
+ # HACK: mooncake transfer engine does not support port specification yet, so we ignore it
377
+ # port = parts[1].strip() if len(parts) > 1 else None
378
+ base_devices = (
379
+ [device_name]
380
+ if device_name in available_devices
381
+ else []
382
+ if is_exact_match
383
+ else [dev for dev in available_devices if dev.startswith(device_name)]
384
+ )
385
+
386
+ if not base_devices:
387
+ logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.")
388
+ continue
389
+
390
+ for base_dev in base_devices:
391
+ devices.add(base_dev)
392
+
393
+ return sorted(devices)
394
+
395
+
396
+ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
397
+ class TPMeta(BaseModel):
398
+ concat_dim: int
399
+ size: int
400
+
401
+ parameters: dict[str, torch.Tensor] = {}
402
+ parameter_metas: dict[str, ParameterMeta] = {}
403
+ tp_metas: dict[str, TPMeta] = {}
404
+ parameters_with_tp: dict[str, dict[int, torch.Tensor]] = {}
405
+ for file in files:
406
+ tp_rank, ret = _load_checkpoint_file(file)
407
+ for parameter_name, (meta, weight) in ret.items():
408
+ if parameter_name not in parameters_with_tp:
409
+ parameters_with_tp[parameter_name] = {}
410
+ parameters_with_tp[parameter_name][tp_rank] = weight
411
+ if parameter_name not in tp_metas:
412
+ tp_metas[parameter_name] = TPMeta(
413
+ concat_dim=meta["tp_concat_dim"],
414
+ size=1,
415
+ )
416
+ if parameter_name not in parameter_metas:
417
+ assert isinstance(meta["dtype"], torch.dtype), (
418
+ f"meta {meta} dtype should be torch.dtype"
419
+ )
420
+ assert isinstance(meta["shape"], torch.Size), (
421
+ f"meta {meta} shape should be torch.Size"
422
+ )
423
+ parameter_metas[parameter_name] = ParameterMeta(
424
+ name=parameter_name,
425
+ shape=meta["shape"],
426
+ dtype=meta["dtype"],
427
+ aligned_size=_align_size(meta["dtype"], meta["shape"]),
428
+ )
429
+ tp_meta = tp_metas[parameter_name]
430
+ if tp_meta.concat_dim != -1:
431
+ tp_meta.size = max(tp_meta.size, tp_rank + 1)
432
+ for name, tp_meta in tp_metas.items():
433
+ if tp_meta.concat_dim != -1:
434
+ shape = list(parameter_metas[name].shape)
435
+ shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
436
+ parameter_metas[name] = ParameterMeta(
437
+ name=name,
438
+ shape=torch.Size(shape),
439
+ dtype=parameter_metas[name].dtype,
440
+ aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
441
+ )
442
+ weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
443
+ # TODO: here concat is serial, which may be slow
444
+ # but since tp storage is not used in the future
445
+ # we ignore this performance issue for now
446
+ parameters[name] = _concat_tp_weights(weights_in_cpu, tp_meta.concat_dim, tp_meta.size)
447
+ for name, parameter in parameters.items():
448
+ assert name in parameter_metas, f"parameter {name} not found in parameter_metas"
449
+ assert parameter_metas[name].shape == parameter.shape, (
450
+ f"parameter {name} shape mismatch, {parameter_metas[name].shape} != {parameter.shape}"
451
+ )
452
+ assert parameter_metas[name].dtype == parameter.dtype, (
453
+ f"parameter {name} dtype mismatch, {parameter_metas[name].dtype} != {parameter.dtype}"
454
+ )
455
+ return parameters
456
+
457
+
458
+ def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
459
+ def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
460
+ """
461
+ safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
462
+ We load the safetensors file as bytes, then parse the header manually to get parameter metas.
463
+ The actual tensor data is in the remaining bytes and is naturally aligned.
464
+ We pin the remaining bytes as the buffer, making pinning faster.
465
+ """
466
+
467
+ def _pin(t: torch.Tensor):
468
+ """
469
+ Pin the memory of tensor in-place.
470
+ See: https://github.com/pytorch/pytorch/issues/32167
471
+ """
472
+ cudart = torch.cuda.cudart()
473
+ r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
474
+ assert r == 0, f"pin memory error, error code: {r}"
475
+
476
+ # TODO: should only support /dev/shm? but we found files in disk also work?
477
+ size = os.stat(file_path).st_size
478
+ flag_size = 8
479
+ t = torch.from_file(file_path, True, size, dtype=torch.uint8)
480
+ assert t.nbytes > flag_size, (
481
+ f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
482
+ )
483
+ start_pos = (
484
+ int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
485
+ + flag_size
486
+ )
487
+ header_tensor = t[flag_size:start_pos]
488
+ header = json.loads(header_tensor.numpy().tobytes())
489
+ if "__metadata__" in header:
490
+ header.pop("__metadata__")
491
+
492
+ metas: list[ParameterMeta] = []
493
+ offset = 0
494
+ try:
495
+ for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
496
+ start, end = meta["data_offsets"]
497
+ # safetensors format ensures offsets are aligned
498
+ assert offset == start, f"offset {offset} should be equal to start {start}"
499
+ metas.append(
500
+ ParameterMeta(
501
+ name=name,
502
+ dtype=_getdtype(meta["dtype"]),
503
+ shape=torch.Size(meta["shape"]),
504
+ aligned_size=end - start,
505
+ )
506
+ )
507
+ offset = end
508
+ except Exception as e:
509
+ logger.error(f"fail to parse safetensors header from {file_path}: {e}")
510
+ raise
511
+
512
+ buffer = t[start_pos:]
513
+ assert offset == buffer.nbytes, (
514
+ f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
515
+ )
516
+ # Remove the file after successfully loading. This will avoid doubling the memory usage.
517
+ # We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
518
+ os.remove(file_path)
519
+ _pin(buffer)
520
+ logger.info(
521
+ f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
522
+ )
523
+ return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
524
+
525
+ memory_buffers: list[MemoryBuffer] = []
526
+ with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
527
+ memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
528
+ return memory_buffers
529
+
530
+
531
+ def _normal_pin_memory(
532
+ files: list[str],
533
+ named_tensors: dict[str, torch.Tensor],
534
+ rank: int | None = None,
535
+ shared_pin_memory: list[MemoryBuffer] | None = None,
536
+ ) -> list[MemoryBuffer]:
537
+ parameters = _load_checkpoint(files)
538
+ if named_tensors:
539
+ parameters.update(named_tensors)
540
+ bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values()))
541
+
542
+ class MemoryBucket(BaseModel):
543
+ size: int
544
+ metas: list[ParameterMeta]
545
+
546
+ buckets: list[MemoryBucket] = []
547
+ buckets.append(MemoryBucket(size=0, metas=[]))
548
+ for name, tensor in sorted(parameters.items()):
549
+ size = _align_size(tensor.dtype, tensor.shape)
550
+ if buckets[-1].size + size > bucket_size:
551
+ assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
552
+ buckets.append(MemoryBucket(size=0, metas=[]))
553
+ buckets[-1].metas.append(
554
+ ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
555
+ )
556
+ buckets[-1].size += size
557
+
558
+ memory_buffers = [
559
+ MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
560
+ for bucket in buckets
561
+ ]
562
+
563
+ def register_pin_memory(
564
+ idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
565
+ ) -> tuple[int, torch.Tensor]:
566
+ if shared_pin_memory:
567
+ # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
568
+ # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
569
+ assert idx < len(shared_pin_memory), (
570
+ f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
571
+ )
572
+ assert shared_pin_memory[idx].size == size, (
573
+ f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
574
+ )
575
+ return idx, shared_pin_memory[idx].buffer
576
+ else:
577
+ buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
578
+ return idx, buffer
579
+
580
+ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
581
+ buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
582
+
583
+ with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
584
+ futures = [
585
+ executor.submit(
586
+ register_pin_memory,
587
+ idx,
588
+ bucket.size,
589
+ shared_pin_memory,
590
+ )
591
+ for idx, bucket in enumerate(buckets)
592
+ ]
593
+ new_futures = []
594
+ for future in concurrent.futures.as_completed(futures):
595
+ idx, buffer = future.result()
596
+ assert buffer.numel() == buckets[idx].size, (
597
+ f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}"
598
+ )
599
+ memory_buffers[idx].buffer = buffer
600
+ logger.info(
601
+ f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, "
602
+ f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer"
603
+ )
604
+ offset = 0
605
+ for meta in buckets[idx].metas:
606
+ name = meta.name
607
+ tensor = parameters[name]
608
+ size = _align_size(tensor.dtype, tensor.shape)
609
+ assert size == _align_size(meta.dtype, meta.shape), (
610
+ f"tensor {name} size {size} should be equal to meta size {_align_size(meta.dtype, meta.shape)}"
611
+ )
612
+ new_futures.append(executor.submit(register_tensor, buffer, offset, tensor))
613
+ offset += size
614
+ for future in concurrent.futures.as_completed(new_futures):
615
+ future.result()
616
+ return memory_buffers
617
+
618
+
619
+ def _register_checkpoint(
620
+ *,
621
+ files: list[str],
622
+ named_tensors: dict[str, torch.Tensor],
623
+ rank: int | None = None,
624
+ shared_pin_memory: list[MemoryBuffer] | None = None,
625
+ ) -> list[MemoryBuffer]:
626
+ logger.info(
627
+ f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
628
+ )
629
+ if not files and not named_tensors:
630
+ return []
631
+ memory_buffers: list[MemoryBuffer] = []
632
+ files_to_inplace_pin = [
633
+ file
634
+ for file in files
635
+ if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
636
+ ]
637
+ files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
638
+ if files_to_normal_pin or named_tensors:
639
+ memory_buffers.extend(
640
+ _normal_pin_memory(
641
+ files=files_to_normal_pin,
642
+ named_tensors=named_tensors,
643
+ rank=rank,
644
+ shared_pin_memory=shared_pin_memory,
645
+ )
646
+ )
647
+ if files_to_inplace_pin:
648
+ memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
649
+ return memory_buffers
650
+
651
+
652
+ def request_inference_to_update(
653
+ url: str,
654
+ socket_paths: dict[str, str],
655
+ timeout: float = 300.0,
656
+ uds: str | None = None,
657
+ ):
658
+ """Send an inference update request to inference server via HTTP or Unix socket.
659
+
660
+ Args:
661
+ url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
662
+ socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
663
+ timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
664
+ uds (str, optional): Path to a Unix domain socket. If provided, the request
665
+ will be sent via the Unix socket instead of HTTP. Defaults to None.
666
+
667
+ Raises:
668
+ httpx.HTTPStatusError: If the response contains an HTTP error status.
669
+ httpx.RequestError: If there was an issue while making the request.
670
+ """
671
+ resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
672
+ url,
673
+ json={
674
+ "method": "update_weights_from_ipc",
675
+ "args": [socket_paths],
676
+ "timeout": timeout,
677
+ },
678
+ timeout=timeout,
679
+ )
680
+ resp.raise_for_status()
681
+
682
+
683
+ def _gen_h2d_buckets(
684
+ global_metas: dict[int, MemoryBufferMetaList],
685
+ bucket_size: int,
686
+ local_topo: dict[str, set[int]],
687
+ remote_topo: dict[str, set[int]],
688
+ ranks: list[int] | None = None,
689
+ ) -> list[tuple[int, int, H2DBucket]]:
690
+ buckets: list[tuple[int, H2DBucket]] = []
691
+
692
+ for owner_rank, items in global_metas.items():
693
+ buckets.append((owner_rank, H2DBucket(size=0, ranges=[], items=[])))
694
+ for idx, metas in enumerate(items.memory_buffer_metas_list):
695
+ start_offset, offset = 0, 0
696
+ for meta in metas.metas:
697
+ s = meta.aligned_size
698
+ if buckets[-1][1].size + s > bucket_size:
699
+ if offset - start_offset > 0:
700
+ buckets[-1][1].ranges.append(
701
+ BucketRange(idx, start_offset, offset - start_offset)
702
+ )
703
+ start_offset = offset
704
+ buckets.append((owner_rank, H2DBucket(size=0, ranges=[], items=[])))
705
+ offset += s
706
+ buckets[-1][1].size += s
707
+ buckets[-1][1].items.append(meta)
708
+ buckets[-1][1].ranges.append(BucketRange(idx, start_offset, offset - start_offset))
709
+ assert buckets[-1][1].size > 0, (
710
+ f"buckets[-1][1].size {buckets[-1][1].size} should be greater than 0"
711
+ )
712
+ ranks_set = set(ranks) if ranks else set()
713
+ actual_local_topo = (
714
+ {k: v & ranks_set for k, v in local_topo.items() if v & ranks_set} if ranks else local_topo
715
+ )
716
+ # if ranks is empty, assign the owner_rank as receiver_rank, this is used for colocate architecture
717
+ if not ranks:
718
+ return [(owner_rank, owner_rank, bucket) for owner_rank, bucket in buckets]
719
+ else:
720
+ return _assign_receiver_ranks(buckets, actual_local_topo, remote_topo)
721
+
722
+
723
+ def _assign_receiver_ranks(
724
+ buckets: list[tuple[int, "T"]],
725
+ local_topo: dict[str, set[int]],
726
+ remote_topo: dict[str, set[int]],
727
+ ) -> list[tuple[int, int, "T"]]:
728
+ """
729
+ (owner_rank, bucket) -> (receiver_rank, owner_rank, bucket)
730
+
731
+ Assign receiver ranks to buckets. If ranks is empty, assign the owner_rank as receiver_rank.
732
+ GPU-rdma_device topology will be considered to make full use of the bandwidth.
733
+ """
734
+ if not buckets:
735
+ logger.warning("bucket list is empty, no need to assign receiver ranks")
736
+ return []
737
+ rank_to_rdma_device = {
738
+ rank: rdma_device for rdma_device, ranks in remote_topo.items() for rank in ranks
739
+ }
740
+
741
+ # group buckets by owner RDMA devices
742
+ buckets_by_rdma_device = defaultdict(list)
743
+ for owner_rank, bucket in buckets:
744
+ owner_rdma_device = rank_to_rdma_device[owner_rank]
745
+ buckets_by_rdma_device[owner_rdma_device].append((owner_rank, bucket))
746
+
747
+ buckets_matrix = list(buckets_by_rdma_device.values())
748
+ assert buckets_matrix, "buckets_matrix should not be empty"
749
+
750
+ # Select receiver ranks. We use the minimum rank in each local RDMA device group as receiver rank
751
+ num_receivers = min(len(local_topo), len(buckets_by_rdma_device))
752
+ receiver_list = [min(ranks) for ranks in list(local_topo.values())[:num_receivers]]
753
+
754
+ flattened_buckets = [
755
+ buckets_matrix[row][col]
756
+ for col in range(
757
+ max(len(matrix_row) for matrix_row in buckets_matrix) if buckets_matrix else 0
758
+ )
759
+ for row in range(len(buckets_matrix))
760
+ if col < len(buckets_matrix[row])
761
+ ]
762
+
763
+ buckets_with_receiver = []
764
+ assigned_cnt = 0
765
+ while assigned_cnt < len(flattened_buckets):
766
+ occupied_devices = set()
767
+ for receiver_rank in receiver_list:
768
+ if assigned_cnt >= len(flattened_buckets):
769
+ break
770
+ owner_rank, bucket = flattened_buckets[assigned_cnt]
771
+ rdma_device = rank_to_rdma_device[owner_rank]
772
+ if rdma_device in occupied_devices:
773
+ break
774
+ buckets_with_receiver.append((receiver_rank, owner_rank, bucket))
775
+ occupied_devices.add(rdma_device)
776
+ assigned_cnt += 1
777
+
778
+ return buckets_with_receiver
779
+
780
+
781
+ def _get_master_port(master_port: int | None = None) -> int:
782
+ if master_port is None:
783
+ # HACK: use MASTER_PORT + 1 as master_port, avoid conflict with torchrun's rendezvous port
784
+ # TODO: check whether master_port is available or use a more elegant way
785
+ master_port = int(os.getenv("MASTER_PORT")) + 1
786
+ return master_port
787
+
788
+
789
+ class P2PStore:
790
+ def __init__(self, device_manager: DeviceManager):
791
+ from mooncake.engine import TransferEngine
792
+
793
+ self.rank = int(os.getenv("RANK"))
794
+ gpu_count = device_manager.device_module.device_count()
795
+ local_rank = self.rank % gpu_count
796
+ device_type = device_manager.device_type
797
+ if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None:
798
+ self.device = ""
799
+ else:
800
+ self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
801
+ self.ip = get_ip()
802
+
803
+ # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
804
+ retry_count = 8
805
+ for i in range(retry_count):
806
+ self.engine = TransferEngine()
807
+ ret = self.engine.initialize(
808
+ self.ip,
809
+ "P2PHANDSHAKE",
810
+ "ascend_direct" if device_type == "npu" else "rdma",
811
+ self.device,
812
+ )
813
+ if ret == 0:
814
+ break
815
+ # sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
816
+ sleep_ms = random.randint(500, 2000)
817
+ logger.warning(
818
+ f"[rank{self.rank}] fail to initialize transfer engine, ret {ret}, retry {i + 1}/{retry_count} in {sleep_ms}ms"
819
+ )
820
+ time.sleep(sleep_ms / 1000)
821
+ else:
822
+ raise RuntimeError(f"[rank{self.rank}] fail to initialize transfer engine")
823
+ self.port = self.engine.get_rpc_port()
824
+ self.named_tensors: dict[str, torch.Tensor] = {}
825
+ logger.info(
826
+ f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}"
827
+ )
828
+
829
+ @property
830
+ def addr(self) -> str:
831
+ return f"{self.ip}:{self.port}"
832
+
833
+ def register_named_tensors(self, named_tensors: dict[str, torch.Tensor]):
834
+ buffer_addresses = [tensor.data_ptr() for tensor in named_tensors.values()]
835
+ capacities = [tensor.nbytes for tensor in named_tensors.values()]
836
+ self.named_tensors.update(named_tensors)
837
+ for i, name in enumerate(named_tensors.keys()):
838
+ logger.info(
839
+ f"[rank{self.rank}] p2p store register tensor {name} with addr {hex(buffer_addresses[i])} and capacity {capacities[i]}"
840
+ )
841
+ assert self.engine.batch_register_memory(buffer_addresses, capacities) == 0
842
+
843
+ def unregister_named_tensors(self, names: list[str]) -> int:
844
+ buffer_addresses = [self.named_tensors[name].data_ptr() for name in names]
845
+ assert self.engine.batch_unregister_memory(buffer_addresses) == 0
846
+ num_unregistered = 0
847
+ for i, name in enumerate(names):
848
+ del self.named_tensors[name]
849
+ logger.info(
850
+ f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}"
851
+ )
852
+ num_unregistered += 1
853
+ return num_unregistered
854
+
855
+ def batch_transfer_sync_read(
856
+ self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int]
857
+ ):
858
+ assert (
859
+ self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0
860
+ )
861
+
862
+
863
+ class ParameterServer:
864
+ shared_memory_pool_name = "__shared_memory_pool__"
865
+
866
+ def __init__(
867
+ self,
868
+ *,
869
+ rank: int | None = None,
870
+ world_size: int | None = None,
871
+ auto_pg: bool = False,
872
+ gpu_count: int | None = None,
873
+ mem_fraction: float | None = None,
874
+ ):
875
+ """
876
+ Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
877
+
878
+ Args:
879
+ auto_pg: Whether to automatically initialize the process group.
880
+ Notice that if auto_pg is True, will destroy the process group after update.
881
+ mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
882
+ """
883
+ self._rank = rank or int(os.environ.get("RANK", None))
884
+ self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
885
+ self.device_manager = DeviceManager()
886
+ self._gpu_count = gpu_count or self.device_manager.device_module.device_count()
887
+ self._local_rank = self._rank % self._gpu_count
888
+ self._auto_pg = auto_pg
889
+ self._all_hosts = []
890
+ self._global_device_uuids: list[str] = []
891
+ self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
892
+ self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
893
+ self._mem_fraction = mem_fraction or 0.9
894
+
895
+ assert self._rank is not None and self._rank >= 0, self._rank
896
+ assert self._world_size and self._world_size > 0, self._world_size
897
+ assert (
898
+ self._gpu_count is not None
899
+ and self._gpu_count > 0
900
+ and self._gpu_count <= self.device_manager.device_module.device_count()
901
+ ), self._gpu_count
902
+ assert (
903
+ self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
904
+ ), self._mem_fraction
905
+
906
+ self._zmq_ctx = zmq.Context()
907
+ self._zmq_addr_counter = 0
908
+
909
+ # stores the name of the checkpoint currently using the shared memory pool, or empty string if none
910
+ self._current_shared_memory_pool_user: str = ""
911
+ self._memory_pool: dict[str, list[MemoryBuffer]] = {}
912
+ self._memory_pool[self.shared_memory_pool_name] = []
913
+ # dict key is owner_rank, value is a bucket metas list in owner_rank
914
+ self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
915
+ # NPU transfer engine initialization requires prior set_device.
916
+ device_index = self._local_rank
917
+ self.device_manager.device_module.set_device(device_index)
918
+ try:
919
+ self._p2p_store = P2PStore(self.device_manager)
920
+ except ImportError as e:
921
+ logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}")
922
+ self._p2p_store = None
923
+
924
+ self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
925
+ self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
926
+
927
+ def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
928
+ if checkpoint_name == self._current_shared_memory_pool_user:
929
+ assert self._memory_pool[self.shared_memory_pool_name], (
930
+ f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it"
931
+ )
932
+ return self._memory_pool[self.shared_memory_pool_name]
933
+ elif checkpoint_name in self._memory_pool:
934
+ return self._memory_pool[checkpoint_name]
935
+ else:
936
+ raise RuntimeError(f"checkpoint {checkpoint_name} is not registered")
937
+
938
+ def _logger_rank0(self, msg: str):
939
+ if self._local_rank == 0:
940
+ logger.info(msg)
941
+
942
+ def get_metas(self) -> dict[int, MemoryBufferMetaList]:
943
+ return self._current_global_parameter_metas
944
+
945
+ def load_metas(self, metas: dict[int, MemoryBufferMetaList]):
946
+ self._current_global_parameter_metas = metas
947
+ self._remote_rdma_devices = defaultdict(set)
948
+ for i, meta in self._current_global_parameter_metas.items():
949
+ assert meta.rdma_device is not None, "meta.rdma_device should not be None"
950
+ assert meta.p2p_store_addr is not None, "meta.p2p_store_addr should not be None"
951
+ self._remote_rdma_devices[
952
+ meta.rdma_device + "@" + meta.p2p_store_addr.split(":")[0]
953
+ ].add(i)
954
+
955
+ def register_checkpoint(
956
+ self,
957
+ checkpoint_name: str,
958
+ *,
959
+ files: list[str] | None = None,
960
+ named_tensors: dict[str, torch.Tensor] | None = None,
961
+ use_shared_memory_pool: bool = False,
962
+ ) -> None:
963
+ """
964
+ Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
965
+ Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
966
+ Please make sure to copy the files to disks if you need to keep them.
967
+
968
+ Args:
969
+ checkpoint_name: The name of the checkpoint.
970
+ files: The safetensors files to register.
971
+ named_tensors: The named tensors to register.
972
+ use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory.
973
+ Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and
974
+ cannot accommodate checkpoints with different memory requirements.
975
+ To free the actual memory of the shared pool or to modify its shape,
976
+ please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
977
+ """
978
+ try:
979
+ if use_shared_memory_pool:
980
+ logger.info(
981
+ f"[rank{self._rank}] checkpoint {checkpoint_name} use shared memory pool"
982
+ )
983
+ assert self._current_shared_memory_pool_user == "", (
984
+ f"cannot register checkpoint {checkpoint_name} to shared memory pool, "
985
+ f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. "
986
+ f"This registration may cause unexpected conflicts."
987
+ )
988
+ # Since we set the uninitialized shared memory pool to empty list,
989
+ # we can check whether this is the first time to use shared memory pool
990
+ _is_first_time = not self._memory_pool[self.shared_memory_pool_name]
991
+ self._memory_pool[self.shared_memory_pool_name] = _register_checkpoint(
992
+ files=files or [],
993
+ named_tensors=named_tensors or {},
994
+ rank=self._rank,
995
+ shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
996
+ )
997
+ self._current_shared_memory_pool_user = checkpoint_name
998
+ if self._p2p_store is not None and _is_first_time:
999
+ self._register_parameters_to_p2p_store(checkpoint_name)
1000
+ else:
1001
+ assert checkpoint_name not in self._memory_pool, (
1002
+ f"checkpoint {checkpoint_name} already registered"
1003
+ )
1004
+ self._memory_pool[checkpoint_name] = _register_checkpoint(
1005
+ files=files or [], named_tensors=named_tensors or {}, rank=self._rank
1006
+ )
1007
+ if self._p2p_store is not None:
1008
+ self._register_parameters_to_p2p_store(checkpoint_name)
1009
+ except Exception:
1010
+ logger.exception(
1011
+ f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}"
1012
+ )
1013
+ if self._p2p_store is not None and not use_shared_memory_pool:
1014
+ self._unregister_parameters_from_p2p_store(checkpoint_name)
1015
+ self.unregister_checkpoint(checkpoint_name)
1016
+ raise
1017
+
1018
+ def unregister_checkpoint(self, checkpoint_name: str, force: bool = False) -> None:
1019
+ """
1020
+ Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint
1021
+ from p2p store if p2p store is initialized.
1022
+ Args:
1023
+ checkpoint_name: The name of the checkpoint.
1024
+ force: This flag is designed for shared memory pool user. If True, the memory for shared memory pool itself will be freed.
1025
+ If False, only the checkpoint name will be unregistered, and the shared memory pool will be kept for future use.
1026
+ """
1027
+ if (
1028
+ checkpoint_name not in self._memory_pool
1029
+ and checkpoint_name != self._current_shared_memory_pool_user
1030
+ ):
1031
+ logger.warning(
1032
+ f"[rank{self._rank}] unregister checkpoint name {checkpoint_name} not found"
1033
+ )
1034
+ return
1035
+
1036
+ if checkpoint_name == self._current_shared_memory_pool_user and not force:
1037
+ self._current_shared_memory_pool_user = ""
1038
+ return
1039
+
1040
+ if self._p2p_store is not None:
1041
+ num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name)
1042
+ logger.info(
1043
+ f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}"
1044
+ )
1045
+
1046
+ if checkpoint_name == self._current_shared_memory_pool_user:
1047
+ self._current_shared_memory_pool_user = ""
1048
+ del self._memory_pool[self.shared_memory_pool_name]
1049
+ self._memory_pool[self.shared_memory_pool_name] = []
1050
+ else:
1051
+ del self._memory_pool[checkpoint_name]
1052
+ # see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
1053
+ # this works by using torch>=2.5.0
1054
+ torch._C._host_emptyCache()
1055
+
1056
+ def gather_metas(self, checkpoint_name: str):
1057
+ """
1058
+ Gather the parameter metas from all ranks. This will gather memory_buffer, and other metadatas.
1059
+ This function should be called before update and init a new value to `self._current_global_parameter_metas`,
1060
+ which can be exported by using `self.get_metas` function.
1061
+ """
1062
+ if self._auto_pg and not dist.is_initialized():
1063
+ self.init_process_group()
1064
+ assert dist.is_initialized(), "process group is not initialized"
1065
+ metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore
1066
+ try:
1067
+ memory_pool = self._get_memory_pool(checkpoint_name)
1068
+ except RuntimeError:
1069
+ memory_pool = []
1070
+ metas = DataToGather(
1071
+ memory_buffer_metas_list=[
1072
+ MemoryBufferMetas(
1073
+ metas=x.metas,
1074
+ ptr=x.buffer.data_ptr(),
1075
+ size=x.size,
1076
+ )
1077
+ for x in memory_pool
1078
+ ],
1079
+ p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
1080
+ host_ip=get_ip(),
1081
+ device_uuid=self._device_uuid,
1082
+ rdma_device=self._rdma_device or "",
1083
+ )
1084
+
1085
+ dist.all_gather_object(metas_lst, metas)
1086
+
1087
+ self._current_global_parameter_metas = {}
1088
+
1089
+ num_parameters = 0
1090
+ all_hosts: list[str] = []
1091
+ global_device_uuids: list[str] = []
1092
+ for i, metas_buckets in enumerate(metas_lst):
1093
+ assert metas_buckets is not None, f"metas_buckets {i} should not be None"
1094
+ if i % self._gpu_count == 0 and not self._all_hosts:
1095
+ all_hosts.append(metas_buckets.host_ip)
1096
+ if not self._global_device_uuids:
1097
+ global_device_uuids.append(metas_buckets.device_uuid)
1098
+ if metas_buckets.memory_buffer_metas_list:
1099
+ self._current_global_parameter_metas[i] = MemoryBufferMetaList(
1100
+ memory_buffer_metas_list=metas_buckets.memory_buffer_metas_list,
1101
+ p2p_store_addr=metas_buckets.p2p_store_addr,
1102
+ rdma_device=metas_buckets.rdma_device,
1103
+ )
1104
+ num_parameters += sum(len(x.metas) for x in metas_buckets.memory_buffer_metas_list)
1105
+ self._local_rdma_devices[
1106
+ metas_buckets.rdma_device + "@" + metas_buckets.p2p_store_addr.split(":")[0]
1107
+ if metas_buckets.p2p_store_addr
1108
+ else metas_buckets.host_ip
1109
+ ].add(i)
1110
+ if not self._all_hosts:
1111
+ self._all_hosts = all_hosts
1112
+ if not self._global_device_uuids:
1113
+ self._global_device_uuids = global_device_uuids
1114
+ # Sender node and Receiver node have the same GPU-rdma_device topology is considered as default.
1115
+ # Rewrite the sender's topology (_remote_rdma_devices) by calling load_metas.
1116
+ self._remote_rdma_devices = self._local_rdma_devices.copy()
1117
+ logger.info(
1118
+ f"[rank{self._rank}] gather parameter metas finished, num_parameters: {num_parameters}"
1119
+ )
1120
+
1121
+ def init_process_group(
1122
+ self,
1123
+ *,
1124
+ master_addr: str | None = None,
1125
+ master_port: int | None = None,
1126
+ timeout: timedelta = timedelta(minutes=10),
1127
+ ):
1128
+ """
1129
+ Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
1130
+
1131
+ Args:
1132
+ master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
1133
+ timeout: The timeout of the process group.
1134
+ """
1135
+ master_addr = master_addr or os.getenv("MASTER_ADDR")
1136
+ assert master_addr, "master_addr is required"
1137
+ store = dist.TCPStore(
1138
+ master_addr,
1139
+ _get_master_port(master_port),
1140
+ self._world_size,
1141
+ timeout=timeout,
1142
+ is_master=self._rank == 0,
1143
+ )
1144
+ dist.init_process_group(
1145
+ backend=self.device_manager.backend,
1146
+ world_size=self._world_size,
1147
+ rank=self._rank,
1148
+ timeout=timeout,
1149
+ store=store,
1150
+ )
1151
+ logger.info(f"[rank{self._rank}] init process group successfully.")
1152
+
1153
+ def store_based_barrier(
1154
+ self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5)
1155
+ ) -> None:
1156
+ """
1157
+ Perform a store-based barrier synchronization across all ranks.
1158
+
1159
+ This barrier uses a TCP store directly rather than a process group,
1160
+ allowing all ranks to synchronize regardless of which process group
1161
+ they belong to.
1162
+
1163
+ Args:
1164
+ store: The TCPStore instance to use for synchronization.
1165
+ """
1166
+ dist.distributed_c10d._store_based_barrier(
1167
+ rank=self._rank,
1168
+ store=store,
1169
+ group_name="parameter_server_barrier",
1170
+ rendezvous_count=self._world_size,
1171
+ timeout=timeout,
1172
+ )
1173
+
1174
+ def update(
1175
+ self,
1176
+ checkpoint_name: str,
1177
+ req_func: Callable[[list[tuple[str, str]]], None],
1178
+ *,
1179
+ timeout: timedelta = timedelta(minutes=10),
1180
+ ranks: list[int] | None = None,
1181
+ master_addr: str | None = None,
1182
+ master_port: int | None = None,
1183
+ ) -> None:
1184
+ """
1185
+ Update the checkpoint to inference engine. This function should be called after gather_metas.
1186
+
1187
+ Args:
1188
+ checkpoint_name: The name of the checkpoint.
1189
+ req_func: The function to request the inference of inference engine.
1190
+ ranks: The ranks to update. If not set, will use fully broadcast to update to all ranks,
1191
+ which is the fastest way to update weights, especially in colocated architecture.
1192
+ If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
1193
+ which is useful in disaggregated architecture.
1194
+ master_addr: The master address for process group initialization. If not set, will use env MASTER_ADDR.
1195
+ master_port: The master port for process group initialization. If not set, will use _get_master_port to get the port, which will use MASTER_PORT+1.
1196
+ timeout: The timeout of the barrier operation.
1197
+ """
1198
+ assert req_func is not None, "req_func is required"
1199
+ ranks_group = None
1200
+ try:
1201
+ master_addr = os.getenv("MASTER_ADDR") or master_addr
1202
+ assert master_addr, "master_addr is required"
1203
+ if self._auto_pg:
1204
+ if not dist.is_initialized():
1205
+ self.init_process_group(
1206
+ timeout=timeout, master_addr=master_addr, master_port=master_port
1207
+ )
1208
+ manager_store = dist.distributed_c10d._get_default_store()
1209
+ else:
1210
+ # HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
1211
+ # If master_port is provided, use master_port+1 for barrier store
1212
+ manager_store = dist.TCPStore(
1213
+ master_addr,
1214
+ _get_master_port(master_port) + 1,
1215
+ self._world_size,
1216
+ timeout=timeout,
1217
+ is_master=self._rank == 0,
1218
+ )
1219
+ # if ranks is None or [], it will use fully broadcast to update to all ranks
1220
+ ranks_group = dist.new_group(ranks if ranks else None)
1221
+ self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
1222
+ self.store_based_barrier(manager_store)
1223
+ except Exception as e:
1224
+ logger.exception(
1225
+ f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
1226
+ )
1227
+ raise
1228
+ finally:
1229
+ if ranks_group:
1230
+ dist.destroy_process_group(ranks_group)
1231
+ if self._auto_pg and dist.is_initialized():
1232
+ dist.destroy_process_group()
1233
+ self.device_manager.device_module.empty_cache()
1234
+ logger.info(
1235
+ f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
1236
+ f"Current device allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, "
1237
+ f"reserved {self.device_manager.device_module.memory_reserved() / 1024 / 1024} MB."
1238
+ )
1239
+
1240
+ def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
1241
+ def zmq_handle(device_uuid: str) -> str:
1242
+ return f"ipc://@checkpoint-engine-{device_uuid}-{self._zmq_addr_counter}.sock"
1243
+
1244
+ socket_paths = [(uid, zmq_handle(uid)) for uid in self._global_device_uuids]
1245
+ socket = self._zmq_ctx.socket(zmq.REQ)
1246
+ socket.bind(zmq_handle(self._device_uuid))
1247
+ self._zmq_addr_counter += 1
1248
+ return socket, socket_paths
1249
+
1250
+ def _detect_bucket_size(
1251
+ self, ranks_group: dist.ProcessGroup, *, disable_h2d_buffer: bool = False
1252
+ ) -> tuple[int, bool]:
1253
+ GiB = 1 << 30 # noqa: N806
1254
+ # auto detect bucket size
1255
+ tensor = torch.tensor(
1256
+ [
1257
+ # proportion of current device free memory bytes
1258
+ int(
1259
+ float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction
1260
+ ),
1261
+ # we use negative value to reuse allreduce min operation
1262
+ # for getting the max value of zmq_addr_counter in all ranks
1263
+ -self._zmq_addr_counter,
1264
+ ],
1265
+ dtype=torch.int64,
1266
+ device=self.device_manager.device_type,
1267
+ )
1268
+ dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group)
1269
+ tensor = tensor.cpu()
1270
+ free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
1271
+ max_tensor_bytes = 0
1272
+ for items in self._current_global_parameter_metas.values():
1273
+ for metas_list in items.memory_buffer_metas_list:
1274
+ for meta in metas_list.metas:
1275
+ max_tensor_bytes = max(max_tensor_bytes, meta.aligned_size)
1276
+ free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE
1277
+ if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer:
1278
+ self._logger_rank0(f"[rank{self._rank}] use h2d buffer")
1279
+ # using h2d_buffer can make all ranks' h2d parallel execution
1280
+ # the cost is that we need to allocate extra h2d_buffer's GPU memory
1281
+ free_bytes = free_bytes_divided_3
1282
+ else:
1283
+ # if the memory is not enough, it will fallback to disable_h2d_buffer mode,
1284
+ # at this time, the bandwidth will be limited by the h2d of a single machine,
1285
+ # but we can save GPU memory
1286
+ self._logger_rank0(
1287
+ f"[rank{self._rank}] disable h2d buffer when max_tensor_bytes {max_tensor_bytes} is larger than free_bytes {free_bytes} // 3"
1288
+ )
1289
+ free_bytes = free_bytes // (2 * _ALIGN_SIZE) * _ALIGN_SIZE
1290
+ assert max_tensor_bytes <= free_bytes, (
1291
+ f"max_tensor_bytes {max_tensor_bytes} should be less than free_bytes {free_bytes}"
1292
+ )
1293
+ disable_h2d_buffer = True
1294
+ max_bytes = int(os.getenv("PS_MAX_BUCKET_SIZE_GB", 8)) * GiB
1295
+ bucket_size = min(max(max_bytes, max_tensor_bytes), free_bytes)
1296
+ logger.info(f"[rank{self._rank}] auto detect bucket size {bucket_size / GiB:.2f} GiB")
1297
+ return bucket_size, disable_h2d_buffer
1298
+
1299
+ def _copy_to_buffer(
1300
+ self,
1301
+ checkpoint_name: str,
1302
+ bucket: H2DBucket,
1303
+ buffer: torch.Tensor,
1304
+ owner_rank: int | None = None,
1305
+ ):
1306
+ offset = 0
1307
+ if owner_rank is not None:
1308
+ buf_ptrs, remote_ptrs, lens = [], [], []
1309
+ ptr_base = buffer.data_ptr()
1310
+ target_addr, ptrs = self._get_addr_ptrs(owner_rank)
1311
+ for b in bucket.ranges:
1312
+ assert offset + b.size <= bucket.size, (
1313
+ f"offset {offset} + size {b.size} > bucket_size {bucket.size}"
1314
+ )
1315
+ if owner_rank is not None:
1316
+ buf_ptrs.append(ptr_base + offset)
1317
+ remote_ptrs.append(ptrs[b.idx][0] + b.offset)
1318
+ lens.append(b.size)
1319
+ else:
1320
+ pool = self._get_memory_pool(checkpoint_name)[b.idx]
1321
+ buffer[offset : offset + b.size].data.copy_(
1322
+ pool.buffer[b.offset : b.offset + b.size],
1323
+ non_blocking=True,
1324
+ )
1325
+ offset += b.size
1326
+ assert offset == bucket.size, f"offset {offset} != bucket_size {bucket.size}"
1327
+ if owner_rank is not None:
1328
+ self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
1329
+ self.device_manager.device_module.synchronize()
1330
+
1331
+ def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
1332
+ addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr
1333
+ metas_list = self._current_global_parameter_metas[owner_rank].memory_buffer_metas_list
1334
+ return addr, [(metas.ptr, metas.size) for metas in metas_list]
1335
+
1336
+ def _register_parameters_to_p2p_store(self, checkpoint_name: str):
1337
+ assert self._p2p_store is not None, "p2p store is not initialized"
1338
+ pool = self._get_memory_pool(checkpoint_name)
1339
+ if len(pool) == 0:
1340
+ return
1341
+ named_tensors, tensor_ptrs = {}, []
1342
+ register_name = (
1343
+ checkpoint_name
1344
+ if checkpoint_name != self._current_shared_memory_pool_user
1345
+ else self.shared_memory_pool_name
1346
+ )
1347
+ for idx, memory_buffer in enumerate(pool):
1348
+ named_tensors[f"memory_pool_{register_name}_{idx}"] = memory_buffer.buffer
1349
+ tensor_ptrs.append((memory_buffer.buffer.data_ptr(), memory_buffer.size))
1350
+ self._p2p_store.register_named_tensors(named_tensors)
1351
+
1352
+ def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
1353
+ assert self._p2p_store is not None, "p2p store is not initialized"
1354
+ pool = self._get_memory_pool(checkpoint_name)
1355
+ if len(pool) == 0:
1356
+ return 0
1357
+ unregister_name = (
1358
+ checkpoint_name
1359
+ if checkpoint_name != self._current_shared_memory_pool_user
1360
+ else self.shared_memory_pool_name
1361
+ )
1362
+ return self._p2p_store.unregister_named_tensors(
1363
+ [f"memory_pool_{unregister_name}_{idx}" for idx, _ in enumerate(pool)]
1364
+ )
1365
+
1366
+ def _update_per_bucket(
1367
+ self,
1368
+ checkpoint_name: str,
1369
+ req_func: Callable[[list[tuple[str, str]]], None],
1370
+ ranks_group: dist.ProcessGroup,
1371
+ ranks: list[int] | None = None,
1372
+ ):
1373
+ assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
1374
+ assert dist.is_initialized(), "process group is not initialized"
1375
+
1376
+ # if both ranks is None or [], it will use fully broadcast to update to all ranks
1377
+ if not ranks:
1378
+ logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
1379
+ # if ranks is set, it will use p2p to update to the ranks
1380
+ else:
1381
+ assert self._p2p_store is not None, "p2p store is not initialized"
1382
+ assert ranks, "ranks should be set"
1383
+
1384
+ need_update = self._rank in ranks
1385
+ logger.info(
1386
+ f"[rank{self._rank}] update checkpoint {checkpoint_name} p2p, {need_update=} with {ranks=}, "
1387
+ f"gpu_count {self._gpu_count}, world_size {self._world_size}"
1388
+ )
1389
+
1390
+ if not need_update:
1391
+ return
1392
+ # first execute a barrier to avoid subsequent device oom
1393
+ dist.barrier(group=ranks_group)
1394
+
1395
+ bucket_size, disable_h2d_buffer = self._detect_bucket_size(ranks_group)
1396
+ buckets = _gen_h2d_buckets(
1397
+ self._current_global_parameter_metas,
1398
+ bucket_size,
1399
+ self._local_rdma_devices,
1400
+ self._remote_rdma_devices,
1401
+ ranks,
1402
+ )
1403
+
1404
+ h2d_buffer: torch.Tensor | None = (
1405
+ None
1406
+ if disable_h2d_buffer
1407
+ else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type)
1408
+ )
1409
+ # p2p store need to register h2d_buffer to let other ranks read
1410
+ if ranks:
1411
+ h2d_buffer_name = "__h2d_buffer__"
1412
+ if h2d_buffer is not None and self._p2p_store is not None:
1413
+ self._p2p_store.register_named_tensors({h2d_buffer_name: h2d_buffer})
1414
+ receiver_rank_buckets: list[tuple[int, H2DBucket]] = []
1415
+ for receiver_rank, owner_rank, bucket in buckets:
1416
+ if receiver_rank != self._rank:
1417
+ continue
1418
+ receiver_rank_buckets.append((owner_rank, bucket))
1419
+
1420
+ buffer = torch.empty(
1421
+ bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type
1422
+ )
1423
+ handle = reduce_tensor(buffer)
1424
+
1425
+ buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
1426
+ max_len = 0
1427
+ for receiver_rank, _, bucket in buckets:
1428
+ buckets_by_receiver_rank[receiver_rank].append(bucket)
1429
+ if len(buckets_by_receiver_rank[receiver_rank]) > max_len:
1430
+ max_len = len(buckets_by_receiver_rank[receiver_rank])
1431
+
1432
+ socket, socket_paths = self._bind_zmq_socket()
1433
+ req_thread = threading.Thread(
1434
+ target=req_func,
1435
+ args=(socket_paths,),
1436
+ )
1437
+ req_thread.start()
1438
+ socket.send_pyobj(handle)
1439
+
1440
+ gidx = 0
1441
+ ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
1442
+ try:
1443
+ for i in range(max_len):
1444
+ if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
1445
+ self._copy_to_buffer(
1446
+ checkpoint_name,
1447
+ receiver_rank_buckets[i][1],
1448
+ h2d_buffer,
1449
+ receiver_rank_buckets[i][0] if ranks else None,
1450
+ )
1451
+ for receiver_rank, _buckets in buckets_by_receiver_rank.items():
1452
+ if i >= len(_buckets):
1453
+ continue
1454
+ bucket = _buckets[i]
1455
+ alloc, reserved = (
1456
+ self.device_manager.device_module.memory_allocated() / 1024 / 1024,
1457
+ self.device_manager.device_module.memory_reserved() / 1024 / 1024,
1458
+ )
1459
+ self._logger_rank0(
1460
+ f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
1461
+ f"Current device allocated {alloc:.2f} MB, "
1462
+ f"reserved {reserved:.2f} MB."
1463
+ )
1464
+ start = gidx % 2 * bucket_size
1465
+ buffer_b: torch.Tensor = buffer[start : start + bucket.size]
1466
+ if receiver_rank == self._rank:
1467
+ if disable_h2d_buffer:
1468
+ self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
1469
+ else:
1470
+ buffer_b.data.copy_(h2d_buffer[: bucket.size])
1471
+ dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group)
1472
+ resp = socket.recv()
1473
+ if resp != b"":
1474
+ msg = resp.decode("utf-8")
1475
+ logger.error(
1476
+ f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
1477
+ )
1478
+ ret_code.fill_(1)
1479
+ dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group)
1480
+ self.device_manager.device_module.synchronize()
1481
+ if ret_code.item() != 0:
1482
+ # quit early if any rank failed
1483
+ socket.send_pyobj(RuntimeError("Some workers failed to update weights"))
1484
+ raise RuntimeError("Failed to update weights due to remote errors")
1485
+ socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
1486
+ gidx += 1
1487
+
1488
+ socket.recv()
1489
+ socket.send_pyobj(None)
1490
+ socket.recv()
1491
+ finally:
1492
+ req_thread.join()
1493
+ dist.barrier(group=ranks_group)
1494
+ socket.close()
1495
+ if ranks and h2d_buffer is not None:
1496
+ self._p2p_store.unregister_named_tensors([h2d_buffer_name])
1497
+
1498
+ self.device_manager.device_module.empty_cache()
1499
+
1500
+
1501
+ def _init_api(ps: ParameterServer) -> Any:
1502
+ import fastapi
1503
+ from fastapi import Request
1504
+ from fastapi.responses import JSONResponse, Response
1505
+
1506
+ app = fastapi.FastAPI()
1507
+
1508
+ class RegisterRequest(BaseModel):
1509
+ files: list[str]
1510
+
1511
+ class UpdateRequest(BaseModel):
1512
+ ranks: list[int] = []
1513
+ update_url: str | None = None
1514
+ inference_group_ranks: list[int] = []
1515
+ timeout: float = 300.0
1516
+ uds: str | None = None
1517
+
1518
+ def wrap_exception(func: Callable[[], None]) -> Response:
1519
+ try:
1520
+ func()
1521
+ except Exception as e: # noqa: BLE001
1522
+ logger.exception(f"wrap exception {func} failed")
1523
+ return JSONResponse(content=str(e), status_code=500)
1524
+ return Response(status_code=200)
1525
+
1526
+ @app.post("/v1/checkpoints/{checkpoint_name}/files")
1527
+ async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response:
1528
+ return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files))
1529
+
1530
+ @app.delete("/v1/checkpoints/{checkpoint_name}")
1531
+ async def unregister_checkpoint(checkpoint_name: str) -> Response:
1532
+ return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name))
1533
+
1534
+ @app.get("/v1/healthz")
1535
+ async def healthz() -> Response:
1536
+ return Response(status_code=200)
1537
+
1538
+ @app.post("/v1/checkpoints/{checkpoint_name}/gather-metas")
1539
+ async def gather_metas(checkpoint_name: str) -> Response:
1540
+ return wrap_exception(lambda: ps.gather_metas(checkpoint_name))
1541
+
1542
+ @app.post("/v1/checkpoints/{checkpoint_name}/update")
1543
+ async def update(checkpoint_name: str, req: UpdateRequest) -> Response:
1544
+ def update_func(socket_paths: list[tuple[str, str]]):
1545
+ if req.update_url is None:
1546
+ return
1547
+ if req.inference_group_ranks:
1548
+ socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
1549
+ request_inference_to_update(
1550
+ req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
1551
+ )
1552
+
1553
+ return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))
1554
+
1555
+ return app
1556
+
1557
+
1558
+ @logger.catch(reraise=True)
1559
+ def run_from_cli():
1560
+ import uvicorn
1561
+
1562
+ parser = argparse.ArgumentParser(description="Parameter Server")
1563
+ parser.add_argument("--uds", type=str)
1564
+
1565
+ args = parser.parse_args()
1566
+ logger.info(
1567
+ f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
1568
+ )
1569
+
1570
+ assert args.uds and len(args.uds) > 0, args.uds
1571
+ ps = ParameterServer(auto_pg=True)
1572
+ uvicorn.run(_init_api(ps), uds=args.uds, timeout_keep_alive=60)
1573
+
1574
+
1575
+ if __name__ == "__main__":
1576
+ run_from_cli()