sglang 0.3.6.post3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +6 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/flashinfer_backend.py +13 -15
- sglang/srt/layers/attention/torch_native_backend.py +285 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +34 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +18 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +76 -20
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +21 -3
- sglang/srt/model_executor/cuda_graph_runner.py +16 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +136 -150
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +2 -3
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +3 -11
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +97 -27
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +12 -5
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -8
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -11
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -12
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/server.py +267 -170
- sglang/srt/server_args.py +65 -31
- sglang/srt/utils.py +245 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/METADATA +1 -1
- sglang-0.4.0.dist-info/RECORD +184 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,291 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce_utils.py
|
2
|
+
import ctypes
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
import pickle
|
7
|
+
import subprocess
|
8
|
+
import sys
|
9
|
+
import tempfile
|
10
|
+
from functools import lru_cache
|
11
|
+
from itertools import product
|
12
|
+
from typing import Dict, List, Optional, Sequence
|
13
|
+
|
14
|
+
import torch.distributed as dist
|
15
|
+
import torch.multiprocessing as mp
|
16
|
+
|
17
|
+
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
18
|
+
from sglang.srt.utils import cuda_device_count_stateless
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
def update_environment_variables(envs: Dict[str, str]):
|
24
|
+
for k, v in envs.items():
|
25
|
+
if k in os.environ and os.environ[k] != v:
|
26
|
+
logger.warning(
|
27
|
+
"Overwriting environment variable %s " "from '%s' to '%s'",
|
28
|
+
k,
|
29
|
+
os.environ[k],
|
30
|
+
v,
|
31
|
+
)
|
32
|
+
os.environ[k] = v
|
33
|
+
|
34
|
+
|
35
|
+
def producer(
|
36
|
+
batch_src: Sequence[int],
|
37
|
+
producer_queue,
|
38
|
+
consumer_queue,
|
39
|
+
result_queue,
|
40
|
+
cuda_visible_devices: Optional[str] = None,
|
41
|
+
):
|
42
|
+
if cuda_visible_devices is not None:
|
43
|
+
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
44
|
+
|
45
|
+
lib = CudaRTLibrary()
|
46
|
+
for i in batch_src:
|
47
|
+
lib.cudaSetDevice(i)
|
48
|
+
pointer = lib.cudaMalloc(1024)
|
49
|
+
lib.cudaMemset(pointer, 1, 1024)
|
50
|
+
lib.cudaDeviceSynchronize()
|
51
|
+
handle = lib.cudaIpcGetMemHandle(pointer)
|
52
|
+
producer_queue.put(handle)
|
53
|
+
open_success = consumer_queue.get()
|
54
|
+
if open_success:
|
55
|
+
# use two queues to simulate barrier
|
56
|
+
producer_queue.put(0)
|
57
|
+
consumer_queue.get()
|
58
|
+
# check if the memory is modified
|
59
|
+
host_data = (ctypes.c_char * 1024)()
|
60
|
+
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
61
|
+
for i in range(1024):
|
62
|
+
if ord(host_data[i]) != 2:
|
63
|
+
open_success = False
|
64
|
+
break
|
65
|
+
result_queue.put(open_success)
|
66
|
+
lib.cudaDeviceReset()
|
67
|
+
|
68
|
+
|
69
|
+
def consumer(
|
70
|
+
batch_tgt: Sequence[int],
|
71
|
+
producer_queue,
|
72
|
+
consumer_queue,
|
73
|
+
result_queue,
|
74
|
+
cuda_visible_devices: Optional[str] = None,
|
75
|
+
):
|
76
|
+
if cuda_visible_devices is not None:
|
77
|
+
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
78
|
+
|
79
|
+
lib = CudaRTLibrary()
|
80
|
+
for j in batch_tgt:
|
81
|
+
lib.cudaSetDevice(j)
|
82
|
+
handle = producer_queue.get()
|
83
|
+
open_success = False
|
84
|
+
try:
|
85
|
+
pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
|
86
|
+
open_success = True
|
87
|
+
except RuntimeError:
|
88
|
+
# cannot error out here, because the producer process
|
89
|
+
# is still waiting for the response.
|
90
|
+
pass
|
91
|
+
consumer_queue.put(open_success)
|
92
|
+
if open_success:
|
93
|
+
# modify the memory
|
94
|
+
lib.cudaMemset(pointer, 2, 1024)
|
95
|
+
lib.cudaDeviceSynchronize()
|
96
|
+
# use two queues to simulate barrier
|
97
|
+
producer_queue.get()
|
98
|
+
consumer_queue.put(0)
|
99
|
+
# check if the memory is modified
|
100
|
+
host_data = (ctypes.c_char * 1024)()
|
101
|
+
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
102
|
+
for i in range(1024):
|
103
|
+
if ord(host_data[i]) != 2:
|
104
|
+
open_success = False
|
105
|
+
break
|
106
|
+
result_queue.put(open_success)
|
107
|
+
lib.cudaDeviceReset()
|
108
|
+
|
109
|
+
|
110
|
+
def can_actually_p2p(
|
111
|
+
batch_src: Sequence[int],
|
112
|
+
batch_tgt: Sequence[int],
|
113
|
+
) -> Sequence[bool]:
|
114
|
+
"""
|
115
|
+
Usually, checking if P2P access is enabled can be done by
|
116
|
+
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
|
117
|
+
the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
|
118
|
+
returns `True` even if P2P access is not actually possible.
|
119
|
+
See https://github.com/vllm-project/vllm/issues/2728 and
|
120
|
+
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
|
121
|
+
Therefore, we have to perform a real P2P access to check if it is actually
|
122
|
+
possible.
|
123
|
+
|
124
|
+
Note on p2p and cuda IPC:
|
125
|
+
Usually, one process uses one GPU:
|
126
|
+
GPU src --> cuda context src --> tensor src --> process src
|
127
|
+
|
128
|
+
We need to combine p2p and cuda IPC, so that:
|
129
|
+
GPU src --> cuda context src --> tensor src --> process src
|
130
|
+
|shared|
|
131
|
+
GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
|
132
|
+
That is to say, process src creates a tensor in GPU src, passes IPC handle to
|
133
|
+
process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
|
134
|
+
tensor in process tgt will be reflected in the tensor in process src, because
|
135
|
+
they are the same memory segment.
|
136
|
+
It is important to note that process tgt accesses the tensor in GPU tgt, not
|
137
|
+
GPU src. That's why we need p2p access.
|
138
|
+
|
139
|
+
The most time-consuming part is the process creation. To avoid creating
|
140
|
+
processes for every pair of GPUs, we use batched testing. We create two
|
141
|
+
processes for testing all pairs of GPUs in batch. The trick is to reset
|
142
|
+
the device after each test (which is not available in PyTorch).
|
143
|
+
""" # noqa
|
144
|
+
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
145
|
+
# pass the CUDA_VISIBLE_DEVICES to the child process
|
146
|
+
# to make sure they see the same set of GPUs
|
147
|
+
|
148
|
+
# make sure the processes are spawned
|
149
|
+
smp = mp.get_context("spawn")
|
150
|
+
producer_queue = smp.Queue()
|
151
|
+
consumer_queue = smp.Queue()
|
152
|
+
result_queue = smp.Queue()
|
153
|
+
p_src = smp.Process(
|
154
|
+
target=producer,
|
155
|
+
args=(
|
156
|
+
batch_src,
|
157
|
+
producer_queue,
|
158
|
+
consumer_queue,
|
159
|
+
result_queue,
|
160
|
+
cuda_visible_devices,
|
161
|
+
),
|
162
|
+
)
|
163
|
+
p_tgt = smp.Process(
|
164
|
+
target=consumer,
|
165
|
+
args=(
|
166
|
+
batch_tgt,
|
167
|
+
producer_queue,
|
168
|
+
consumer_queue,
|
169
|
+
result_queue,
|
170
|
+
cuda_visible_devices,
|
171
|
+
),
|
172
|
+
)
|
173
|
+
p_src.start()
|
174
|
+
p_tgt.start()
|
175
|
+
p_src.join()
|
176
|
+
p_tgt.join()
|
177
|
+
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
|
178
|
+
result: List[bool] = []
|
179
|
+
for src, tgt in zip(batch_src, batch_tgt):
|
180
|
+
a = result_queue.get()
|
181
|
+
b = result_queue.get()
|
182
|
+
if a != b:
|
183
|
+
logger.warning(
|
184
|
+
"Two processes do not agree on the P2P access"
|
185
|
+
" status on %d -> %d, treat as disabled.",
|
186
|
+
src,
|
187
|
+
tgt,
|
188
|
+
)
|
189
|
+
result.append(False)
|
190
|
+
else:
|
191
|
+
result.append(a)
|
192
|
+
return result
|
193
|
+
|
194
|
+
|
195
|
+
# why do we need this cache?
|
196
|
+
# we are testing peer-to-peer (p2p) access between GPUs,across processes.
|
197
|
+
# if we test it every time, it will be very slow, because we need to create
|
198
|
+
# N * N * 2 processes, where N is the world size. This is very slow.
|
199
|
+
# to reduce the time, we use a cache file to store the p2p access status.
|
200
|
+
# the cache file is generated by the master process if it does not exist.
|
201
|
+
# then all the processes can read the cache file to check the p2p access status.
|
202
|
+
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
|
203
|
+
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
|
204
|
+
# e.g. used by different vllm engines. The device id in the cache file is a
|
205
|
+
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
|
206
|
+
# of visible devices in the vllm engine.
|
207
|
+
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
|
208
|
+
|
209
|
+
|
210
|
+
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
211
|
+
"""Check if GPU src can access GPU tgt."""
|
212
|
+
|
213
|
+
# if the cache variable is already calculated,
|
214
|
+
# read from the cache instead of checking it again
|
215
|
+
global _gpu_p2p_access_cache
|
216
|
+
if _gpu_p2p_access_cache is not None:
|
217
|
+
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
218
|
+
|
219
|
+
is_distributed = dist.is_initialized()
|
220
|
+
|
221
|
+
num_dev = cuda_device_count_stateless()
|
222
|
+
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
223
|
+
if cuda_visible_devices is None:
|
224
|
+
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
225
|
+
|
226
|
+
# VLLM_CACHE_ROOT -> SGLANG_CACHE_ROOT
|
227
|
+
# "~/.cache/vllm" -> "~/.cache/sglang"
|
228
|
+
SGLANG_CACHE_ROOT = os.path.expanduser("~/.cache/sglang")
|
229
|
+
path = os.path.join(
|
230
|
+
SGLANG_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
|
231
|
+
)
|
232
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
233
|
+
from sglang.srt.distributed.parallel_state import get_world_group
|
234
|
+
|
235
|
+
if (not is_distributed or get_world_group().local_rank == 0) and (
|
236
|
+
not os.path.exists(path)
|
237
|
+
):
|
238
|
+
# only the local master process (with local_rank == 0) can
|
239
|
+
# enter this block to calculate the cache
|
240
|
+
logger.info("generating GPU P2P access cache in %s", path)
|
241
|
+
cache: Dict[str, bool] = {}
|
242
|
+
ids = list(range(num_dev))
|
243
|
+
# batch of all pairs of GPUs
|
244
|
+
batch_src, batch_tgt = zip(*list(product(ids, ids)))
|
245
|
+
# NOTE: we use `subprocess` rather than `multiprocessing` here
|
246
|
+
# because the caller might not have `if __name__ == "__main__":`,
|
247
|
+
# in that case we cannot use spawn method in multiprocessing.
|
248
|
+
# However, `can_actually_p2p` requires spawn method.
|
249
|
+
# The fix is, we use `subprocess` to call the function,
|
250
|
+
# where we have `if __name__ == "__main__":` in this file.
|
251
|
+
|
252
|
+
# use a temporary file to store the result
|
253
|
+
# we don't use the output of the subprocess directly,
|
254
|
+
# because the subprocess might produce logging output
|
255
|
+
with tempfile.NamedTemporaryFile() as output_file:
|
256
|
+
input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name))
|
257
|
+
returned = subprocess.run(
|
258
|
+
[sys.executable, __file__], input=input_bytes, capture_output=True
|
259
|
+
)
|
260
|
+
# check if the subprocess is successful
|
261
|
+
try:
|
262
|
+
returned.check_returncode()
|
263
|
+
except Exception as e:
|
264
|
+
# wrap raised exception to provide more information
|
265
|
+
raise RuntimeError(
|
266
|
+
f"Error happened when batch testing "
|
267
|
+
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
|
268
|
+
f"{returned.stderr.decode()}"
|
269
|
+
) from e
|
270
|
+
with open(output_file.name, "rb") as f:
|
271
|
+
result = pickle.load(f)
|
272
|
+
for _i, _j, r in zip(batch_src, batch_tgt, result):
|
273
|
+
cache[f"{_i}->{_j}"] = r
|
274
|
+
with open(path, "w") as f:
|
275
|
+
json.dump(cache, f, indent=4)
|
276
|
+
if is_distributed:
|
277
|
+
get_world_group().barrier()
|
278
|
+
logger.info("reading GPU P2P access cache from %s", path)
|
279
|
+
with open(path) as f:
|
280
|
+
cache = json.load(f)
|
281
|
+
_gpu_p2p_access_cache = cache
|
282
|
+
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
283
|
+
|
284
|
+
|
285
|
+
__all__ = ["gpu_p2p_access_check"]
|
286
|
+
|
287
|
+
if __name__ == "__main__":
|
288
|
+
batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
|
289
|
+
result = can_actually_p2p(batch_src, batch_tgt)
|
290
|
+
with open(output_file, "wb") as f:
|
291
|
+
f.write(pickle.dumps(result))
|
@@ -0,0 +1,48 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/hpu_communicator.py
|
2
|
+
import torch
|
3
|
+
import torch.distributed as dist
|
4
|
+
from torch.distributed import ProcessGroup
|
5
|
+
|
6
|
+
from sglang.srt.utils import is_hpu
|
7
|
+
|
8
|
+
if is_hpu():
|
9
|
+
import habana_frameworks.torch as htorch # noqa: F401
|
10
|
+
|
11
|
+
|
12
|
+
class HpuCommunicator:
|
13
|
+
|
14
|
+
def __init__(self, group: ProcessGroup):
|
15
|
+
if not is_hpu():
|
16
|
+
self.disabled = True
|
17
|
+
return
|
18
|
+
self.disabled = False
|
19
|
+
self.group = group
|
20
|
+
self.world_size = dist.get_world_size(self.group)
|
21
|
+
|
22
|
+
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
23
|
+
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
24
|
+
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
|
25
|
+
# (which is required for tensor parallel HPUGraph inference)
|
26
|
+
htorch.core.mark_step()
|
27
|
+
dist.all_reduce(x, group=self.group)
|
28
|
+
return x
|
29
|
+
|
30
|
+
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
31
|
+
world_size = self.world_size
|
32
|
+
if dim < 0:
|
33
|
+
# Convert negative dim to positive.
|
34
|
+
dim += x.dim()
|
35
|
+
input_size = x.size()
|
36
|
+
# Allocate output tensor.
|
37
|
+
output_tensor = torch.empty(
|
38
|
+
(world_size,) + input_size, dtype=x.dtype, device=x.device
|
39
|
+
)
|
40
|
+
# All-gather.
|
41
|
+
htorch.core.mark_step()
|
42
|
+
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
|
43
|
+
# Reshape
|
44
|
+
output_tensor = output_tensor.movedim(0, dim)
|
45
|
+
output_tensor = output_tensor.reshape(
|
46
|
+
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
|
47
|
+
)
|
48
|
+
return output_tensor
|
@@ -0,0 +1,204 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py
|
2
|
+
import logging
|
3
|
+
from contextlib import contextmanager
|
4
|
+
from typing import Optional, Union
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import torch.distributed as dist
|
8
|
+
from torch.distributed import ProcessGroup, ReduceOp
|
9
|
+
|
10
|
+
from sglang.srt.distributed.device_communicators.pynccl_wrapper import (
|
11
|
+
NCCLLibrary,
|
12
|
+
buffer_type,
|
13
|
+
cudaStream_t,
|
14
|
+
ncclComm_t,
|
15
|
+
ncclDataTypeEnum,
|
16
|
+
ncclRedOpTypeEnum,
|
17
|
+
ncclUniqueId,
|
18
|
+
)
|
19
|
+
from sglang.srt.distributed.utils import StatelessProcessGroup
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class PyNcclCommunicator:
|
25
|
+
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
group: Union[ProcessGroup, StatelessProcessGroup],
|
29
|
+
device: Union[int, str, torch.device],
|
30
|
+
library_path: Optional[str] = None,
|
31
|
+
):
|
32
|
+
"""
|
33
|
+
Args:
|
34
|
+
group: the process group to work on. If None, it will use the
|
35
|
+
default process group.
|
36
|
+
device: the device to bind the PyNcclCommunicator to. If None,
|
37
|
+
it will be bind to f"cuda:{local_rank}".
|
38
|
+
library_path: the path to the NCCL library. If None, it will
|
39
|
+
use the default library path.
|
40
|
+
It is the caller's responsibility to make sure each communicator
|
41
|
+
is bind to a unique device.
|
42
|
+
"""
|
43
|
+
if not isinstance(group, StatelessProcessGroup):
|
44
|
+
assert dist.is_initialized()
|
45
|
+
assert (
|
46
|
+
dist.get_backend(group) != dist.Backend.NCCL
|
47
|
+
), "PyNcclCommunicator should be attached to a non-NCCL group."
|
48
|
+
# note: this rank is the rank in the group
|
49
|
+
self.rank = dist.get_rank(group)
|
50
|
+
self.world_size = dist.get_world_size(group)
|
51
|
+
else:
|
52
|
+
self.rank = group.rank
|
53
|
+
self.world_size = group.world_size
|
54
|
+
|
55
|
+
self.group = group
|
56
|
+
|
57
|
+
# if world_size == 1, no need to create communicator
|
58
|
+
if self.world_size == 1:
|
59
|
+
self.available = False
|
60
|
+
self.disabled = True
|
61
|
+
self.stream = None
|
62
|
+
return
|
63
|
+
try:
|
64
|
+
self.nccl = NCCLLibrary(library_path)
|
65
|
+
except Exception:
|
66
|
+
# disable because of missing NCCL library
|
67
|
+
# e.g. in a non-GPU environment
|
68
|
+
self.available = False
|
69
|
+
self.disabled = True
|
70
|
+
self.stream = None
|
71
|
+
return
|
72
|
+
|
73
|
+
self.available = True
|
74
|
+
self.disabled = False
|
75
|
+
|
76
|
+
logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
|
77
|
+
|
78
|
+
if self.rank == 0:
|
79
|
+
# get the unique id from NCCL
|
80
|
+
self.unique_id = self.nccl.ncclGetUniqueId()
|
81
|
+
else:
|
82
|
+
# construct an empty unique id
|
83
|
+
self.unique_id = ncclUniqueId()
|
84
|
+
|
85
|
+
if not isinstance(group, StatelessProcessGroup):
|
86
|
+
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
87
|
+
ranks = dist.get_process_group_ranks(group)
|
88
|
+
# arg `src` in `broadcast` is the global rank
|
89
|
+
dist.broadcast(tensor, src=ranks[0], group=group)
|
90
|
+
byte_list = tensor.tolist()
|
91
|
+
for i, byte in enumerate(byte_list):
|
92
|
+
self.unique_id.internal[i] = byte
|
93
|
+
else:
|
94
|
+
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
|
95
|
+
if isinstance(device, int):
|
96
|
+
device = torch.device(f"cuda:{device}")
|
97
|
+
elif isinstance(device, str):
|
98
|
+
device = torch.device(device)
|
99
|
+
# now `device` is a `torch.device` object
|
100
|
+
assert isinstance(device, torch.device)
|
101
|
+
self.device = device
|
102
|
+
# nccl communicator and stream will use this device
|
103
|
+
# `torch.cuda.device` is a context manager that changes the
|
104
|
+
# current cuda device to the specified one
|
105
|
+
with torch.cuda.device(device):
|
106
|
+
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
107
|
+
self.world_size, self.unique_id, self.rank
|
108
|
+
)
|
109
|
+
self.stream = torch.cuda.Stream()
|
110
|
+
|
111
|
+
# A small all_reduce for warmup.
|
112
|
+
data = torch.zeros(1, device=device)
|
113
|
+
self.all_reduce(data)
|
114
|
+
self.stream.synchronize()
|
115
|
+
del data
|
116
|
+
|
117
|
+
# by default it is disabled, e.g. in profiling models and prefill phase.
|
118
|
+
# to use it, use under `with obj.change_state(enable=True)`, usually
|
119
|
+
# when we are using CUDA graph.
|
120
|
+
self.disabled = True
|
121
|
+
|
122
|
+
def all_reduce(
|
123
|
+
self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
|
124
|
+
):
|
125
|
+
if self.disabled:
|
126
|
+
return
|
127
|
+
# nccl communicator created on a specific device
|
128
|
+
# will only work on tensors on the same device
|
129
|
+
# otherwise it will cause "illegal memory access"
|
130
|
+
assert tensor.device == self.device, (
|
131
|
+
f"this nccl communicator is created to work on {self.device}, "
|
132
|
+
f"but the input tensor is on {tensor.device}"
|
133
|
+
)
|
134
|
+
if stream is None:
|
135
|
+
stream = self.stream
|
136
|
+
self.nccl.ncclAllReduce(
|
137
|
+
buffer_type(tensor.data_ptr()),
|
138
|
+
buffer_type(tensor.data_ptr()),
|
139
|
+
tensor.numel(),
|
140
|
+
ncclDataTypeEnum.from_torch(tensor.dtype),
|
141
|
+
ncclRedOpTypeEnum.from_torch(op),
|
142
|
+
self.comm,
|
143
|
+
cudaStream_t(stream.cuda_stream),
|
144
|
+
)
|
145
|
+
|
146
|
+
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
147
|
+
if self.disabled:
|
148
|
+
return
|
149
|
+
assert tensor.device == self.device, (
|
150
|
+
f"this nccl communicator is created to work on {self.device}, "
|
151
|
+
f"but the input tensor is on {tensor.device}"
|
152
|
+
)
|
153
|
+
if stream is None:
|
154
|
+
stream = self.stream
|
155
|
+
self.nccl.ncclSend(
|
156
|
+
buffer_type(tensor.data_ptr()),
|
157
|
+
tensor.numel(),
|
158
|
+
ncclDataTypeEnum.from_torch(tensor.dtype),
|
159
|
+
dst,
|
160
|
+
self.comm,
|
161
|
+
cudaStream_t(stream.cuda_stream),
|
162
|
+
)
|
163
|
+
|
164
|
+
def recv(self, tensor: torch.Tensor, src: int, stream=None):
|
165
|
+
if self.disabled:
|
166
|
+
return
|
167
|
+
assert tensor.device == self.device, (
|
168
|
+
f"this nccl communicator is created to work on {self.device}, "
|
169
|
+
f"but the input tensor is on {tensor.device}"
|
170
|
+
)
|
171
|
+
if stream is None:
|
172
|
+
stream = self.stream
|
173
|
+
self.nccl.ncclRecv(
|
174
|
+
buffer_type(tensor.data_ptr()),
|
175
|
+
tensor.numel(),
|
176
|
+
ncclDataTypeEnum.from_torch(tensor.dtype),
|
177
|
+
src,
|
178
|
+
self.comm,
|
179
|
+
cudaStream_t(stream.cuda_stream),
|
180
|
+
)
|
181
|
+
|
182
|
+
@contextmanager
|
183
|
+
def change_state(
|
184
|
+
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
|
185
|
+
):
|
186
|
+
"""
|
187
|
+
A context manager to change the state of the communicator.
|
188
|
+
"""
|
189
|
+
if enable is None:
|
190
|
+
# guess a default value when not specified
|
191
|
+
enable = self.available
|
192
|
+
|
193
|
+
if stream is None:
|
194
|
+
stream = self.stream
|
195
|
+
|
196
|
+
old_disable = self.disabled
|
197
|
+
old_stream = self.stream
|
198
|
+
|
199
|
+
self.stream = stream
|
200
|
+
self.disabled = not enable
|
201
|
+
yield
|
202
|
+
|
203
|
+
self.disabled = old_disable
|
204
|
+
self.stream = old_stream
|