sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +55 -2
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +4 -3
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/launch_server.py +3 -2
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +6 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/flashinfer_backend.py +13 -15
- sglang/srt/layers/attention/torch_native_backend.py +285 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +34 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -11
- sglang/srt/managers/detokenizer_manager.py +7 -4
- sglang/srt/managers/image_processor.py +1 -1
- sglang/srt/managers/io_struct.py +48 -12
- sglang/srt/managers/schedule_batch.py +42 -36
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +111 -46
- sglang/srt/managers/session_controller.py +0 -3
- sglang/srt/managers/tokenizer_manager.py +169 -100
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
- sglang/srt/model_executor/cuda_graph_runner.py +16 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +136 -150
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +2 -3
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +3 -11
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +14 -51
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +97 -27
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +10 -12
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +12 -5
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +391 -0
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -8
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -11
- sglang/srt/models/qwen2_vl.py +12 -9
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -12
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +10 -6
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/server.py +303 -204
- sglang/srt/server_args.py +65 -31
- sglang/srt/utils.py +253 -48
- sglang/test/test_utils.py +27 -7
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
- sglang-0.4.0.dist-info/RECORD +184 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
- sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
- sglang/srt/layers/fused_moe_grok/layer.py +0 -630
- sglang-0.3.6.post2.dist-info/RECORD +0 -164
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,568 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/shm_broadcast.py
|
2
|
+
import ipaddress
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import pickle
|
6
|
+
import socket
|
7
|
+
import time
|
8
|
+
import warnings
|
9
|
+
from contextlib import contextmanager
|
10
|
+
from dataclasses import dataclass, field
|
11
|
+
from multiprocessing import shared_memory
|
12
|
+
from typing import List, Optional
|
13
|
+
from unittest.mock import patch
|
14
|
+
|
15
|
+
import torch
|
16
|
+
import torch.distributed as dist
|
17
|
+
from torch.distributed import ProcessGroup
|
18
|
+
from zmq import IPV6 # type: ignore
|
19
|
+
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
20
|
+
|
21
|
+
# SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
|
22
|
+
SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
|
23
|
+
os.environ.get("SGLANG_RINGBUFFER_WARNING_INTERVAL", "60")
|
24
|
+
)
|
25
|
+
|
26
|
+
logger = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
|
29
|
+
def get_ip() -> str:
|
30
|
+
# SGLANG_HOST_IP env can be ignore
|
31
|
+
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
|
32
|
+
if host_ip:
|
33
|
+
return host_ip
|
34
|
+
|
35
|
+
# IP is not set, try to get it from the network interface
|
36
|
+
|
37
|
+
# try ipv4
|
38
|
+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
39
|
+
try:
|
40
|
+
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
|
41
|
+
return s.getsockname()[0]
|
42
|
+
except Exception:
|
43
|
+
pass
|
44
|
+
|
45
|
+
# try ipv6
|
46
|
+
try:
|
47
|
+
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
48
|
+
# Google's public DNS server, see
|
49
|
+
# https://developers.google.com/speed/public-dns/docs/using#addresses
|
50
|
+
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
51
|
+
return s.getsockname()[0]
|
52
|
+
except Exception:
|
53
|
+
pass
|
54
|
+
|
55
|
+
warnings.warn(
|
56
|
+
"Failed to get the IP address, using 0.0.0.0 by default."
|
57
|
+
"The value can be set by the environment variable"
|
58
|
+
" SGLANG_HOST_IP or HOST_IP.",
|
59
|
+
stacklevel=2,
|
60
|
+
)
|
61
|
+
return "0.0.0.0"
|
62
|
+
|
63
|
+
|
64
|
+
def get_open_port() -> int:
|
65
|
+
|
66
|
+
port = os.getenv("SGLANG_PORT")
|
67
|
+
if port is not None:
|
68
|
+
while True:
|
69
|
+
try:
|
70
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
71
|
+
s.bind(("", port))
|
72
|
+
return port
|
73
|
+
except OSError:
|
74
|
+
port += 1 # Increment port number if already in use
|
75
|
+
logger.info("Port %d is already in use, trying port %d", port - 1, port)
|
76
|
+
# try ipv4
|
77
|
+
try:
|
78
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
79
|
+
s.bind(("", 0))
|
80
|
+
return s.getsockname()[1]
|
81
|
+
except OSError:
|
82
|
+
# try ipv6
|
83
|
+
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
84
|
+
s.bind(("", 0))
|
85
|
+
return s.getsockname()[1]
|
86
|
+
|
87
|
+
|
88
|
+
def is_valid_ipv6_address(address: str) -> bool:
|
89
|
+
try:
|
90
|
+
ipaddress.IPv6Address(address)
|
91
|
+
return True
|
92
|
+
except ValueError:
|
93
|
+
return False
|
94
|
+
|
95
|
+
|
96
|
+
class ShmRingBuffer:
|
97
|
+
|
98
|
+
def __init__(
|
99
|
+
self,
|
100
|
+
n_reader: int,
|
101
|
+
max_chunk_bytes: int,
|
102
|
+
max_chunks: int,
|
103
|
+
name: Optional[str] = None,
|
104
|
+
):
|
105
|
+
"""
|
106
|
+
A shared memory ring buffer implementation for broadcast communication.
|
107
|
+
Essentially, it is a queue where only one will `enqueue` and multiple
|
108
|
+
will `dequeue`. The max size of each item, together with the max number
|
109
|
+
of items that can be stored in the buffer are known in advance.
|
110
|
+
In this case, we don't need to synchronize the access to
|
111
|
+
the buffer.
|
112
|
+
|
113
|
+
Buffer memory layout:
|
114
|
+
data metadata
|
115
|
+
| |
|
116
|
+
| (current_idx) | (current_idx)
|
117
|
+
v v
|
118
|
+
+-------------------------------+----------------------------------------+
|
119
|
+
| chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
|
120
|
+
+-------------------------------+----------------------------------------+
|
121
|
+
| max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
|
122
|
+
|
123
|
+
metadata memory layout: each byte is a flag, the first byte is the written
|
124
|
+
flag, and the rest are reader flags. The flags are set to 0 by default.
|
125
|
+
+--------------+--------------+--------------+-----+--------------+
|
126
|
+
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
|
127
|
+
+--------------+--------------+--------------+-----+--------------+
|
128
|
+
|
129
|
+
The state of metadata is as follows:
|
130
|
+
|
131
|
+
(case 1) 0???...???: the block is not written yet, cannot read, can write
|
132
|
+
(case 2) 1000...000: the block is just written, can read, cannot write
|
133
|
+
(case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
|
134
|
+
(case 4) 1111...111: the block is written and read by all readers, cannot read, can write
|
135
|
+
|
136
|
+
State transition for readers:
|
137
|
+
|
138
|
+
When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
|
139
|
+
Only after the caller finishes reading the block, the reader can mark the block as read.
|
140
|
+
Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).
|
141
|
+
|
142
|
+
State transition for writer:
|
143
|
+
|
144
|
+
When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
|
145
|
+
to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
|
146
|
+
can reset the reader flags to 0, and mark the block as written (from 0 to 1).
|
147
|
+
NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
|
148
|
+
|
149
|
+
During creation, `name` is None and the buffer is created. We can pass the
|
150
|
+
created object to other processes by pickling it. The other processes will
|
151
|
+
get the name of the shared memory and open it, so that they can access the
|
152
|
+
same shared memory buffer.
|
153
|
+
""" # noqa
|
154
|
+
self.n_reader = n_reader
|
155
|
+
self.metadata_size = 1 + n_reader
|
156
|
+
self.max_chunk_bytes = max_chunk_bytes
|
157
|
+
self.max_chunks = max_chunks
|
158
|
+
self.total_bytes_of_buffer = (
|
159
|
+
self.max_chunk_bytes + self.metadata_size
|
160
|
+
) * self.max_chunks
|
161
|
+
self.data_offset = 0
|
162
|
+
self.metadata_offset = self.max_chunk_bytes * self.max_chunks
|
163
|
+
|
164
|
+
if name is None:
|
165
|
+
# we are creating a buffer
|
166
|
+
self.is_creator = True
|
167
|
+
self.shared_memory = shared_memory.SharedMemory(
|
168
|
+
create=True, size=self.total_bytes_of_buffer
|
169
|
+
)
|
170
|
+
# initialize the metadata section to 0
|
171
|
+
with memoryview(
|
172
|
+
self.shared_memory.buf[self.metadata_offset :]
|
173
|
+
) as metadata_buffer:
|
174
|
+
torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
|
175
|
+
else:
|
176
|
+
# we are opening an existing buffer
|
177
|
+
self.is_creator = False
|
178
|
+
# fix to https://stackoverflow.com/q/62748654/9191338
|
179
|
+
# Python incorrectly tracks shared memory even if it is not
|
180
|
+
# created by the process. The following patch is a workaround.
|
181
|
+
with patch(
|
182
|
+
"multiprocessing.resource_tracker.register",
|
183
|
+
lambda *args, **kwargs: None,
|
184
|
+
):
|
185
|
+
try:
|
186
|
+
self.shared_memory = shared_memory.SharedMemory(name=name)
|
187
|
+
assert self.shared_memory.size == self.total_bytes_of_buffer
|
188
|
+
except FileNotFoundError:
|
189
|
+
# we might deserialize the object in a different node
|
190
|
+
# in this case, this object is not used,
|
191
|
+
# and we should suppress the error
|
192
|
+
pass
|
193
|
+
|
194
|
+
def __reduce__(self):
|
195
|
+
return (
|
196
|
+
self.__class__,
|
197
|
+
(
|
198
|
+
self.n_reader,
|
199
|
+
self.max_chunk_bytes,
|
200
|
+
self.max_chunks,
|
201
|
+
self.shared_memory.name,
|
202
|
+
),
|
203
|
+
)
|
204
|
+
|
205
|
+
def __del__(self):
|
206
|
+
if hasattr(self, "shared_memory"):
|
207
|
+
self.shared_memory.close()
|
208
|
+
if self.is_creator:
|
209
|
+
self.shared_memory.unlink()
|
210
|
+
|
211
|
+
@contextmanager
|
212
|
+
def get_data(self, current_idx: int):
|
213
|
+
start = self.data_offset + current_idx * self.max_chunk_bytes
|
214
|
+
end = start + self.max_chunk_bytes
|
215
|
+
with memoryview(self.shared_memory.buf[start:end]) as buf:
|
216
|
+
yield buf
|
217
|
+
|
218
|
+
@contextmanager
|
219
|
+
def get_metadata(self, current_idx: int):
|
220
|
+
start = self.metadata_offset + current_idx * self.metadata_size
|
221
|
+
end = start + self.metadata_size
|
222
|
+
with memoryview(self.shared_memory.buf[start:end]) as buf:
|
223
|
+
yield buf
|
224
|
+
|
225
|
+
|
226
|
+
@dataclass
|
227
|
+
class Handle:
|
228
|
+
connect_ip: str
|
229
|
+
local_reader_ranks: List[int] = field(default_factory=list)
|
230
|
+
|
231
|
+
buffer: Optional[ShmRingBuffer] = None
|
232
|
+
local_subscribe_port: Optional[int] = None
|
233
|
+
remote_subscribe_port: Optional[int] = None
|
234
|
+
|
235
|
+
|
236
|
+
class MessageQueue:
|
237
|
+
|
238
|
+
def __init__(
|
239
|
+
self,
|
240
|
+
n_reader, # number of all readers
|
241
|
+
n_local_reader, # number of local readers through shared memory
|
242
|
+
local_reader_ranks: Optional[List[int]] = None,
|
243
|
+
max_chunk_bytes: int = 1024 * 1024 * 10,
|
244
|
+
max_chunks: int = 10,
|
245
|
+
connect_ip: Optional[str] = None,
|
246
|
+
):
|
247
|
+
if local_reader_ranks is None:
|
248
|
+
local_reader_ranks = list(range(n_local_reader))
|
249
|
+
else:
|
250
|
+
assert len(local_reader_ranks) == n_local_reader
|
251
|
+
self.n_local_reader = n_local_reader
|
252
|
+
n_remote_reader = n_reader - n_local_reader
|
253
|
+
self.n_remote_reader = n_remote_reader
|
254
|
+
|
255
|
+
if connect_ip is None:
|
256
|
+
connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
|
257
|
+
|
258
|
+
context = Context()
|
259
|
+
|
260
|
+
if n_local_reader > 0:
|
261
|
+
# for local readers, we will:
|
262
|
+
# 1. create a shared memory ring buffer to communicate small data
|
263
|
+
# 2. create a publish-subscribe socket to communicate large data
|
264
|
+
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks)
|
265
|
+
|
266
|
+
# XPUB is very similar to PUB,
|
267
|
+
# except that it can receive subscription messages
|
268
|
+
# to confirm the number of subscribers
|
269
|
+
self.local_socket = context.socket(XPUB)
|
270
|
+
# set the verbose option so that we can receive every subscription
|
271
|
+
# message. otherwise, we will only receive the first subscription
|
272
|
+
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
|
273
|
+
self.local_socket.setsockopt(XPUB_VERBOSE, True)
|
274
|
+
local_subscribe_port = get_open_port()
|
275
|
+
socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}"
|
276
|
+
logger.debug("Binding to %s", socket_addr)
|
277
|
+
self.local_socket.bind(socket_addr)
|
278
|
+
|
279
|
+
self.current_idx = 0
|
280
|
+
|
281
|
+
else:
|
282
|
+
self.buffer = None # type: ignore
|
283
|
+
local_subscribe_port = None
|
284
|
+
self.local_socket = None
|
285
|
+
self.current_idx = -1
|
286
|
+
|
287
|
+
if n_remote_reader > 0:
|
288
|
+
# for remote readers, we will:
|
289
|
+
# create a publish-subscribe socket to communicate large data
|
290
|
+
self.remote_socket = context.socket(XPUB)
|
291
|
+
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
|
292
|
+
remote_subscribe_port = get_open_port()
|
293
|
+
if is_valid_ipv6_address(connect_ip):
|
294
|
+
self.remote_socket.setsockopt(IPV6, 1)
|
295
|
+
socket_addr = f"tcp://*:{remote_subscribe_port}"
|
296
|
+
self.remote_socket.bind(socket_addr)
|
297
|
+
|
298
|
+
else:
|
299
|
+
remote_subscribe_port = None
|
300
|
+
self.remote_socket = None
|
301
|
+
|
302
|
+
self._is_writer = True
|
303
|
+
self._is_local_reader = False
|
304
|
+
self.local_reader_rank = -1
|
305
|
+
# rank does not matter for remote readers
|
306
|
+
self._is_remote_reader = False
|
307
|
+
|
308
|
+
self.handle = Handle(
|
309
|
+
connect_ip=connect_ip,
|
310
|
+
local_reader_ranks=local_reader_ranks,
|
311
|
+
buffer=self.buffer,
|
312
|
+
local_subscribe_port=local_subscribe_port,
|
313
|
+
remote_subscribe_port=remote_subscribe_port,
|
314
|
+
)
|
315
|
+
|
316
|
+
logger.info("vLLM message queue communication handle: %s", self.handle)
|
317
|
+
|
318
|
+
def export_handle(self) -> Handle:
|
319
|
+
return self.handle
|
320
|
+
|
321
|
+
@staticmethod
|
322
|
+
def create_from_handle(handle: Handle, rank) -> "MessageQueue":
|
323
|
+
self = MessageQueue.__new__(MessageQueue)
|
324
|
+
self.handle = handle
|
325
|
+
self._is_writer = False
|
326
|
+
|
327
|
+
context = Context()
|
328
|
+
|
329
|
+
if rank in handle.local_reader_ranks:
|
330
|
+
assert handle.buffer is not None
|
331
|
+
self.buffer = handle.buffer
|
332
|
+
self.current_idx = 0
|
333
|
+
self.local_reader_rank = handle.local_reader_ranks.index(rank)
|
334
|
+
self._is_local_reader = True
|
335
|
+
self._is_remote_reader = False
|
336
|
+
|
337
|
+
self.local_socket = context.socket(SUB)
|
338
|
+
self.local_socket.setsockopt_string(SUBSCRIBE, "")
|
339
|
+
socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}"
|
340
|
+
logger.debug("Connecting to %s", socket_addr)
|
341
|
+
self.local_socket.connect(socket_addr)
|
342
|
+
|
343
|
+
self.remote_socket = None
|
344
|
+
else:
|
345
|
+
self.buffer = None # type: ignore
|
346
|
+
self.current_idx = -1
|
347
|
+
self.local_reader_rank = -1
|
348
|
+
self._is_local_reader = False
|
349
|
+
self._is_remote_reader = True
|
350
|
+
|
351
|
+
self.local_socket = None
|
352
|
+
|
353
|
+
self.remote_socket = context.socket(SUB)
|
354
|
+
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
355
|
+
if is_valid_ipv6_address(handle.connect_ip):
|
356
|
+
self.remote_socket.setsockopt(IPV6, 1)
|
357
|
+
socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
|
358
|
+
logger.debug("Connecting to %s", socket_addr)
|
359
|
+
self.remote_socket.connect(socket_addr)
|
360
|
+
|
361
|
+
return self
|
362
|
+
|
363
|
+
def wait_until_ready(self):
|
364
|
+
"""This is a collective operation. All processes (including the
|
365
|
+
readers and the writer) should call this function.
|
366
|
+
"""
|
367
|
+
if self._is_writer:
|
368
|
+
# wait for all readers to connect
|
369
|
+
|
370
|
+
# local readers
|
371
|
+
for i in range(self.n_local_reader):
|
372
|
+
# wait for subscription messages from all local readers
|
373
|
+
self.local_socket.recv()
|
374
|
+
if self.n_local_reader > 0:
|
375
|
+
# send a message to all local readers
|
376
|
+
# to make sure the publish channel is working
|
377
|
+
self.local_socket.send(b"READY")
|
378
|
+
|
379
|
+
# remote readers
|
380
|
+
for i in range(self.n_remote_reader):
|
381
|
+
# wait for subscription messages from all remote readers
|
382
|
+
self.remote_socket.recv()
|
383
|
+
if self.n_remote_reader > 0:
|
384
|
+
# send a message to all remote readers
|
385
|
+
# to make sure the publish channel is working
|
386
|
+
self.remote_socket.send(b"READY")
|
387
|
+
elif self._is_local_reader:
|
388
|
+
# wait for the writer to send a message
|
389
|
+
recv = self.local_socket.recv()
|
390
|
+
assert recv == b"READY"
|
391
|
+
elif self._is_remote_reader:
|
392
|
+
# wait for the writer to send a message
|
393
|
+
recv = self.remote_socket.recv()
|
394
|
+
assert recv == b"READY"
|
395
|
+
|
396
|
+
@contextmanager
|
397
|
+
def acquire_write(self):
|
398
|
+
assert self._is_writer, "Only writers can acquire write"
|
399
|
+
start_time = time.monotonic()
|
400
|
+
n_warning = 1
|
401
|
+
while True:
|
402
|
+
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
403
|
+
read_count = sum(metadata_buffer[1:])
|
404
|
+
written_flag = metadata_buffer[0]
|
405
|
+
if written_flag and read_count != self.buffer.n_reader:
|
406
|
+
# this block is written and not read by all readers
|
407
|
+
# for writers, `self.current_idx` is the next block to write
|
408
|
+
# if this block is not ready to write,
|
409
|
+
# we need to wait until it is read by all readers
|
410
|
+
|
411
|
+
# Release the processor to other threads
|
412
|
+
os.sched_yield()
|
413
|
+
|
414
|
+
# if we wait for a long time, we should warn the user
|
415
|
+
if (
|
416
|
+
time.monotonic() - start_time
|
417
|
+
> SGLANG_RINGBUFFER_WARNING_INTERVAL * n_warning
|
418
|
+
):
|
419
|
+
logger.warning(
|
420
|
+
"No available block found in %s second. ",
|
421
|
+
SGLANG_RINGBUFFER_WARNING_INTERVAL,
|
422
|
+
)
|
423
|
+
n_warning += 1
|
424
|
+
|
425
|
+
continue
|
426
|
+
# found a block that is either
|
427
|
+
# (1) not written
|
428
|
+
# (2) read by all readers
|
429
|
+
|
430
|
+
# mark the block as not written
|
431
|
+
metadata_buffer[0] = 0
|
432
|
+
# let caller write to the buffer
|
433
|
+
with self.buffer.get_data(self.current_idx) as buf:
|
434
|
+
yield buf
|
435
|
+
|
436
|
+
# caller has written to the buffer
|
437
|
+
# NOTE: order is important here
|
438
|
+
# first set the read flags to 0
|
439
|
+
# then set the written flag to 1
|
440
|
+
# otherwise, the readers may think they already read the block
|
441
|
+
for i in range(1, self.buffer.n_reader + 1):
|
442
|
+
# set read flag to 0, meaning it is not read yet
|
443
|
+
metadata_buffer[i] = 0
|
444
|
+
# mark the block as written
|
445
|
+
metadata_buffer[0] = 1
|
446
|
+
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
447
|
+
break
|
448
|
+
|
449
|
+
@contextmanager
|
450
|
+
def acquire_read(self):
|
451
|
+
assert self._is_local_reader, "Only readers can acquire read"
|
452
|
+
start_time = time.monotonic()
|
453
|
+
n_warning = 1
|
454
|
+
while True:
|
455
|
+
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
456
|
+
read_flag = metadata_buffer[self.local_reader_rank + 1]
|
457
|
+
written_flag = metadata_buffer[0]
|
458
|
+
if not written_flag or read_flag:
|
459
|
+
# this block is either
|
460
|
+
# (1) not written
|
461
|
+
# (2) already read by this reader
|
462
|
+
|
463
|
+
# for readers, `self.current_idx` is the next block to read
|
464
|
+
# if this block is not ready,
|
465
|
+
# we need to wait until it is written
|
466
|
+
|
467
|
+
# Release the processor to other threads
|
468
|
+
os.sched_yield()
|
469
|
+
|
470
|
+
# if we wait for a long time, we should warn the user
|
471
|
+
if (
|
472
|
+
time.monotonic() - start_time
|
473
|
+
> SGLANG_RINGBUFFER_WARNING_INTERVAL * n_warning
|
474
|
+
):
|
475
|
+
logger.warning(
|
476
|
+
"No available block found in %s second. ",
|
477
|
+
SGLANG_RINGBUFFER_WARNING_INTERVAL,
|
478
|
+
)
|
479
|
+
n_warning += 1
|
480
|
+
|
481
|
+
continue
|
482
|
+
# found a block that is not read by this reader
|
483
|
+
# let caller read from the buffer
|
484
|
+
with self.buffer.get_data(self.current_idx) as buf:
|
485
|
+
yield buf
|
486
|
+
|
487
|
+
# caller has read from the buffer
|
488
|
+
# set the read flag
|
489
|
+
metadata_buffer[self.local_reader_rank + 1] = 1
|
490
|
+
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
491
|
+
break
|
492
|
+
|
493
|
+
def enqueue(self, obj):
|
494
|
+
assert self._is_writer, "Only writers can enqueue"
|
495
|
+
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
496
|
+
if self.n_local_reader > 0:
|
497
|
+
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
|
498
|
+
with self.acquire_write() as buf:
|
499
|
+
buf[0] = 1 # overflow
|
500
|
+
self.local_socket.send(serialized_obj)
|
501
|
+
else:
|
502
|
+
with self.acquire_write() as buf:
|
503
|
+
buf[0] = 0 # not overflow
|
504
|
+
buf[1 : len(serialized_obj) + 1] = serialized_obj
|
505
|
+
if self.n_remote_reader > 0:
|
506
|
+
self.remote_socket.send(serialized_obj)
|
507
|
+
|
508
|
+
def dequeue(self):
|
509
|
+
if self._is_local_reader:
|
510
|
+
with self.acquire_read() as buf:
|
511
|
+
overflow = buf[0] == 1
|
512
|
+
if not overflow:
|
513
|
+
# no need to know the size of serialized object
|
514
|
+
# pickle format contains the size information internally
|
515
|
+
# see https://docs.python.org/3/library/pickle.html
|
516
|
+
obj = pickle.loads(buf[1:])
|
517
|
+
if overflow:
|
518
|
+
recv = self.local_socket.recv()
|
519
|
+
obj = pickle.loads(recv)
|
520
|
+
elif self._is_remote_reader:
|
521
|
+
recv = self.remote_socket.recv()
|
522
|
+
obj = pickle.loads(recv)
|
523
|
+
else:
|
524
|
+
raise RuntimeError("Only readers can dequeue")
|
525
|
+
return obj
|
526
|
+
|
527
|
+
def broadcast_object(self, obj=None):
|
528
|
+
if self._is_writer:
|
529
|
+
self.enqueue(obj)
|
530
|
+
return obj
|
531
|
+
else:
|
532
|
+
return self.dequeue()
|
533
|
+
|
534
|
+
@staticmethod
|
535
|
+
def create_from_process_group(
|
536
|
+
pg: ProcessGroup, max_chunk_bytes, max_chunks, writer_rank=0
|
537
|
+
) -> "MessageQueue":
|
538
|
+
group_rank = dist.get_rank(pg)
|
539
|
+
group_world_size = dist.get_world_size(pg)
|
540
|
+
global_ranks = dist.get_process_group_ranks(pg)
|
541
|
+
|
542
|
+
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
543
|
+
|
544
|
+
status = in_the_same_node_as(pg, source_rank=writer_rank)
|
545
|
+
same_node_ranks = [i for i, s in enumerate(status) if s]
|
546
|
+
n_reader = group_world_size - 1
|
547
|
+
n_local_reader = len(same_node_ranks) - 1
|
548
|
+
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
|
549
|
+
buffer_io: MessageQueue
|
550
|
+
if group_rank == writer_rank:
|
551
|
+
buffer_io = MessageQueue(
|
552
|
+
n_reader=n_reader,
|
553
|
+
n_local_reader=n_local_reader,
|
554
|
+
local_reader_ranks=local_reader_ranks,
|
555
|
+
max_chunk_bytes=max_chunk_bytes,
|
556
|
+
max_chunks=max_chunks,
|
557
|
+
)
|
558
|
+
handle = buffer_io.export_handle()
|
559
|
+
dist.broadcast_object_list(
|
560
|
+
[handle], src=global_ranks[writer_rank], group=pg
|
561
|
+
)
|
562
|
+
else:
|
563
|
+
recv = [None]
|
564
|
+
dist.broadcast_object_list(recv, src=global_ranks[writer_rank], group=pg)
|
565
|
+
handle = recv[0] # type: ignore
|
566
|
+
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
|
567
|
+
buffer_io.wait_until_ready()
|
568
|
+
return buffer_io
|
@@ -0,0 +1,47 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/xpu_communicator.py
|
2
|
+
import torch
|
3
|
+
import torch.distributed as dist
|
4
|
+
from torch.distributed import ProcessGroup
|
5
|
+
|
6
|
+
from sglang.srt.utils import is_xpu
|
7
|
+
|
8
|
+
|
9
|
+
class XpuCommunicator:
|
10
|
+
|
11
|
+
def __init__(self, group: ProcessGroup):
|
12
|
+
if not is_xpu():
|
13
|
+
self.disabled = True
|
14
|
+
return
|
15
|
+
self.disabled = False
|
16
|
+
self.group = group
|
17
|
+
self.world_size = dist.get_world_size(self.group)
|
18
|
+
|
19
|
+
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
20
|
+
dist.all_reduce(x, group=self.group)
|
21
|
+
return x
|
22
|
+
|
23
|
+
def gather(
|
24
|
+
self, input_: torch.Tensor, rank_in_group: int, dst: int = 0, dim: int = -1
|
25
|
+
):
|
26
|
+
# For xpu path, gather doesn't work properly together with ray
|
27
|
+
# cluster so we use all_gather instead for now.
|
28
|
+
input_size = input_.size()
|
29
|
+
# Allocate output tensor.
|
30
|
+
output_tensor = torch.empty(
|
31
|
+
(self.world_size,) + input_size, dtype=input_.dtype, device=input_.device
|
32
|
+
)
|
33
|
+
# All-gather.
|
34
|
+
torch.distributed.all_gather_into_tensor(
|
35
|
+
output_tensor, input_, group=self.group
|
36
|
+
)
|
37
|
+
if rank_in_group == dst:
|
38
|
+
# Reshape
|
39
|
+
output_tensor = output_tensor.movedim(0, dim)
|
40
|
+
output_tensor = output_tensor.reshape(
|
41
|
+
input_size[:dim]
|
42
|
+
+ (self.world_size * input_size[dim],)
|
43
|
+
+ input_size[dim + 1 :]
|
44
|
+
)
|
45
|
+
else:
|
46
|
+
output_tensor = None
|
47
|
+
return output_tensor
|