sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_one_batch.py +4 -0
  3. sglang/bench_serving.py +13 -0
  4. sglang/check_env.py +1 -1
  5. sglang/srt/_custom_ops.py +118 -0
  6. sglang/srt/configs/device_config.py +17 -0
  7. sglang/srt/configs/load_config.py +84 -0
  8. sglang/srt/configs/model_config.py +161 -4
  9. sglang/srt/configs/qwen2vl.py +5 -8
  10. sglang/srt/constrained/outlines_backend.py +11 -1
  11. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  12. sglang/srt/constrained/xgrammar_backend.py +5 -5
  13. sglang/srt/distributed/__init__.py +3 -0
  14. sglang/srt/distributed/communication_op.py +34 -0
  15. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  16. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  19. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  20. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  21. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  23. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  24. sglang/srt/distributed/parallel_state.py +1275 -0
  25. sglang/srt/distributed/utils.py +223 -0
  26. sglang/srt/hf_transformers_utils.py +37 -1
  27. sglang/srt/layers/attention/__init__.py +5 -2
  28. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  29. sglang/srt/layers/attention/flashinfer_backend.py +33 -20
  30. sglang/srt/layers/attention/torch_native_backend.py +299 -0
  31. sglang/srt/layers/attention/triton_backend.py +22 -8
  32. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  33. sglang/srt/layers/ep_moe/__init__.py +0 -0
  34. sglang/srt/layers/ep_moe/kernels.py +349 -0
  35. sglang/srt/layers/ep_moe/layer.py +661 -0
  36. sglang/srt/layers/fused_moe_patch.py +20 -11
  37. sglang/srt/layers/linear.py +1 -0
  38. sglang/srt/layers/logits_processor.py +17 -3
  39. sglang/srt/layers/quantization/__init__.py +36 -2
  40. sglang/srt/layers/quantization/fp8.py +559 -0
  41. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  42. sglang/srt/layers/radix_attention.py +4 -2
  43. sglang/srt/layers/sampler.py +2 -0
  44. sglang/srt/layers/torchao_utils.py +23 -45
  45. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  46. sglang/srt/lora/lora.py +1 -1
  47. sglang/srt/managers/io_struct.py +48 -2
  48. sglang/srt/managers/schedule_batch.py +19 -14
  49. sglang/srt/managers/schedule_policy.py +7 -4
  50. sglang/srt/managers/scheduler.py +145 -85
  51. sglang/srt/managers/tokenizer_manager.py +166 -68
  52. sglang/srt/managers/tp_worker.py +36 -3
  53. sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
  54. sglang/srt/mem_cache/memory_pool.py +5 -1
  55. sglang/srt/model_executor/cuda_graph_runner.py +30 -7
  56. sglang/srt/model_executor/forward_batch_info.py +9 -4
  57. sglang/srt/model_executor/model_runner.py +146 -153
  58. sglang/srt/model_loader/__init__.py +34 -0
  59. sglang/srt/model_loader/loader.py +1139 -0
  60. sglang/srt/model_loader/utils.py +41 -0
  61. sglang/srt/model_loader/weight_utils.py +640 -0
  62. sglang/srt/model_parallel.py +1 -5
  63. sglang/srt/models/baichuan.py +9 -10
  64. sglang/srt/models/chatglm.py +6 -15
  65. sglang/srt/models/commandr.py +4 -5
  66. sglang/srt/models/dbrx.py +2 -3
  67. sglang/srt/models/deepseek.py +4 -11
  68. sglang/srt/models/deepseek_v2.py +90 -18
  69. sglang/srt/models/exaone.py +2 -3
  70. sglang/srt/models/gemma.py +2 -6
  71. sglang/srt/models/gemma2.py +3 -14
  72. sglang/srt/models/gemma2_reward.py +0 -1
  73. sglang/srt/models/gpt2.py +5 -12
  74. sglang/srt/models/gpt_bigcode.py +6 -22
  75. sglang/srt/models/grok.py +3 -8
  76. sglang/srt/models/internlm2.py +2 -3
  77. sglang/srt/models/internlm2_reward.py +0 -1
  78. sglang/srt/models/llama.py +96 -31
  79. sglang/srt/models/llama_classification.py +1 -2
  80. sglang/srt/models/llama_embedding.py +1 -2
  81. sglang/srt/models/llama_reward.py +2 -3
  82. sglang/srt/models/llava.py +1 -4
  83. sglang/srt/models/llavavid.py +1 -2
  84. sglang/srt/models/minicpm.py +4 -7
  85. sglang/srt/models/minicpm3.py +6 -19
  86. sglang/srt/models/mixtral.py +24 -14
  87. sglang/srt/models/mixtral_quant.py +2 -3
  88. sglang/srt/models/mllama.py +3 -7
  89. sglang/srt/models/olmo.py +2 -8
  90. sglang/srt/models/olmo2.py +0 -1
  91. sglang/srt/models/olmoe.py +3 -5
  92. sglang/srt/models/phi3_small.py +8 -13
  93. sglang/srt/models/qwen.py +2 -3
  94. sglang/srt/models/qwen2.py +10 -9
  95. sglang/srt/models/qwen2_moe.py +4 -16
  96. sglang/srt/models/qwen2_vl.py +2 -6
  97. sglang/srt/models/registry.py +99 -0
  98. sglang/srt/models/stablelm.py +2 -3
  99. sglang/srt/models/torch_native_llama.py +6 -17
  100. sglang/srt/models/xverse.py +2 -4
  101. sglang/srt/models/xverse_moe.py +4 -11
  102. sglang/srt/models/yivl.py +2 -3
  103. sglang/srt/openai_api/adapter.py +9 -5
  104. sglang/srt/openai_api/protocol.py +1 -0
  105. sglang/srt/sampling/sampling_batch_info.py +9 -8
  106. sglang/srt/server.py +270 -173
  107. sglang/srt/server_args.py +102 -29
  108. sglang/srt/utils.py +295 -28
  109. sglang/test/test_utils.py +7 -0
  110. sglang/version.py +1 -1
  111. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  112. sglang-0.4.0.post1.dist-info/RECORD +189 -0
  113. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  114. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  115. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  116. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,291 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce_utils.py
