sglang 0.3.6.post3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +4 -0
  2. sglang/bench_serving.py +13 -0
  3. sglang/check_env.py +1 -1
  4. sglang/srt/_custom_ops.py +118 -0
  5. sglang/srt/configs/device_config.py +17 -0
  6. sglang/srt/configs/load_config.py +84 -0
  7. sglang/srt/configs/model_config.py +161 -4
  8. sglang/srt/configs/qwen2vl.py +5 -8
  9. sglang/srt/constrained/outlines_backend.py +6 -1
  10. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  11. sglang/srt/distributed/__init__.py +3 -0
  12. sglang/srt/distributed/communication_op.py +34 -0
  13. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  14. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  15. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  16. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  17. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  21. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  22. sglang/srt/distributed/parallel_state.py +1275 -0
  23. sglang/srt/distributed/utils.py +223 -0
  24. sglang/srt/hf_transformers_utils.py +37 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  26. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  27. sglang/srt/layers/fused_moe_patch.py +20 -11
  28. sglang/srt/layers/linear.py +1 -0
  29. sglang/srt/layers/logits_processor.py +17 -3
  30. sglang/srt/layers/quantization/__init__.py +34 -0
  31. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  32. sglang/srt/lora/lora.py +1 -1
  33. sglang/srt/managers/io_struct.py +48 -2
  34. sglang/srt/managers/schedule_batch.py +18 -14
  35. sglang/srt/managers/schedule_policy.py +7 -4
  36. sglang/srt/managers/scheduler.py +76 -20
  37. sglang/srt/managers/tokenizer_manager.py +166 -68
  38. sglang/srt/managers/tp_worker.py +36 -3
  39. sglang/srt/managers/tp_worker_overlap_thread.py +21 -3
  40. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  41. sglang/srt/model_executor/forward_batch_info.py +9 -4
  42. sglang/srt/model_executor/model_runner.py +136 -150
  43. sglang/srt/model_loader/__init__.py +34 -0
  44. sglang/srt/model_loader/loader.py +1139 -0
  45. sglang/srt/model_loader/utils.py +41 -0
  46. sglang/srt/model_loader/weight_utils.py +640 -0
  47. sglang/srt/models/baichuan.py +9 -10
  48. sglang/srt/models/chatglm.py +6 -15
  49. sglang/srt/models/commandr.py +2 -3
  50. sglang/srt/models/dbrx.py +2 -3
  51. sglang/srt/models/deepseek.py +4 -11
  52. sglang/srt/models/deepseek_v2.py +3 -11
  53. sglang/srt/models/exaone.py +2 -3
  54. sglang/srt/models/gemma.py +2 -6
  55. sglang/srt/models/gemma2.py +3 -14
  56. sglang/srt/models/gemma2_reward.py +0 -1
  57. sglang/srt/models/gpt2.py +5 -12
  58. sglang/srt/models/gpt_bigcode.py +6 -22
  59. sglang/srt/models/grok.py +3 -3
  60. sglang/srt/models/internlm2.py +2 -3
  61. sglang/srt/models/internlm2_reward.py +0 -1
  62. sglang/srt/models/llama.py +97 -27
  63. sglang/srt/models/llama_classification.py +1 -2
  64. sglang/srt/models/llama_embedding.py +1 -2
  65. sglang/srt/models/llama_reward.py +2 -3
  66. sglang/srt/models/llava.py +1 -4
  67. sglang/srt/models/llavavid.py +1 -2
  68. sglang/srt/models/minicpm.py +4 -7
  69. sglang/srt/models/minicpm3.py +6 -19
  70. sglang/srt/models/mixtral.py +12 -5
  71. sglang/srt/models/mixtral_quant.py +2 -3
  72. sglang/srt/models/mllama.py +3 -7
  73. sglang/srt/models/olmo.py +2 -8
  74. sglang/srt/models/olmo2.py +0 -1
  75. sglang/srt/models/olmoe.py +3 -5
  76. sglang/srt/models/phi3_small.py +8 -8
  77. sglang/srt/models/qwen.py +2 -3
  78. sglang/srt/models/qwen2.py +10 -9
  79. sglang/srt/models/qwen2_moe.py +4 -11
  80. sglang/srt/models/qwen2_vl.py +2 -6
  81. sglang/srt/models/registry.py +99 -0
  82. sglang/srt/models/stablelm.py +2 -3
  83. sglang/srt/models/torch_native_llama.py +6 -12
  84. sglang/srt/models/xverse.py +2 -4
  85. sglang/srt/models/xverse_moe.py +4 -11
  86. sglang/srt/models/yivl.py +2 -3
  87. sglang/srt/openai_api/adapter.py +9 -5
  88. sglang/srt/openai_api/protocol.py +1 -0
  89. sglang/srt/server.py +267 -170
  90. sglang/srt/server_args.py +65 -31
  91. sglang/srt/utils.py +245 -28
  92. sglang/test/test_utils.py +7 -0
  93. sglang/version.py +1 -1
  94. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/METADATA +1 -1
  95. sglang-0.4.0.dist-info/RECORD +184 -0
  96. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  97. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  98. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  99. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,182 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/cuda_wrapper.py
