sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +55 -2
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +4 -3
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/launch_server.py +3 -2
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +6 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/flashinfer_backend.py +13 -15
- sglang/srt/layers/attention/torch_native_backend.py +285 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +34 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -11
- sglang/srt/managers/detokenizer_manager.py +7 -4
- sglang/srt/managers/image_processor.py +1 -1
- sglang/srt/managers/io_struct.py +48 -12
- sglang/srt/managers/schedule_batch.py +42 -36
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +111 -46
- sglang/srt/managers/session_controller.py +0 -3
- sglang/srt/managers/tokenizer_manager.py +169 -100
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
- sglang/srt/model_executor/cuda_graph_runner.py +16 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +136 -150
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +2 -3
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +3 -11
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +14 -51
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +97 -27
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +10 -12
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +12 -5
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +391 -0
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -8
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -11
- sglang/srt/models/qwen2_vl.py +12 -9
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -12
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +10 -6
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/server.py +303 -204
- sglang/srt/server_args.py +65 -31
- sglang/srt/utils.py +253 -48
- sglang/test/test_utils.py +27 -7
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
- sglang-0.4.0.dist-info/RECORD +184 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
- sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
- sglang/srt/layers/fused_moe_grok/layer.py +0 -630
- sglang-0.3.6.post2.dist-info/RECORD +0 -164
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,204 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py
|
2
|
+
import logging
|
3
|
+
from contextlib import contextmanager
|
4
|
+
from typing import Optional, Union
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import torch.distributed as dist
|
8
|
+
from torch.distributed import ProcessGroup, ReduceOp
|
9
|
+
|
10
|
+
from sglang.srt.distributed.device_communicators.pynccl_wrapper import (
|
11
|
+
NCCLLibrary,
|
12
|
+
buffer_type,
|
13
|
+
cudaStream_t,
|
14
|
+
ncclComm_t,
|
15
|
+
ncclDataTypeEnum,
|
16
|
+
ncclRedOpTypeEnum,
|
17
|
+
ncclUniqueId,
|
18
|
+
)
|
19
|
+
from sglang.srt.distributed.utils import StatelessProcessGroup
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class PyNcclCommunicator:
|
25
|
+
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
group: Union[ProcessGroup, StatelessProcessGroup],
|
29
|
+
device: Union[int, str, torch.device],
|
30
|
+
library_path: Optional[str] = None,
|
31
|
+
):
|
32
|
+
"""
|
33
|
+
Args:
|
34
|
+
group: the process group to work on. If None, it will use the
|
35
|
+
default process group.
|
36
|
+
device: the device to bind the PyNcclCommunicator to. If None,
|
37
|
+
it will be bind to f"cuda:{local_rank}".
|
38
|
+
library_path: the path to the NCCL library. If None, it will
|
39
|
+
use the default library path.
|
40
|
+
It is the caller's responsibility to make sure each communicator
|
41
|
+
is bind to a unique device.
|
42
|
+
"""
|
43
|
+
if not isinstance(group, StatelessProcessGroup):
|
44
|
+
assert dist.is_initialized()
|
45
|
+
assert (
|
46
|
+
dist.get_backend(group) != dist.Backend.NCCL
|
47
|
+
), "PyNcclCommunicator should be attached to a non-NCCL group."
|
48
|
+
# note: this rank is the rank in the group
|
49
|
+
self.rank = dist.get_rank(group)
|
50
|
+
self.world_size = dist.get_world_size(group)
|
51
|
+
else:
|
52
|
+
self.rank = group.rank
|
53
|
+
self.world_size = group.world_size
|
54
|
+
|
55
|
+
self.group = group
|
56
|
+
|
57
|
+
# if world_size == 1, no need to create communicator
|
58
|
+
if self.world_size == 1:
|
59
|
+
self.available = False
|
60
|
+
self.disabled = True
|
61
|
+
self.stream = None
|
62
|
+
return
|
63
|
+
try:
|
64
|
+
self.nccl = NCCLLibrary(library_path)
|
65
|
+
except Exception:
|
66
|
+
# disable because of missing NCCL library
|
67
|
+
# e.g. in a non-GPU environment
|
68
|
+
self.available = False
|
69
|
+
self.disabled = True
|
70
|
+
self.stream = None
|
71
|
+
return
|
72
|
+
|
73
|
+
self.available = True
|
74
|
+
self.disabled = False
|
75
|
+
|
76
|
+
logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
|
77
|
+
|
78
|
+
if self.rank == 0:
|
79
|
+
# get the unique id from NCCL
|
80
|
+
self.unique_id = self.nccl.ncclGetUniqueId()
|
81
|
+
else:
|
82
|
+
# construct an empty unique id
|
83
|
+
self.unique_id = ncclUniqueId()
|
84
|
+
|
85
|
+
if not isinstance(group, StatelessProcessGroup):
|
86
|
+
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
87
|
+
ranks = dist.get_process_group_ranks(group)
|
88
|
+
# arg `src` in `broadcast` is the global rank
|
89
|
+
dist.broadcast(tensor, src=ranks[0], group=group)
|
90
|
+
byte_list = tensor.tolist()
|
91
|
+
for i, byte in enumerate(byte_list):
|
92
|
+
self.unique_id.internal[i] = byte
|
93
|
+
else:
|
94
|
+
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
|
95
|
+
if isinstance(device, int):
|
96
|
+
device = torch.device(f"cuda:{device}")
|
97
|
+
elif isinstance(device, str):
|
98
|
+
device = torch.device(device)
|
99
|
+
# now `device` is a `torch.device` object
|
100
|
+
assert isinstance(device, torch.device)
|
101
|
+
self.device = device
|
102
|
+
# nccl communicator and stream will use this device
|
103
|
+
# `torch.cuda.device` is a context manager that changes the
|
104
|
+
# current cuda device to the specified one
|
105
|
+
with torch.cuda.device(device):
|
106
|
+
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
107
|
+
self.world_size, self.unique_id, self.rank
|
108
|
+
)
|
109
|
+
self.stream = torch.cuda.Stream()
|
110
|
+
|
111
|
+
# A small all_reduce for warmup.
|
112
|
+
data = torch.zeros(1, device=device)
|
113
|
+
self.all_reduce(data)
|
114
|
+
self.stream.synchronize()
|
115
|
+
del data
|
116
|
+
|
117
|
+
# by default it is disabled, e.g. in profiling models and prefill phase.
|
118
|
+
# to use it, use under `with obj.change_state(enable=True)`, usually
|
119
|
+
# when we are using CUDA graph.
|
120
|
+
self.disabled = True
|
121
|
+
|
122
|
+
def all_reduce(
|
123
|
+
self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
|
124
|
+
):
|
125
|
+
if self.disabled:
|
126
|
+
return
|
127
|
+
# nccl communicator created on a specific device
|
128
|
+
# will only work on tensors on the same device
|
129
|
+
# otherwise it will cause "illegal memory access"
|
130
|
+
assert tensor.device == self.device, (
|
131
|
+
f"this nccl communicator is created to work on {self.device}, "
|
132
|
+
f"but the input tensor is on {tensor.device}"
|
133
|
+
)
|
134
|
+
if stream is None:
|
135
|
+
stream = self.stream
|
136
|
+
self.nccl.ncclAllReduce(
|
137
|
+
buffer_type(tensor.data_ptr()),
|
138
|
+
buffer_type(tensor.data_ptr()),
|
139
|
+
tensor.numel(),
|
140
|
+
ncclDataTypeEnum.from_torch(tensor.dtype),
|
141
|
+
ncclRedOpTypeEnum.from_torch(op),
|
142
|
+
self.comm,
|
143
|
+
cudaStream_t(stream.cuda_stream),
|
144
|
+
)
|
145
|
+
|
146
|
+
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
147
|
+
if self.disabled:
|
148
|
+
return
|
149
|
+
assert tensor.device == self.device, (
|
150
|
+
f"this nccl communicator is created to work on {self.device}, "
|
151
|
+
f"but the input tensor is on {tensor.device}"
|
152
|
+
)
|
153
|
+
if stream is None:
|
154
|
+
stream = self.stream
|
155
|
+
self.nccl.ncclSend(
|
156
|
+
buffer_type(tensor.data_ptr()),
|
157
|
+
tensor.numel(),
|
158
|
+
ncclDataTypeEnum.from_torch(tensor.dtype),
|
159
|
+
dst,
|
160
|
+
self.comm,
|
161
|
+
cudaStream_t(stream.cuda_stream),
|
162
|
+
)
|
163
|
+
|
164
|
+
def recv(self, tensor: torch.Tensor, src: int, stream=None):
|
165
|
+
if self.disabled:
|
166
|
+
return
|
167
|
+
assert tensor.device == self.device, (
|
168
|
+
f"this nccl communicator is created to work on {self.device}, "
|
169
|
+
f"but the input tensor is on {tensor.device}"
|
170
|
+
)
|
171
|
+
if stream is None:
|
172
|
+
stream = self.stream
|
173
|
+
self.nccl.ncclRecv(
|
174
|
+
buffer_type(tensor.data_ptr()),
|
175
|
+
tensor.numel(),
|
176
|
+
ncclDataTypeEnum.from_torch(tensor.dtype),
|
177
|
+
src,
|
178
|
+
self.comm,
|
179
|
+
cudaStream_t(stream.cuda_stream),
|
180
|
+
)
|
181
|
+
|
182
|
+
@contextmanager
|
183
|
+
def change_state(
|
184
|
+
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
|
185
|
+
):
|
186
|
+
"""
|
187
|
+
A context manager to change the state of the communicator.
|
188
|
+
"""
|
189
|
+
if enable is None:
|
190
|
+
# guess a default value when not specified
|
191
|
+
enable = self.available
|
192
|
+
|
193
|
+
if stream is None:
|
194
|
+
stream = self.stream
|
195
|
+
|
196
|
+
old_disable = self.disabled
|
197
|
+
old_stream = self.stream
|
198
|
+
|
199
|
+
self.stream = stream
|
200
|
+
self.disabled = not enable
|
201
|
+
yield
|
202
|
+
|
203
|
+
self.disabled = old_disable
|
204
|
+
self.stream = old_stream
|
@@ -0,0 +1,362 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py
|
2
|
+
|
3
|
+
# This file is a pure Python wrapper for the NCCL library.
|
4
|
+
# The main purpose is to use NCCL combined with CUDA graph.
|
5
|
+
# Before writing this script, we tried the following approach:
|
6
|
+
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
|
7
|
+
# often gets stuck when initializing the NCCL communicator.
|
8
|
+
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
|
9
|
+
# contains many other potential cuda APIs, that are not allowed during
|
10
|
+
# capturing the CUDA graph. For further details, please check
|
11
|
+
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
|
12
|
+
#
|
13
|
+
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
|
14
|
+
# doable, but we often encounter issues related with nccl versions, and need
|
15
|
+
# to switch between different versions of NCCL. See
|
16
|
+
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
|
17
|
+
# A C/C++ binding is not flexible enough to handle this. It requires
|
18
|
+
# recompilation of the code every time we want to switch between different
|
19
|
+
# versions. This current implementation, with a **pure** Python wrapper, is
|
20
|
+
# more flexible. We can easily switch between different versions of NCCL by
|
21
|
+
# changing the environment variable `SGLANG_NCCL_SO_PATH`, or the `so_file`
|
22
|
+
# variable in the code.
|
23
|
+
|
24
|
+
import ctypes
|
25
|
+
import logging
|
26
|
+
import os
|
27
|
+
import platform
|
28
|
+
from dataclasses import dataclass
|
29
|
+
from typing import Any, Dict, List, Optional
|
30
|
+
|
31
|
+
import torch
|
32
|
+
from torch.distributed import ReduceOp
|
33
|
+
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
|
37
|
+
def find_nccl_library() -> str:
|
38
|
+
"""
|
39
|
+
We either use the library file specified by the `SGLANG_NCCL_SO_PATH`
|
40
|
+
environment variable, or we find the library file brought by PyTorch.
|
41
|
+
After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
|
42
|
+
found by `ctypes` automatically.
|
43
|
+
"""
|
44
|
+
|
45
|
+
# so_file can be set to None in sglang
|
46
|
+
so_file = os.environ.get("SGLANG_NCCL_SO_PATH", None)
|
47
|
+
|
48
|
+
# manually load the nccl library
|
49
|
+
if so_file:
|
50
|
+
logger.info(
|
51
|
+
"Found nccl from environment variable SGLANG_NCCL_SO_PATH=%s", so_file
|
52
|
+
)
|
53
|
+
else:
|
54
|
+
if torch.version.cuda is not None:
|
55
|
+
so_file = "libnccl.so.2"
|
56
|
+
elif torch.version.hip is not None:
|
57
|
+
so_file = "librccl.so.1"
|
58
|
+
else:
|
59
|
+
raise ValueError("NCCL only supports CUDA and ROCm backends.")
|
60
|
+
logger.info("Found nccl from library %s", so_file)
|
61
|
+
return so_file
|
62
|
+
|
63
|
+
|
64
|
+
# === export types and functions from nccl to Python ===
|
65
|
+
# for the original nccl definition, please check
|
66
|
+
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
|
67
|
+
|
68
|
+
ncclResult_t = ctypes.c_int
|
69
|
+
ncclComm_t = ctypes.c_void_p
|
70
|
+
|
71
|
+
|
72
|
+
class ncclUniqueId(ctypes.Structure):
|
73
|
+
_fields_ = [("internal", ctypes.c_byte * 128)]
|
74
|
+
|
75
|
+
|
76
|
+
cudaStream_t = ctypes.c_void_p
|
77
|
+
buffer_type = ctypes.c_void_p
|
78
|
+
|
79
|
+
ncclDataType_t = ctypes.c_int
|
80
|
+
|
81
|
+
|
82
|
+
class ncclDataTypeEnum:
|
83
|
+
ncclInt8 = 0
|
84
|
+
ncclChar = 0
|
85
|
+
ncclUint8 = 1
|
86
|
+
ncclInt32 = 2
|
87
|
+
ncclInt = 2
|
88
|
+
ncclUint32 = 3
|
89
|
+
ncclInt64 = 4
|
90
|
+
ncclUint64 = 5
|
91
|
+
ncclFloat16 = 6
|
92
|
+
ncclHalf = 6
|
93
|
+
ncclFloat32 = 7
|
94
|
+
ncclFloat = 7
|
95
|
+
ncclFloat64 = 8
|
96
|
+
ncclDouble = 8
|
97
|
+
ncclBfloat16 = 9
|
98
|
+
ncclNumTypes = 10
|
99
|
+
|
100
|
+
@classmethod
|
101
|
+
def from_torch(cls, dtype: torch.dtype) -> int:
|
102
|
+
if dtype == torch.int8:
|
103
|
+
return cls.ncclInt8
|
104
|
+
if dtype == torch.uint8:
|
105
|
+
return cls.ncclUint8
|
106
|
+
if dtype == torch.int32:
|
107
|
+
return cls.ncclInt32
|
108
|
+
if dtype == torch.int64:
|
109
|
+
return cls.ncclInt64
|
110
|
+
if dtype == torch.float16:
|
111
|
+
return cls.ncclFloat16
|
112
|
+
if dtype == torch.float32:
|
113
|
+
return cls.ncclFloat32
|
114
|
+
if dtype == torch.float64:
|
115
|
+
return cls.ncclFloat64
|
116
|
+
if dtype == torch.bfloat16:
|
117
|
+
return cls.ncclBfloat16
|
118
|
+
raise ValueError(f"Unsupported dtype: {dtype}")
|
119
|
+
|
120
|
+
|
121
|
+
ncclRedOp_t = ctypes.c_int
|
122
|
+
|
123
|
+
|
124
|
+
class ncclRedOpTypeEnum:
|
125
|
+
ncclSum = 0
|
126
|
+
ncclProd = 1
|
127
|
+
ncclMax = 2
|
128
|
+
ncclMin = 3
|
129
|
+
ncclAvg = 4
|
130
|
+
ncclNumOps = 5
|
131
|
+
|
132
|
+
@classmethod
|
133
|
+
def from_torch(cls, op: ReduceOp) -> int:
|
134
|
+
if op == ReduceOp.SUM:
|
135
|
+
return cls.ncclSum
|
136
|
+
if op == ReduceOp.PRODUCT:
|
137
|
+
return cls.ncclProd
|
138
|
+
if op == ReduceOp.MAX:
|
139
|
+
return cls.ncclMax
|
140
|
+
if op == ReduceOp.MIN:
|
141
|
+
return cls.ncclMin
|
142
|
+
if op == ReduceOp.AVG:
|
143
|
+
return cls.ncclAvg
|
144
|
+
raise ValueError(f"Unsupported op: {op}")
|
145
|
+
|
146
|
+
|
147
|
+
@dataclass
|
148
|
+
class Function:
|
149
|
+
name: str
|
150
|
+
restype: Any
|
151
|
+
argtypes: List[Any]
|
152
|
+
|
153
|
+
|
154
|
+
class NCCLLibrary:
|
155
|
+
exported_functions = [
|
156
|
+
# const char* ncclGetErrorString(ncclResult_t result)
|
157
|
+
Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
|
158
|
+
# ncclResult_t ncclGetVersion(int *version);
|
159
|
+
Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]),
|
160
|
+
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
|
161
|
+
Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]),
|
162
|
+
# ncclResult_t ncclCommInitRank(
|
163
|
+
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
|
164
|
+
# note that ncclComm_t is a pointer type, so the first argument
|
165
|
+
# is a pointer to a pointer
|
166
|
+
Function(
|
167
|
+
"ncclCommInitRank",
|
168
|
+
ncclResult_t,
|
169
|
+
[ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int],
|
170
|
+
),
|
171
|
+
# ncclResult_t ncclAllReduce(
|
172
|
+
# const void* sendbuff, void* recvbuff, size_t count,
|
173
|
+
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
174
|
+
# cudaStream_t stream);
|
175
|
+
# note that cudaStream_t is a pointer type, so the last argument
|
176
|
+
# is a pointer
|
177
|
+
Function(
|
178
|
+
"ncclAllReduce",
|
179
|
+
ncclResult_t,
|
180
|
+
[
|
181
|
+
buffer_type,
|
182
|
+
buffer_type,
|
183
|
+
ctypes.c_size_t,
|
184
|
+
ncclDataType_t,
|
185
|
+
ncclRedOp_t,
|
186
|
+
ncclComm_t,
|
187
|
+
cudaStream_t,
|
188
|
+
],
|
189
|
+
),
|
190
|
+
# ncclResult_t ncclSend(
|
191
|
+
# const void* sendbuff, size_t count, ncclDataType_t datatype,
|
192
|
+
# int dest, ncclComm_t comm, cudaStream_t stream);
|
193
|
+
Function(
|
194
|
+
"ncclSend",
|
195
|
+
ncclResult_t,
|
196
|
+
[
|
197
|
+
buffer_type,
|
198
|
+
ctypes.c_size_t,
|
199
|
+
ncclDataType_t,
|
200
|
+
ctypes.c_int,
|
201
|
+
ncclComm_t,
|
202
|
+
cudaStream_t,
|
203
|
+
],
|
204
|
+
),
|
205
|
+
# ncclResult_t ncclRecv(
|
206
|
+
# void* recvbuff, size_t count, ncclDataType_t datatype,
|
207
|
+
# int src, ncclComm_t comm, cudaStream_t stream);
|
208
|
+
Function(
|
209
|
+
"ncclRecv",
|
210
|
+
ncclResult_t,
|
211
|
+
[
|
212
|
+
buffer_type,
|
213
|
+
ctypes.c_size_t,
|
214
|
+
ncclDataType_t,
|
215
|
+
ctypes.c_int,
|
216
|
+
ncclComm_t,
|
217
|
+
cudaStream_t,
|
218
|
+
],
|
219
|
+
),
|
220
|
+
# be cautious! this is a collective call, it will block until all
|
221
|
+
# processes in the communicator have called this function.
|
222
|
+
# because Python object destruction can happen in random order,
|
223
|
+
# it is better not to call it at all.
|
224
|
+
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
225
|
+
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
226
|
+
]
|
227
|
+
|
228
|
+
# class attribute to store the mapping from the path to the library
|
229
|
+
# to avoid loading the same library multiple times
|
230
|
+
path_to_library_cache: Dict[str, Any] = {}
|
231
|
+
|
232
|
+
# class attribute to store the mapping from library path
|
233
|
+
# to the corresponding dictionary
|
234
|
+
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
235
|
+
|
236
|
+
def __init__(self, so_file: Optional[str] = None):
|
237
|
+
|
238
|
+
so_file = so_file or find_nccl_library()
|
239
|
+
|
240
|
+
try:
|
241
|
+
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
242
|
+
lib = ctypes.CDLL(so_file)
|
243
|
+
NCCLLibrary.path_to_library_cache[so_file] = lib
|
244
|
+
self.lib = NCCLLibrary.path_to_library_cache[so_file]
|
245
|
+
except Exception as e:
|
246
|
+
logger.error(
|
247
|
+
"Failed to load NCCL library from %s ."
|
248
|
+
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
249
|
+
"Otherwise, the nccl library might not exist, be corrupted "
|
250
|
+
"or it does not support the current platform %s."
|
251
|
+
"If you already have the library, please set the "
|
252
|
+
"environment variable SGLANG_NCCL_SO_PATH"
|
253
|
+
" to point to the correct nccl library path.",
|
254
|
+
so_file,
|
255
|
+
platform.platform(),
|
256
|
+
)
|
257
|
+
raise e
|
258
|
+
|
259
|
+
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
260
|
+
_funcs: Dict[str, Any] = {}
|
261
|
+
for func in NCCLLibrary.exported_functions:
|
262
|
+
f = getattr(self.lib, func.name)
|
263
|
+
f.restype = func.restype
|
264
|
+
f.argtypes = func.argtypes
|
265
|
+
_funcs[func.name] = f
|
266
|
+
NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
|
267
|
+
self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
|
268
|
+
|
269
|
+
def ncclGetErrorString(self, result: ncclResult_t) -> str:
|
270
|
+
return self._funcs["ncclGetErrorString"](result).decode("utf-8")
|
271
|
+
|
272
|
+
def NCCL_CHECK(self, result: ncclResult_t) -> None:
|
273
|
+
if result != 0:
|
274
|
+
error_str = self.ncclGetErrorString(result)
|
275
|
+
raise RuntimeError(f"NCCL error: {error_str}")
|
276
|
+
|
277
|
+
def ncclGetVersion(self) -> str:
|
278
|
+
version = ctypes.c_int()
|
279
|
+
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
|
280
|
+
version_str = str(version.value)
|
281
|
+
# something like 21903 --> "2.19.3"
|
282
|
+
major = version_str[0].lstrip("0")
|
283
|
+
minor = version_str[1:3].lstrip("0")
|
284
|
+
patch = version_str[3:].lstrip("0")
|
285
|
+
return f"{major}.{minor}.{patch}"
|
286
|
+
|
287
|
+
def ncclGetUniqueId(self) -> ncclUniqueId:
|
288
|
+
unique_id = ncclUniqueId()
|
289
|
+
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id)))
|
290
|
+
return unique_id
|
291
|
+
|
292
|
+
def ncclCommInitRank(
|
293
|
+
self, world_size: int, unique_id: ncclUniqueId, rank: int
|
294
|
+
) -> ncclComm_t:
|
295
|
+
comm = ncclComm_t()
|
296
|
+
self.NCCL_CHECK(
|
297
|
+
self._funcs["ncclCommInitRank"](
|
298
|
+
ctypes.byref(comm), world_size, unique_id, rank
|
299
|
+
)
|
300
|
+
)
|
301
|
+
return comm
|
302
|
+
|
303
|
+
def ncclAllReduce(
|
304
|
+
self,
|
305
|
+
sendbuff: buffer_type,
|
306
|
+
recvbuff: buffer_type,
|
307
|
+
count: int,
|
308
|
+
datatype: int,
|
309
|
+
op: int,
|
310
|
+
comm: ncclComm_t,
|
311
|
+
stream: cudaStream_t,
|
312
|
+
) -> None:
|
313
|
+
# `datatype` actually should be `ncclDataType_t`
|
314
|
+
# and `op` should be `ncclRedOp_t`
|
315
|
+
# both are aliases of `ctypes.c_int`
|
316
|
+
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
317
|
+
# by ctypes automatically
|
318
|
+
self.NCCL_CHECK(
|
319
|
+
self._funcs["ncclAllReduce"](
|
320
|
+
sendbuff, recvbuff, count, datatype, op, comm, stream
|
321
|
+
)
|
322
|
+
)
|
323
|
+
|
324
|
+
def ncclSend(
|
325
|
+
self,
|
326
|
+
sendbuff: buffer_type,
|
327
|
+
count: int,
|
328
|
+
datatype: int,
|
329
|
+
dest: int,
|
330
|
+
comm: ncclComm_t,
|
331
|
+
stream: cudaStream_t,
|
332
|
+
) -> None:
|
333
|
+
self.NCCL_CHECK(
|
334
|
+
self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream)
|
335
|
+
)
|
336
|
+
|
337
|
+
def ncclRecv(
|
338
|
+
self,
|
339
|
+
recvbuff: buffer_type,
|
340
|
+
count: int,
|
341
|
+
datatype: int,
|
342
|
+
src: int,
|
343
|
+
comm: ncclComm_t,
|
344
|
+
stream: cudaStream_t,
|
345
|
+
) -> None:
|
346
|
+
self.NCCL_CHECK(
|
347
|
+
self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)
|
348
|
+
)
|
349
|
+
|
350
|
+
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
351
|
+
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
352
|
+
|
353
|
+
|
354
|
+
__all__ = [
|
355
|
+
"NCCLLibrary",
|
356
|
+
"ncclDataTypeEnum",
|
357
|
+
"ncclRedOpTypeEnum",
|
358
|
+
"ncclUniqueId",
|
359
|
+
"ncclComm_t",
|
360
|
+
"cudaStream_t",
|
361
|
+
"buffer_type",
|
362
|
+
]
|