2
+ import ctypes
3
+ import json
4
+ import logging
5
+ import os
6
+ import pickle
7
+ import subprocess
8
+ import sys
9
+ import tempfile
10
+ from functools import lru_cache
11
+ from itertools import product
12
+ from typing import Dict, List, Optional, Sequence
13
+
14
+ import torch.distributed as dist
15
+ import torch.multiprocessing as mp
16
+
17
+ from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
18
+ from sglang.srt.utils import cuda_device_count_stateless
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def update_environment_variables(envs: Dict[str, str]):
24
+ for k, v in envs.items():
25
+ if k in os.environ and os.environ[k] != v:
26
+ logger.warning(
27
+ "Overwriting environment variable %s " "from '%s' to '%s'",
28
+ k,
29
+ os.environ[k],
30
+ v,
31
+ )
32
+ os.environ[k] = v
33
+
34
+
35
+ def producer(
36
+ batch_src: Sequence[int],
37
+ producer_queue,
38
+ consumer_queue,
39
+ result_queue,
40
+ cuda_visible_devices: Optional[str] = None,
41
+ ):
42
+ if cuda_visible_devices is not None:
43
+ update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
44
+
45
+ lib = CudaRTLibrary()
46
+ for i in batch_src:
47
+ lib.cudaSetDevice(i)
48
+ pointer = lib.cudaMalloc(1024)
49
+ lib.cudaMemset(pointer, 1, 1024)
50
+ lib.cudaDeviceSynchronize()
51
+ handle = lib.cudaIpcGetMemHandle(pointer)
52
+ producer_queue.put(handle)
53
+ open_success = consumer_queue.get()
54
+ if open_success:
55
+ # use two queues to simulate barrier
56
+ producer_queue.put(0)
57
+ consumer_queue.get()
58
+ # check if the memory is modified
59
+ host_data = (ctypes.c_char * 1024)()
60
+ lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
61
+ for i in range(1024):
62
+ if ord(host_data[i]) != 2:
63
+ open_success = False
64
+ break
65
+ result_queue.put(open_success)
66
+ lib.cudaDeviceReset()
67
+
68
+
69
+ def consumer(
70
+ batch_tgt: Sequence[int],
71
+ producer_queue,
72
+ consumer_queue,
73
+ result_queue,
74
+ cuda_visible_devices: Optional[str] = None,
75
+ ):
76
+ if cuda_visible_devices is not None:
77
+ update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
78
+
79
+ lib = CudaRTLibrary()
80
+ for j in batch_tgt:
81
+ lib.cudaSetDevice(j)
82
+ handle = producer_queue.get()
83
+ open_success = False
84
+ try:
85
+ pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
86
+ open_success = True
87
+ except RuntimeError:
88
+ # cannot error out here, because the producer process
89
+ # is still waiting for the response.
90
+ pass
91
+ consumer_queue.put(open_success)
92
+ if open_success:
93
+ # modify the memory
94
+ lib.cudaMemset(pointer, 2, 1024)
95
+ lib.cudaDeviceSynchronize()
96
+ # use two queues to simulate barrier
97
+ producer_queue.get()
98
+ consumer_queue.put(0)
99
+ # check if the memory is modified
100
+ host_data = (ctypes.c_char * 1024)()
101
+ lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
102
+ for i in range(1024):
103
+ if ord(host_data[i]) != 2:
104
+ open_success = False
105
+ break
106
+ result_queue.put(open_success)
107
+ lib.cudaDeviceReset()
108
+
109
+
110
+ def can_actually_p2p(
111
+ batch_src: Sequence[int],
112
+ batch_tgt: Sequence[int],
113
+ ) -> Sequence[bool]:
114
+ """
115
+ Usually, checking if P2P access is enabled can be done by
116
+ `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
117
+ the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
118
+ returns `True` even if P2P access is not actually possible.
119
+ See https://github.com/vllm-project/vllm/issues/2728 and
120
+ https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
121
+ Therefore, we have to perform a real P2P access to check if it is actually
122
+ possible.
123
+
124
+ Note on p2p and cuda IPC:
125
+ Usually, one process uses one GPU:
126
+ GPU src --> cuda context src --> tensor src --> process src
127
+
128
+ We need to combine p2p and cuda IPC, so that:
129
+ GPU src --> cuda context src --> tensor src --> process src
130
+ |shared|
131
+ GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
132
+ That is to say, process src creates a tensor in GPU src, passes IPC handle to
133
+ process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
134
+ tensor in process tgt will be reflected in the tensor in process src, because
135
+ they are the same memory segment.
136
+ It is important to note that process tgt accesses the tensor in GPU tgt, not
137
+ GPU src. That's why we need p2p access.
138
+
139
+ The most time-consuming part is the process creation. To avoid creating
140
+ processes for every pair of GPUs, we use batched testing. We create two
141
+ processes for testing all pairs of GPUs in batch. The trick is to reset
142
+ the device after each test (which is not available in PyTorch).
143
+ """ # noqa
144
+ cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
145
+ # pass the CUDA_VISIBLE_DEVICES to the child process
146
+ # to make sure they see the same set of GPUs
147
+
148
+ # make sure the processes are spawned
149
+ smp = mp.get_context("spawn")
150
+ producer_queue = smp.Queue()
151
+ consumer_queue = smp.Queue()
152
+ result_queue = smp.Queue()
153
+ p_src = smp.Process(
154
+ target=producer,
155
+ args=(
156
+ batch_src,
157
+ producer_queue,
158
+ consumer_queue,
159
+ result_queue,
160
+ cuda_visible_devices,
161
+ ),
162
+ )
163
+ p_tgt = smp.Process(
164
+ target=consumer,
165
+ args=(
166
+ batch_tgt,
167
+ producer_queue,
168
+ consumer_queue,
169
+ result_queue,
170
+ cuda_visible_devices,
171
+ ),
172
+ )
173
+ p_src.start()
174
+ p_tgt.start()
175
+ p_src.join()
176
+ p_tgt.join()
177
+ assert p_src.exitcode == 0 and p_tgt.exitcode == 0
178
+ result: List[bool] = []
179
+ for src, tgt in zip(batch_src, batch_tgt):
180
+ a = result_queue.get()
181
+ b = result_queue.get()
182
+ if a != b:
183
+ logger.warning(
184
+ "Two processes do not agree on the P2P access"
185
+ " status on %d -> %d, treat as disabled.",
186
+ src,
187
+ tgt,
188
+ )
189
+ result.append(False)
190
+ else:
191
+ result.append(a)
192
+ return result
193
+
194
+
195
+ # why do we need this cache?
196
+ # we are testing peer-to-peer (p2p) access between GPUs,across processes.
197
+ # if we test it every time, it will be very slow, because we need to create
198
+ # N * N * 2 processes, where N is the world size. This is very slow.
199
+ # to reduce the time, we use a cache file to store the p2p access status.
200
+ # the cache file is generated by the master process if it does not exist.
201
+ # then all the processes can read the cache file to check the p2p access status.
202
+ # Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
203
+ # can have different cache files for different CUDA_VISIBLE_DEVICES settings,
204
+ # e.g. used by different vllm engines. The device id in the cache file is a
205
+ # **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
206
+ # of visible devices in the vllm engine.
207
+ _gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
208
+
209
+
210
+ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
211
+ """Check if GPU src can access GPU tgt."""
212
+
213
+ # if the cache variable is already calculated,
214
+ # read from the cache instead of checking it again
215
+ global _gpu_p2p_access_cache
216
+ if _gpu_p2p_access_cache is not None:
217
+ return _gpu_p2p_access_cache[f"{src}->{tgt}"]
218
+
219
+ is_distributed = dist.is_initialized()
220
+
221
+ num_dev = cuda_device_count_stateless()
222
+ cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
223
+ if cuda_visible_devices is None:
224
+ cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
225
+
226
+ # VLLM_CACHE_ROOT -> SGLANG_CACHE_ROOT
227
+ # "~/.cache/vllm" -> "~/.cache/sglang"
228
+ SGLANG_CACHE_ROOT = os.path.expanduser("~/.cache/sglang")
229
+ path = os.path.join(
230
+ SGLANG_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
231
+ )
232
+ os.makedirs(os.path.dirname(path), exist_ok=True)
233
+ from sglang.srt.distributed.parallel_state import get_world_group
234
+
235
+ if (not is_distributed or get_world_group().local_rank == 0) and (
236
+ not os.path.exists(path)
237
+ ):
238
+ # only the local master process (with local_rank == 0) can
239
+ # enter this block to calculate the cache
240
+ logger.info("generating GPU P2P access cache in %s", path)
241
+ cache: Dict[str, bool] = {}
242
+ ids = list(range(num_dev))
243
+ # batch of all pairs of GPUs
244
+ batch_src, batch_tgt = zip(*list(product(ids, ids)))
245
+ # NOTE: we use `subprocess` rather than `multiprocessing` here
246
+ # because the caller might not have `if __name__ == "__main__":`,
247
+ # in that case we cannot use spawn method in multiprocessing.
248
+ # However, `can_actually_p2p` requires spawn method.
249
+ # The fix is, we use `subprocess` to call the function,
250
+ # where we have `if __name__ == "__main__":` in this file.
251
+
252
+ # use a temporary file to store the result
253
+ # we don't use the output of the subprocess directly,
254
+ # because the subprocess might produce logging output
255
+ with tempfile.NamedTemporaryFile() as output_file:
256
+ input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name))
257
+ returned = subprocess.run(
258
+ [sys.executable, __file__], input=input_bytes, capture_output=True
259
+ )
260
+ # check if the subprocess is successful
261
+ try:
262
+ returned.check_returncode()
263
+ except Exception as e:
264
+ # wrap raised exception to provide more information
265
+ raise RuntimeError(
266
+ f"Error happened when batch testing "
267
+ f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
268
+ f"{returned.stderr.decode()}"
269
+ ) from e
270
+ with open(output_file.name, "rb") as f:
271
+ result = pickle.load(f)
272
+ for _i, _j, r in zip(batch_src, batch_tgt, result):
273
+ cache[f"{_i}->{_j}"] = r
274
+ with open(path, "w") as f:
275
+ json.dump(cache, f, indent=4)
276
+ if is_distributed:
277
+ get_world_group().barrier()
278
+ logger.info("reading GPU P2P access cache from %s", path)
279
+ with open(path) as f:
280
+ cache = json.load(f)
281
+ _gpu_p2p_access_cache = cache
282
+ return _gpu_p2p_access_cache[f"{src}->{tgt}"]
283
+
284
+
285
+ __all__ = ["gpu_p2p_access_check"]
286
+
287
+ if __name__ == "__main__":
288
+ batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
289
+ result = can_actually_p2p(batch_src, batch_tgt)
290
+ with open(output_file, "wb") as f:
291
+ f.write(pickle.dumps(result))
@@ -0,0 +1,48 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/hpu_communicator.py
2
+ import torch
3
+ import torch.distributed as dist
4
+ from torch.distributed import ProcessGroup
5
+
6
+ from sglang.srt.utils import is_hpu
7
+
8
+ if is_hpu():
9
+ import habana_frameworks.torch as htorch # noqa: F401
10
+
11
+
12
+ class HpuCommunicator:
13
+
14
+ def __init__(self, group: ProcessGroup):
15
+ if not is_hpu():
16
+ self.disabled = True
17
+ return
18
+ self.disabled = False
19
+ self.group = group
20
+ self.world_size = dist.get_world_size(self.group)
21
+
22
+ def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
23
+ # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
24
+ # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
25
+ # (which is required for tensor parallel HPUGraph inference)
26
+ htorch.core.mark_step()
27
+ dist.all_reduce(x, group=self.group)
28
+ return x
29
+
30
+ def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
31
+ world_size = self.world_size
32
+ if dim < 0:
33
+ # Convert negative dim to positive.
34
+ dim += x.dim()
35
+ input_size = x.size()
36
+ # Allocate output tensor.
37
+ output_tensor = torch.empty(
38
+ (world_size,) + input_size, dtype=x.dtype, device=x.device
39
+ )
40
+ # All-gather.
41
+ htorch.core.mark_step()
42
+ dist.all_gather_into_tensor(output_tensor, x, group=self.group)
43
+ # Reshape
44
+ output_tensor = output_tensor.movedim(0, dim)
45
+ output_tensor = output_tensor.reshape(
46
+ input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
47
+ )
48
+ return output_tensor
@@ -0,0 +1,204 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py
2
+ import logging
3
+ from contextlib import contextmanager
4
+ from typing import Optional, Union
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ from torch.distributed import ProcessGroup, ReduceOp
9
+
10
+ from sglang.srt.distributed.device_communicators.pynccl_wrapper import (
11
+ NCCLLibrary,
12
+ buffer_type,
13
+ cudaStream_t,
14
+ ncclComm_t,
15
+ ncclDataTypeEnum,
16
+ ncclRedOpTypeEnum,
17
+ ncclUniqueId,
18
+ )
19
+ from sglang.srt.distributed.utils import StatelessProcessGroup
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class PyNcclCommunicator:
25
+
26
+ def __init__(
27
+ self,
28
+ group: Union[ProcessGroup, StatelessProcessGroup],
29
+ device: Union[int, str, torch.device],
30
+ library_path: Optional[str] = None,
31
+ ):
32
+ """
33
+ Args:
34
+ group: the process group to work on. If None, it will use the
35
+ default process group.
36
+ device: the device to bind the PyNcclCommunicator to. If None,
37
+ it will be bind to f"cuda:{local_rank}".
38
+ library_path: the path to the NCCL library. If None, it will
39
+ use the default library path.
40
+ It is the caller's responsibility to make sure each communicator
41
+ is bind to a unique device.
42
+ """
43
+ if not isinstance(group, StatelessProcessGroup):
44
+ assert dist.is_initialized()
45
+ assert (
46
+ dist.get_backend(group) != dist.Backend.NCCL
47
+ ), "PyNcclCommunicator should be attached to a non-NCCL group."
48
+ # note: this rank is the rank in the group
49
+ self.rank = dist.get_rank(group)
50
+ self.world_size = dist.get_world_size(group)
51
+ else:
52
+ self.rank = group.rank
53
+ self.world_size = group.world_size
54
+
55
+ self.group = group
56
+
57
+ # if world_size == 1, no need to create communicator
58
+ if self.world_size == 1:
59
+ self.available = False
60
+ self.disabled = True
61
+ self.stream = None
62
+ return
63
+ try:
64
+ self.nccl = NCCLLibrary(library_path)
65
+ except Exception:
66
+ # disable because of missing NCCL library
67
+ # e.g. in a non-GPU environment
68
+ self.available = False
69
+ self.disabled = True
70
+ self.stream = None
71
+ return
72
+
73
+ self.available = True
74
+ self.disabled = False
75
+
76
+ logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
77
+
78
+ if self.rank == 0:
79
+ # get the unique id from NCCL
80
+ self.unique_id = self.nccl.ncclGetUniqueId()
81
+ else:
82
+ # construct an empty unique id
83
+ self.unique_id = ncclUniqueId()
84
+
85
+ if not isinstance(group, StatelessProcessGroup):
86
+ tensor = torch.ByteTensor(list(self.unique_id.internal))
87
+ ranks = dist.get_process_group_ranks(group)
88
+ # arg `src` in `broadcast` is the global rank
89
+ dist.broadcast(tensor, src=ranks[0], group=group)
90
+ byte_list = tensor.tolist()
91
+ for i, byte in enumerate(byte_list):
92
+ self.unique_id.internal[i] = byte
93
+ else:
94
+ self.unique_id = group.broadcast_obj(self.unique_id, src=0)
95
+ if isinstance(device, int):
96
+ device = torch.device(f"cuda:{device}")
97
+ elif isinstance(device, str):
98
+ device = torch.device(device)
99
+ # now `device` is a `torch.device` object
100
+ assert isinstance(device, torch.device)
101
+ self.device = device
102
+ # nccl communicator and stream will use this device
103
+ # `torch.cuda.device` is a context manager that changes the
104
+ # current cuda device to the specified one
105
+ with torch.cuda.device(device):
106
+ self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
107
+ self.world_size, self.unique_id, self.rank
108
+ )
109
+ self.stream = torch.cuda.Stream()
110
+
111
+ # A small all_reduce for warmup.
112
+ data = torch.zeros(1, device=device)
113
+ self.all_reduce(data)
114
+ self.stream.synchronize()
115
+ del data
116
+
117
+ # by default it is disabled, e.g. in profiling models and prefill phase.
118
+ # to use it, use under `with obj.change_state(enable=True)`, usually
119
+ # when we are using CUDA graph.
120
+ self.disabled = True
121
+
122
+ def all_reduce(
123
+ self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
124
+ ):
125
+ if self.disabled:
126
+ return
127
+ # nccl communicator created on a specific device
128
+ # will only work on tensors on the same device
129
+ # otherwise it will cause "illegal memory access"
130
+ assert tensor.device == self.device, (
131
+ f"this nccl communicator is created to work on {self.device}, "
132
+ f"but the input tensor is on {tensor.device}"
133
+ )
134
+ if stream is None:
135
+ stream = self.stream
136
+ self.nccl.ncclAllReduce(
137
+ buffer_type(tensor.data_ptr()),
138
+ buffer_type(tensor.data_ptr()),
139
+ tensor.numel(),
140
+ ncclDataTypeEnum.from_torch(tensor.dtype),
141
+ ncclRedOpTypeEnum.from_torch(op),
142
+ self.comm,
143
+ cudaStream_t(stream.cuda_stream),
144
+ )
145
+
146
+ def send(self, tensor: torch.Tensor, dst: int, stream=None):
147
+ if self.disabled:
148
+ return
149
+ assert tensor.device == self.device, (
150
+ f"this nccl communicator is created to work on {self.device}, "
151
+ f"but the input tensor is on {tensor.device}"
152
+ )
153
+ if stream is None:
154
+ stream = self.stream
155
+ self.nccl.ncclSend(
156
+ buffer_type(tensor.data_ptr()),
157
+ tensor.numel(),
158
+ ncclDataTypeEnum.from_torch(tensor.dtype),
159
+ dst,
160
+ self.comm,
161
+ cudaStream_t(stream.cuda_stream),
162
+ )
163
+
164
+ def recv(self, tensor: torch.Tensor, src: int, stream=None):
165
+ if self.disabled:
166
+ return
167
+ assert tensor.device == self.device, (
168
+ f"this nccl communicator is created to work on {self.device}, "
169
+ f"but the input tensor is on {tensor.device}"
170
+ )
171
+ if stream is None:
172
+ stream = self.stream
173
+ self.nccl.ncclRecv(
174
+ buffer_type(tensor.data_ptr()),
175
+ tensor.numel(),
176
+ ncclDataTypeEnum.from_torch(tensor.dtype),
177
+ src,
178
+ self.comm,
179
+ cudaStream_t(stream.cuda_stream),
180
+ )
181
+
182
+ @contextmanager
183
+ def change_state(
184
+ self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
185
+ ):
186
+ """
187
+ A context manager to change the state of the communicator.
188
+ """
189
+ if enable is None:
190
+ # guess a default value when not specified
191
+ enable = self.available
192
+
193
+ if stream is None:
194
+ stream = self.stream
195
+
196
+ old_disable = self.disabled
197
+ old_stream = self.stream
198
+
199
+ self.stream = stream
200
+ self.disabled = not enable
201
+ yield
202
+
203
+ self.disabled = old_disable
204
+ self.stream = old_stream