2
+ """This file is a pure Python wrapper for the cudart library.
3
+ It avoids the need to compile a separate shared library, and is
4
+ convenient for use when we just need to call a few functions.
5
+ """
6
+
7
+ import ctypes
8
+ import logging
9
+ from dataclasses import dataclass
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ # this line makes it possible to directly load `libcudart.so` using `ctypes`
13
+ import torch # noqa
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # === export types and functions from cudart to Python ===
18
+ # for the original cudart definition, please check
19
+ # https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
20
+
21
+ cudaError_t = ctypes.c_int
22
+ cudaMemcpyKind = ctypes.c_int
23
+
24
+
25
+ class cudaIpcMemHandle_t(ctypes.Structure):
26
+ _fields_ = [("internal", ctypes.c_byte * 128)]
27
+
28
+
29
+ @dataclass
30
+ class Function:
31
+ name: str
32
+ restype: Any
33
+ argtypes: List[Any]
34
+
35
+
36
+ def find_loaded_library(lib_name) -> Optional[str]:
37
+ """
38
+ According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
39
+ the file `/proc/self/maps` contains the memory maps of the process, which includes the
40
+ shared libraries loaded by the process. We can use this file to find the path of the
41
+ a loaded library.
42
+ """ # noqa
43
+ found = False
44
+ with open("/proc/self/maps") as f:
45
+ for line in f:
46
+ if lib_name in line:
47
+ found = True
48
+ break
49
+ if not found:
50
+ # the library is not loaded in the current process
51
+ return None
52
+ # if lib_name is libcudart, we need to match a line with:
53
+ # address /path/to/libcudart-hash.so.11.0
54
+ start = line.index("/")
55
+ path = line[start:].strip()
56
+ filename = path.split("/")[-1]
57
+ assert filename.rpartition(".so")[0].startswith(
58
+ lib_name
59
+ ), f"Unexpected filename: {filename} for library {lib_name}"
60
+ return path
61
+
62
+
63
+ class CudaRTLibrary:
64
+ exported_functions = [
65
+ # ​cudaError_t cudaSetDevice ( int device )
66
+ Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
67
+ # cudaError_t cudaDeviceSynchronize ( void )
68
+ Function("cudaDeviceSynchronize", cudaError_t, []),
69
+ # ​cudaError_t cudaDeviceReset ( void )
70
+ Function("cudaDeviceReset", cudaError_t, []),
71
+ # const char* cudaGetErrorString ( cudaError_t error )
72
+ Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
73
+ # ​cudaError_t cudaMalloc ( void** devPtr, size_t size )
74
+ Function(
75
+ "cudaMalloc",
76
+ cudaError_t,
77
+ [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t],
78
+ ),
79
+ # ​cudaError_t cudaFree ( void* devPtr )
80
+ Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
81
+ # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
82
+ Function(
83
+ "cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
84
+ ),
85
+ # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
86
+ Function(
87
+ "cudaMemcpy",
88
+ cudaError_t,
89
+ [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind],
90
+ ),
91
+ # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
92
+ Function(
93
+ "cudaIpcGetMemHandle",
94
+ cudaError_t,
95
+ [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p],
96
+ ),
97
+ # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
98
+ Function(
99
+ "cudaIpcOpenMemHandle",
100
+ cudaError_t,
101
+ [ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint],
102
+ ),
103
+ ]
104
+
105
+ # class attribute to store the mapping from the path to the library
106
+ # to avoid loading the same library multiple times
107
+ path_to_library_cache: Dict[str, Any] = {}
108
+
109
+ # class attribute to store the mapping from library path
110
+ # to the corresponding dictionary
111
+ path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
112
+
113
+ def __init__(self, so_file: Optional[str] = None):
114
+ if so_file is None:
115
+ so_file = find_loaded_library("libcudart")
116
+ assert so_file is not None, "libcudart is not loaded in the current process"
117
+ if so_file not in CudaRTLibrary.path_to_library_cache:
118
+ lib = ctypes.CDLL(so_file)
119
+ CudaRTLibrary.path_to_library_cache[so_file] = lib
120
+ self.lib = CudaRTLibrary.path_to_library_cache[so_file]
121
+
122
+ if so_file not in CudaRTLibrary.path_to_dict_mapping:
123
+ _funcs = {}
124
+ for func in CudaRTLibrary.exported_functions:
125
+ f = getattr(self.lib, func.name)
126
+ f.restype = func.restype
127
+ f.argtypes = func.argtypes
128
+ _funcs[func.name] = f
129
+ CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
130
+ self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
131
+
132
+ def CUDART_CHECK(self, result: cudaError_t) -> None:
133
+ if result != 0:
134
+ error_str = self.cudaGetErrorString(result)
135
+ raise RuntimeError(f"CUDART error: {error_str}")
136
+
137
+ def cudaGetErrorString(self, error: cudaError_t) -> str:
138
+ return self.funcs["cudaGetErrorString"](error).decode("utf-8")
139
+
140
+ def cudaSetDevice(self, device: int) -> None:
141
+ self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
142
+
143
+ def cudaDeviceSynchronize(self) -> None:
144
+ self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
145
+
146
+ def cudaDeviceReset(self) -> None:
147
+ self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
148
+
149
+ def cudaMalloc(self, size: int) -> ctypes.c_void_p:
150
+ devPtr = ctypes.c_void_p()
151
+ self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
152
+ return devPtr
153
+
154
+ def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
155
+ self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
156
+
157
+ def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None:
158
+ self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
159
+
160
+ def cudaMemcpy(
161
+ self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int
162
+ ) -> None:
163
+ cudaMemcpyDefault = 4
164
+ kind = cudaMemcpyDefault
165
+ self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
166
+
167
+ def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
168
+ handle = cudaIpcMemHandle_t()
169
+ self.CUDART_CHECK(
170
+ self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr)
171
+ )
172
+ return handle
173
+
174
+ def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
175
+ cudaIpcMemLazyEnablePeerAccess = 1
176
+ devPtr = ctypes.c_void_p()
177
+ self.CUDART_CHECK(
178
+ self.funcs["cudaIpcOpenMemHandle"](
179
+ ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess
180
+ )
181
+ )
182
+ return devPtr
@@ -0,0 +1,352 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce.py
2
+ import ctypes
3
+ import logging
4
+ import os
5
+ from contextlib import contextmanager
6
+ from functools import wraps
7
+ from typing import Callable, List, Optional, TypeVar, Union
8
+
9
+ import pynvml
10
+ import torch
11
+ import torch.distributed as dist
12
+ from torch.distributed import ProcessGroup
13
+ from typing_extensions import ParamSpec
14
+
15
+ from sglang.srt import _custom_ops as ops
16
+ from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
17
+ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
18
+ gpu_p2p_access_check,
19
+ )
20
+ from sglang.srt.distributed.parallel_state import in_the_same_node_as
21
+ from sglang.srt.utils import cuda_device_count_stateless, is_cuda
22
+
23
+ try:
24
+ ops.meta_size()
25
+ custom_ar = True
26
+ except Exception:
27
+ # For AMD GPUs and CPUs
28
+ custom_ar = False
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ _P = ParamSpec("_P")
34
+ _R = TypeVar("_R")
35
+
36
+
37
+ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
38
+ @wraps(fn)
39
+ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
40
+ pynvml.nvmlInit()
41
+ try:
42
+ return fn(*args, **kwargs)
43
+ finally:
44
+ pynvml.nvmlShutdown()
45
+
46
+ return wrapper
47
+
48
+
49
+ @with_nvml_context
50
+ def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
51
+ """
52
+ query if the set of gpus are fully connected by nvlink (1 hop)
53
+ """
54
+ handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
55
+ for i, handle in enumerate(handles):
56
+ for j, peer_handle in enumerate(handles):
57
+ if i < j:
58
+ try:
59
+ p2p_status = pynvml.nvmlDeviceGetP2PStatus(
60
+ handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
61
+ )
62
+ if p2p_status != pynvml.NVML_P2P_STATUS_OK:
63
+ return False
64
+ except pynvml.NVMLError:
65
+ logger.exception(
66
+ "NVLink detection failed. This is normal if your"
67
+ " machine has no NVLink equipped."
68
+ )
69
+ return False
70
+ return True
71
+
72
+
73
+ def _can_p2p(rank: int, world_size: int) -> bool:
74
+ # SGLANG_SKIP_P2P_CHECK can be set to False in sglang
75
+ SGLANG_SKIP_P2P_CHECK = os.getenv("SGLANG_SKIP_P2P_CHECK", "0") == "1"
76
+ for i in range(world_size):
77
+ if i == rank:
78
+ continue
79
+ if SGLANG_SKIP_P2P_CHECK:
80
+ logger.info("Skipping P2P check and trusting the driver's P2P report.")
81
+ return torch.cuda.can_device_access_peer(rank, i)
82
+ if not gpu_p2p_access_check(rank, i):
83
+ return False
84
+ return True
85
+
86
+
87
+ def is_weak_contiguous(inp: torch.Tensor):
88
+ return inp.is_contiguous() or (
89
+ inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
90
+ == inp.numel() * inp.element_size()
91
+ )
92
+
93
+
94
+ class CustomAllreduce:
95
+
96
+ _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
97
+
98
+ # max_size: max supported allreduce size
99
+ def __init__(
100
+ self,
101
+ group: ProcessGroup,
102
+ device: Union[int, str, torch.device],
103
+ max_size=8192 * 1024,
104
+ ) -> None:
105
+ """
106
+ Args:
107
+ group: the process group to work on. If None, it will use the
108
+ default process group.
109
+ device: the device to bind the CustomAllreduce to. If None,
110
+ it will be bind to f"cuda:{local_rank}".
111
+ It is the caller's responsibility to make sure each communicator
112
+ is bind to a unique device, and all communicators in this group
113
+ are in the same node.
114
+ """
115
+ self._IS_CAPTURING = False
116
+ self.disabled = True
117
+
118
+ if not custom_ar:
119
+ # disable because of missing custom allreduce library
120
+ # e.g. in a non-cuda environment
121
+ return
122
+
123
+ self.group = group
124
+
125
+ assert (
126
+ dist.get_backend(group) != dist.Backend.NCCL
127
+ ), "CustomAllreduce should be attached to a non-NCCL group."
128
+
129
+ if not all(in_the_same_node_as(group, source_rank=0)):
130
+ # No need to initialize custom allreduce for multi-node case.
131
+ logger.warning(
132
+ "Custom allreduce is disabled because this process group"
133
+ " spans across nodes."
134
+ )
135
+ return
136
+
137
+ rank = dist.get_rank(group=self.group)
138
+ world_size = dist.get_world_size(group=self.group)
139
+ if world_size == 1:
140
+ # No need to initialize custom allreduce for single GPU case.
141
+ return
142
+
143
+ if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
144
+ logger.warning(
145
+ "Custom allreduce is disabled due to an unsupported world"
146
+ " size: %d. Supported world sizes: %s. To silence this "
147
+ "warning, specify disable_custom_all_reduce=True explicitly.",
148
+ world_size,
149
+ str(CustomAllreduce._SUPPORTED_WORLD_SIZES),
150
+ )
151
+ return
152
+
153
+ if isinstance(device, int):
154
+ device = torch.device(f"cuda:{device}")
155
+ elif isinstance(device, str):
156
+ device = torch.device(device)
157
+ # now `device` is a `torch.device` object
158
+ assert isinstance(device, torch.device)
159
+ self.device = device
160
+
161
+ cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
162
+ if cuda_visible_devices:
163
+ device_ids = list(map(int, cuda_visible_devices.split(",")))
164
+ else:
165
+ device_ids = list(range(cuda_device_count_stateless()))
166
+
167
+ physical_device_id = device_ids[device.index]
168
+ tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
169
+ gather_list = [
170
+ torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size)
171
+ ]
172
+ dist.all_gather(gather_list, tensor, group=self.group)
173
+ physical_device_ids = [t.item() for t in gather_list]
174
+
175
+ # test nvlink first, this will filter out most of the cases
176
+ # where custom allreduce is not supported
177
+ # this checks hardware and driver support for NVLink
178
+ assert is_cuda()
179
+
180
+ full_nvlink = is_full_nvlink(physical_device_ids)
181
+ if world_size > 2 and not full_nvlink:
182
+ logger.warning(
183
+ "Custom allreduce is disabled because it's not supported on"
184
+ " more than two PCIe-only GPUs. To silence this warning, "
185
+ "specify disable_custom_all_reduce=True explicitly."
186
+ )
187
+ return
188
+ # test P2P capability, this checks software/cudaruntime support
189
+ # this is expensive to compute at the first time
190
+ # then we cache the result
191
+ if not _can_p2p(rank, world_size):
192
+ logger.warning(
193
+ "Custom allreduce is disabled because your platform lacks "
194
+ "GPU P2P capability or P2P test failed. To silence this "
195
+ "warning, specify disable_custom_all_reduce=True explicitly."
196
+ )
197
+ return
198
+
199
+ self.disabled = False
200
+ # Buffers memory are owned by this Python class and passed to C++.
201
+ # Meta data composes of two parts: meta data for synchronization and a
202
+ # temporary buffer for storing intermediate allreduce results.
203
+ self.meta_ptrs = self.create_shared_buffer(
204
+ ops.meta_size() + max_size, group=group
205
+ )
206
+ # This is a pre-registered IPC buffer. In eager mode, input tensors
207
+ # are first copied into this buffer before allreduce is performed
208
+ self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
209
+ # This is a buffer for storing the tuples of pointers pointing to
210
+ # IPC buffers from all ranks. Each registered tuple has size of
211
+ # 8*world_size bytes where world_size is at most 8. Allocating 8MB
212
+ # is enough for 131072 such tuples. The largest model I've seen only
213
+ # needs less than 10000 of registered tuples.
214
+ self.rank_data = torch.empty(
215
+ 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
216
+ )
217
+ self.max_size = max_size
218
+ self.rank = rank
219
+ self.world_size = world_size
220
+ self.full_nvlink = full_nvlink
221
+ self._ptr = ops.init_custom_ar(
222
+ self.meta_ptrs, self.rank_data, rank, self.full_nvlink
223
+ )
224
+ ops.register_buffer(self._ptr, self.buffer_ptrs)
225
+
226
+ @staticmethod
227
+ def create_shared_buffer(
228
+ size_in_bytes: int, group: Optional[ProcessGroup] = None
229
+ ) -> List[int]:
230
+ """
231
+ Creates a shared buffer and returns a list of pointers
232
+ representing the buffer on all processes in the group.
233
+ """
234
+ lib = CudaRTLibrary()
235
+ pointer = lib.cudaMalloc(size_in_bytes)
236
+ handle = lib.cudaIpcGetMemHandle(pointer)
237
+ world_size = dist.get_world_size(group=group)
238
+ rank = dist.get_rank(group=group)
239
+ handles = [None] * world_size
240
+ dist.all_gather_object(handles, handle, group=group)
241
+
242
+ pointers: List[int] = []
243
+ for i, h in enumerate(handles):
244
+ if i == rank:
245
+ pointers.append(pointer.value) # type: ignore
246
+ else:
247
+ pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore
248
+
249
+ return pointers
250
+
251
+ @staticmethod
252
+ def free_shared_buffer(
253
+ pointers: List[int], group: Optional[ProcessGroup] = None
254
+ ) -> None:
255
+ rank = dist.get_rank(group=group)
256
+ lib = CudaRTLibrary()
257
+ lib.cudaFree(ctypes.c_void_p(pointers[rank]))
258
+
259
+ @contextmanager
260
+ def capture(self):
261
+ """
262
+ The main responsibility of this context manager is the
263
+ `register_graph_buffers` call at the end of the context.
264
+ It records all the buffer addresses used in the CUDA graph.
265
+ """
266
+ try:
267
+ self._IS_CAPTURING = True
268
+ yield
269
+ finally:
270
+ self._IS_CAPTURING = False
271
+ if not self.disabled:
272
+ self.register_graph_buffers()
273
+
274
+ def register_graph_buffers(self):
275
+ handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
276
+ logger.info("Registering %d cuda graph addresses", len(offset))
277
+ # We cannot directly use `dist.all_gather_object` here
278
+ # because it is incompatible with `gloo` backend under inference mode.
279
+ # see https://github.com/pytorch/pytorch/issues/126032 for details.
280
+ all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
281
+ all_data[self.rank] = [handle, offset]
282
+ ranks = sorted(dist.get_process_group_ranks(group=self.group))
283
+ for i, rank in enumerate(ranks):
284
+ dist.broadcast_object_list(
285
+ all_data[i], src=rank, group=self.group, device="cpu"
286
+ )
287
+ # Unpack list of tuples to tuple of lists.
288
+ handles = [d[0] for d in all_data] # type: ignore
289
+ offsets = [d[1] for d in all_data] # type: ignore
290
+ ops.register_graph_buffers(self._ptr, handles, offsets)
291
+
292
+ def should_custom_ar(self, inp: torch.Tensor):
293
+ if self.disabled:
294
+ return False
295
+ inp_size = inp.numel() * inp.element_size()
296
+ # custom allreduce requires input byte size to be multiples of 16
297
+ if inp_size % 16 != 0:
298
+ return False
299
+ if not is_weak_contiguous(inp):
300
+ return False
301
+ # for 4 or more non NVLink-capable GPUs, custom allreduce provides
302
+ # little performance improvement over NCCL.
303
+ if self.world_size == 2 or self.full_nvlink:
304
+ return inp_size < self.max_size
305
+ return False
306
+
307
+ def all_reduce(
308
+ self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False
309
+ ):
310
+ """Performs an out-of-place all reduce.
311
+
312
+ If registered is True, this assumes inp's pointer is already
313
+ IPC-registered. Otherwise, inp is first copied into a pre-registered
314
+ buffer.
315
+ """
316
+ if out is None:
317
+ out = torch.empty_like(inp)
318
+ if registered:
319
+ ops.all_reduce(self._ptr, inp, out, 0, 0)
320
+ else:
321
+ ops.all_reduce(
322
+ self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
323
+ )
324
+ return out
325
+
326
+ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
327
+ """The main allreduce API that provides support for cuda graph."""
328
+ # When custom allreduce is disabled, this will be None.
329
+ if self.disabled or not self.should_custom_ar(input):
330
+ return None
331
+ if self._IS_CAPTURING:
332
+ if torch.cuda.is_current_stream_capturing():
333
+ return self.all_reduce(input, registered=True)
334
+ else:
335
+ # If warm up, mimic the allocation pattern since custom
336
+ # allreduce is out-of-place.
337
+ return torch.empty_like(input)
338
+ else:
339
+ # Note: outside of cuda graph context, custom allreduce incurs a
340
+ # cost of cudaMemcpy, which should be small (<=1% of overall
341
+ # latency) compared to the performance gain of using custom kernels
342
+ return self.all_reduce(input, registered=False)
343
+
344
+ def close(self):
345
+ if not self.disabled and self._ptr:
346
+ ops.dispose(self._ptr)
347
+ self._ptr = 0
348
+ self.free_shared_buffer(self.meta_ptrs)
349
+ self.free_shared_buffer(self.buffer_ptrs)
350
+
351
+ def __del__(self):
352
+ self.close()