sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_one_batch.py +4 -0
  3. sglang/bench_serving.py +13 -0
  4. sglang/check_env.py +1 -1
  5. sglang/srt/_custom_ops.py +118 -0
  6. sglang/srt/configs/device_config.py +17 -0
  7. sglang/srt/configs/load_config.py +84 -0
  8. sglang/srt/configs/model_config.py +161 -4
  9. sglang/srt/configs/qwen2vl.py +5 -8
  10. sglang/srt/constrained/outlines_backend.py +11 -1
  11. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  12. sglang/srt/constrained/xgrammar_backend.py +5 -5
  13. sglang/srt/distributed/__init__.py +3 -0
  14. sglang/srt/distributed/communication_op.py +34 -0
  15. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  16. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  19. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  20. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  21. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  23. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  24. sglang/srt/distributed/parallel_state.py +1275 -0
  25. sglang/srt/distributed/utils.py +223 -0
  26. sglang/srt/hf_transformers_utils.py +37 -1
  27. sglang/srt/layers/attention/__init__.py +5 -2
  28. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  29. sglang/srt/layers/attention/flashinfer_backend.py +33 -20
  30. sglang/srt/layers/attention/torch_native_backend.py +299 -0
  31. sglang/srt/layers/attention/triton_backend.py +22 -8
  32. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  33. sglang/srt/layers/ep_moe/__init__.py +0 -0
  34. sglang/srt/layers/ep_moe/kernels.py +349 -0
  35. sglang/srt/layers/ep_moe/layer.py +661 -0
  36. sglang/srt/layers/fused_moe_patch.py +20 -11
  37. sglang/srt/layers/linear.py +1 -0
  38. sglang/srt/layers/logits_processor.py +17 -3
  39. sglang/srt/layers/quantization/__init__.py +36 -2
  40. sglang/srt/layers/quantization/fp8.py +559 -0
  41. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  42. sglang/srt/layers/radix_attention.py +4 -2
  43. sglang/srt/layers/sampler.py +2 -0
  44. sglang/srt/layers/torchao_utils.py +23 -45
  45. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  46. sglang/srt/lora/lora.py +1 -1
  47. sglang/srt/managers/io_struct.py +48 -2
  48. sglang/srt/managers/schedule_batch.py +19 -14
  49. sglang/srt/managers/schedule_policy.py +7 -4
  50. sglang/srt/managers/scheduler.py +145 -85
  51. sglang/srt/managers/tokenizer_manager.py +166 -68
  52. sglang/srt/managers/tp_worker.py +36 -3
  53. sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
  54. sglang/srt/mem_cache/memory_pool.py +5 -1
  55. sglang/srt/model_executor/cuda_graph_runner.py +30 -7
  56. sglang/srt/model_executor/forward_batch_info.py +9 -4
  57. sglang/srt/model_executor/model_runner.py +146 -153
  58. sglang/srt/model_loader/__init__.py +34 -0
  59. sglang/srt/model_loader/loader.py +1139 -0
  60. sglang/srt/model_loader/utils.py +41 -0
  61. sglang/srt/model_loader/weight_utils.py +640 -0
  62. sglang/srt/model_parallel.py +1 -5
  63. sglang/srt/models/baichuan.py +9 -10
  64. sglang/srt/models/chatglm.py +6 -15
  65. sglang/srt/models/commandr.py +4 -5
  66. sglang/srt/models/dbrx.py +2 -3
  67. sglang/srt/models/deepseek.py +4 -11
  68. sglang/srt/models/deepseek_v2.py +90 -18
  69. sglang/srt/models/exaone.py +2 -3
  70. sglang/srt/models/gemma.py +2 -6
  71. sglang/srt/models/gemma2.py +3 -14
  72. sglang/srt/models/gemma2_reward.py +0 -1
  73. sglang/srt/models/gpt2.py +5 -12
  74. sglang/srt/models/gpt_bigcode.py +6 -22
  75. sglang/srt/models/grok.py +3 -8
  76. sglang/srt/models/internlm2.py +2 -3
  77. sglang/srt/models/internlm2_reward.py +0 -1
  78. sglang/srt/models/llama.py +96 -31
  79. sglang/srt/models/llama_classification.py +1 -2
  80. sglang/srt/models/llama_embedding.py +1 -2
  81. sglang/srt/models/llama_reward.py +2 -3
  82. sglang/srt/models/llava.py +1 -4
  83. sglang/srt/models/llavavid.py +1 -2
  84. sglang/srt/models/minicpm.py +4 -7
  85. sglang/srt/models/minicpm3.py +6 -19
  86. sglang/srt/models/mixtral.py +24 -14
  87. sglang/srt/models/mixtral_quant.py +2 -3
  88. sglang/srt/models/mllama.py +3 -7
  89. sglang/srt/models/olmo.py +2 -8
  90. sglang/srt/models/olmo2.py +0 -1
  91. sglang/srt/models/olmoe.py +3 -5
  92. sglang/srt/models/phi3_small.py +8 -13
  93. sglang/srt/models/qwen.py +2 -3
  94. sglang/srt/models/qwen2.py +10 -9
  95. sglang/srt/models/qwen2_moe.py +4 -16
  96. sglang/srt/models/qwen2_vl.py +2 -6
  97. sglang/srt/models/registry.py +99 -0
  98. sglang/srt/models/stablelm.py +2 -3
  99. sglang/srt/models/torch_native_llama.py +6 -17
  100. sglang/srt/models/xverse.py +2 -4
  101. sglang/srt/models/xverse_moe.py +4 -11
  102. sglang/srt/models/yivl.py +2 -3
  103. sglang/srt/openai_api/adapter.py +9 -5
  104. sglang/srt/openai_api/protocol.py +1 -0
  105. sglang/srt/sampling/sampling_batch_info.py +9 -8
  106. sglang/srt/server.py +270 -173
  107. sglang/srt/server_args.py +102 -29
  108. sglang/srt/utils.py +295 -28
  109. sglang/test/test_utils.py +7 -0
  110. sglang/version.py +1 -1
  111. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  112. sglang-0.4.0.post1.dist-info/RECORD +189 -0
  113. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  114. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  115. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  116. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,568 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/shm_broadcast.py
2
+ import ipaddress
3
+ import logging
4
+ import os
5
+ import pickle
6
+ import socket
7
+ import time
8
+ import warnings
9
+ from contextlib import contextmanager
10
+ from dataclasses import dataclass, field
11
+ from multiprocessing import shared_memory
12
+ from typing import List, Optional
13
+ from unittest.mock import patch
14
+
15
+ import torch
16
+ import torch.distributed as dist
17
+ from torch.distributed import ProcessGroup
18
+ from zmq import IPV6 # type: ignore
19
+ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
20
+
21
+ # SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
22
+ SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
23
+ os.environ.get("SGLANG_RINGBUFFER_WARNING_INTERVAL", "60")
24
+ )
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def get_ip() -> str:
30
+ # SGLANG_HOST_IP env can be ignore
31
+ host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
32
+ if host_ip:
33
+ return host_ip
34
+
35
+ # IP is not set, try to get it from the network interface
36
+
37
+ # try ipv4
38
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
39
+ try:
40
+ s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
41
+ return s.getsockname()[0]
42
+ except Exception:
43
+ pass
44
+
45
+ # try ipv6
46
+ try:
47
+ s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
48
+ # Google's public DNS server, see
49
+ # https://developers.google.com/speed/public-dns/docs/using#addresses
50
+ s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
51
+ return s.getsockname()[0]
52
+ except Exception:
53
+ pass
54
+
55
+ warnings.warn(
56
+ "Failed to get the IP address, using 0.0.0.0 by default."
57
+ "The value can be set by the environment variable"
58
+ " SGLANG_HOST_IP or HOST_IP.",
59
+ stacklevel=2,
60
+ )
61
+ return "0.0.0.0"
62
+
63
+
64
+ def get_open_port() -> int:
65
+
66
+ port = os.getenv("SGLANG_PORT")
67
+ if port is not None:
68
+ while True:
69
+ try:
70
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
71
+ s.bind(("", port))
72
+ return port
73
+ except OSError:
74
+ port += 1 # Increment port number if already in use
75
+ logger.info("Port %d is already in use, trying port %d", port - 1, port)
76
+ # try ipv4
77
+ try:
78
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
79
+ s.bind(("", 0))
80
+ return s.getsockname()[1]
81
+ except OSError:
82
+ # try ipv6
83
+ with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
84
+ s.bind(("", 0))
85
+ return s.getsockname()[1]
86
+
87
+
88
+ def is_valid_ipv6_address(address: str) -> bool:
89
+ try:
90
+ ipaddress.IPv6Address(address)
91
+ return True
92
+ except ValueError:
93
+ return False
94
+
95
+
96
+ class ShmRingBuffer:
97
+
98
+ def __init__(
99
+ self,
100
+ n_reader: int,
101
+ max_chunk_bytes: int,
102
+ max_chunks: int,
103
+ name: Optional[str] = None,
104
+ ):
105
+ """
106
+ A shared memory ring buffer implementation for broadcast communication.
107
+ Essentially, it is a queue where only one will `enqueue` and multiple
108
+ will `dequeue`. The max size of each item, together with the max number
109
+ of items that can be stored in the buffer are known in advance.
110
+ In this case, we don't need to synchronize the access to
111
+ the buffer.
112
+
113
+ Buffer memory layout:
114
+ data metadata
115
+ | |
116
+ | (current_idx) | (current_idx)
117
+ v v
118
+ +-------------------------------+----------------------------------------+
119
+ | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
120
+ +-------------------------------+----------------------------------------+
121
+ | max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
122
+
123
+ metadata memory layout: each byte is a flag, the first byte is the written
124
+ flag, and the rest are reader flags. The flags are set to 0 by default.
125
+ +--------------+--------------+--------------+-----+--------------+
126
+ | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
127
+ +--------------+--------------+--------------+-----+--------------+
128
+
129
+ The state of metadata is as follows:
130
+
131
+ (case 1) 0???...???: the block is not written yet, cannot read, can write
132
+ (case 2) 1000...000: the block is just written, can read, cannot write
133
+ (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
134
+ (case 4) 1111...111: the block is written and read by all readers, cannot read, can write
135
+
136
+ State transition for readers:
137
+
138
+ When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
139
+ Only after the caller finishes reading the block, the reader can mark the block as read.
140
+ Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).
141
+
142
+ State transition for writer:
143
+
144
+ When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
145
+ to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
146
+ can reset the reader flags to 0, and mark the block as written (from 0 to 1).
147
+ NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
148
+
149
+ During creation, `name` is None and the buffer is created. We can pass the
150
+ created object to other processes by pickling it. The other processes will
151
+ get the name of the shared memory and open it, so that they can access the
152
+ same shared memory buffer.
153
+ """ # noqa
154
+ self.n_reader = n_reader
155
+ self.metadata_size = 1 + n_reader
156
+ self.max_chunk_bytes = max_chunk_bytes
157
+ self.max_chunks = max_chunks
158
+ self.total_bytes_of_buffer = (
159
+ self.max_chunk_bytes + self.metadata_size
160
+ ) * self.max_chunks
161
+ self.data_offset = 0
162
+ self.metadata_offset = self.max_chunk_bytes * self.max_chunks
163
+
164
+ if name is None:
165
+ # we are creating a buffer
166
+ self.is_creator = True
167
+ self.shared_memory = shared_memory.SharedMemory(
168
+ create=True, size=self.total_bytes_of_buffer
169
+ )
170
+ # initialize the metadata section to 0
171
+ with memoryview(
172
+ self.shared_memory.buf[self.metadata_offset :]
173
+ ) as metadata_buffer:
174
+ torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
175
+ else:
176
+ # we are opening an existing buffer
177
+ self.is_creator = False
178
+ # fix to https://stackoverflow.com/q/62748654/9191338
179
+ # Python incorrectly tracks shared memory even if it is not
180
+ # created by the process. The following patch is a workaround.
181
+ with patch(
182
+ "multiprocessing.resource_tracker.register",
183
+ lambda *args, **kwargs: None,
184
+ ):
185
+ try:
186
+ self.shared_memory = shared_memory.SharedMemory(name=name)
187
+ assert self.shared_memory.size == self.total_bytes_of_buffer
188
+ except FileNotFoundError:
189
+ # we might deserialize the object in a different node
190
+ # in this case, this object is not used,
191
+ # and we should suppress the error
192
+ pass
193
+
194
+ def __reduce__(self):
195
+ return (
196
+ self.__class__,
197
+ (
198
+ self.n_reader,
199
+ self.max_chunk_bytes,
200
+ self.max_chunks,
201
+ self.shared_memory.name,
202
+ ),
203
+ )
204
+
205
+ def __del__(self):
206
+ if hasattr(self, "shared_memory"):
207
+ self.shared_memory.close()
208
+ if self.is_creator:
209
+ self.shared_memory.unlink()
210
+
211
+ @contextmanager
212
+ def get_data(self, current_idx: int):
213
+ start = self.data_offset + current_idx * self.max_chunk_bytes
214
+ end = start + self.max_chunk_bytes
215
+ with memoryview(self.shared_memory.buf[start:end]) as buf:
216
+ yield buf
217
+
218
+ @contextmanager
219
+ def get_metadata(self, current_idx: int):
220
+ start = self.metadata_offset + current_idx * self.metadata_size
221
+ end = start + self.metadata_size
222
+ with memoryview(self.shared_memory.buf[start:end]) as buf:
223
+ yield buf
224
+
225
+
226
+ @dataclass
227
+ class Handle:
228
+ connect_ip: str
229
+ local_reader_ranks: List[int] = field(default_factory=list)
230
+
231
+ buffer: Optional[ShmRingBuffer] = None
232
+ local_subscribe_port: Optional[int] = None
233
+ remote_subscribe_port: Optional[int] = None
234
+
235
+
236
+ class MessageQueue:
237
+
238
+ def __init__(
239
+ self,
240
+ n_reader, # number of all readers
241
+ n_local_reader, # number of local readers through shared memory
242
+ local_reader_ranks: Optional[List[int]] = None,
243
+ max_chunk_bytes: int = 1024 * 1024 * 10,
244
+ max_chunks: int = 10,
245
+ connect_ip: Optional[str] = None,
246
+ ):
247
+ if local_reader_ranks is None:
248
+ local_reader_ranks = list(range(n_local_reader))
249
+ else:
250
+ assert len(local_reader_ranks) == n_local_reader
251
+ self.n_local_reader = n_local_reader
252
+ n_remote_reader = n_reader - n_local_reader
253
+ self.n_remote_reader = n_remote_reader
254
+
255
+ if connect_ip is None:
256
+ connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
257
+
258
+ context = Context()
259
+
260
+ if n_local_reader > 0:
261
+ # for local readers, we will:
262
+ # 1. create a shared memory ring buffer to communicate small data
263
+ # 2. create a publish-subscribe socket to communicate large data
264
+ self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks)
265
+
266
+ # XPUB is very similar to PUB,
267
+ # except that it can receive subscription messages
268
+ # to confirm the number of subscribers
269
+ self.local_socket = context.socket(XPUB)
270
+ # set the verbose option so that we can receive every subscription
271
+ # message. otherwise, we will only receive the first subscription
272
+ # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
273
+ self.local_socket.setsockopt(XPUB_VERBOSE, True)
274
+ local_subscribe_port = get_open_port()
275
+ socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}"
276
+ logger.debug("Binding to %s", socket_addr)
277
+ self.local_socket.bind(socket_addr)
278
+
279
+ self.current_idx = 0
280
+
281
+ else:
282
+ self.buffer = None # type: ignore
283
+ local_subscribe_port = None
284
+ self.local_socket = None
285
+ self.current_idx = -1
286
+
287
+ if n_remote_reader > 0:
288
+ # for remote readers, we will:
289
+ # create a publish-subscribe socket to communicate large data
290
+ self.remote_socket = context.socket(XPUB)
291
+ self.remote_socket.setsockopt(XPUB_VERBOSE, True)
292
+ remote_subscribe_port = get_open_port()
293
+ if is_valid_ipv6_address(connect_ip):
294
+ self.remote_socket.setsockopt(IPV6, 1)
295
+ socket_addr = f"tcp://*:{remote_subscribe_port}"
296
+ self.remote_socket.bind(socket_addr)
297
+
298
+ else:
299
+ remote_subscribe_port = None
300
+ self.remote_socket = None
301
+
302
+ self._is_writer = True
303
+ self._is_local_reader = False
304
+ self.local_reader_rank = -1
305
+ # rank does not matter for remote readers
306
+ self._is_remote_reader = False
307
+
308
+ self.handle = Handle(
309
+ connect_ip=connect_ip,
310
+ local_reader_ranks=local_reader_ranks,
311
+ buffer=self.buffer,
312
+ local_subscribe_port=local_subscribe_port,
313
+ remote_subscribe_port=remote_subscribe_port,
314
+ )
315
+
316
+ logger.info("vLLM message queue communication handle: %s", self.handle)
317
+
318
+ def export_handle(self) -> Handle:
319
+ return self.handle
320
+
321
+ @staticmethod
322
+ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
323
+ self = MessageQueue.__new__(MessageQueue)
324
+ self.handle = handle
325
+ self._is_writer = False
326
+
327
+ context = Context()
328
+
329
+ if rank in handle.local_reader_ranks:
330
+ assert handle.buffer is not None
331
+ self.buffer = handle.buffer
332
+ self.current_idx = 0
333
+ self.local_reader_rank = handle.local_reader_ranks.index(rank)
334
+ self._is_local_reader = True
335
+ self._is_remote_reader = False
336
+
337
+ self.local_socket = context.socket(SUB)
338
+ self.local_socket.setsockopt_string(SUBSCRIBE, "")
339
+ socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}"
340
+ logger.debug("Connecting to %s", socket_addr)
341
+ self.local_socket.connect(socket_addr)
342
+
343
+ self.remote_socket = None
344
+ else:
345
+ self.buffer = None # type: ignore
346
+ self.current_idx = -1
347
+ self.local_reader_rank = -1
348
+ self._is_local_reader = False
349
+ self._is_remote_reader = True
350
+
351
+ self.local_socket = None
352
+
353
+ self.remote_socket = context.socket(SUB)
354
+ self.remote_socket.setsockopt_string(SUBSCRIBE, "")
355
+ if is_valid_ipv6_address(handle.connect_ip):
356
+ self.remote_socket.setsockopt(IPV6, 1)
357
+ socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
358
+ logger.debug("Connecting to %s", socket_addr)
359
+ self.remote_socket.connect(socket_addr)
360
+
361
+ return self
362
+
363
+ def wait_until_ready(self):
364
+ """This is a collective operation. All processes (including the
365
+ readers and the writer) should call this function.
366
+ """
367
+ if self._is_writer:
368
+ # wait for all readers to connect
369
+
370
+ # local readers
371
+ for i in range(self.n_local_reader):
372
+ # wait for subscription messages from all local readers
373
+ self.local_socket.recv()
374
+ if self.n_local_reader > 0:
375
+ # send a message to all local readers
376
+ # to make sure the publish channel is working
377
+ self.local_socket.send(b"READY")
378
+
379
+ # remote readers
380
+ for i in range(self.n_remote_reader):
381
+ # wait for subscription messages from all remote readers
382
+ self.remote_socket.recv()
383
+ if self.n_remote_reader > 0:
384
+ # send a message to all remote readers
385
+ # to make sure the publish channel is working
386
+ self.remote_socket.send(b"READY")
387
+ elif self._is_local_reader:
388
+ # wait for the writer to send a message
389
+ recv = self.local_socket.recv()
390
+ assert recv == b"READY"
391
+ elif self._is_remote_reader:
392
+ # wait for the writer to send a message
393
+ recv = self.remote_socket.recv()
394
+ assert recv == b"READY"
395
+
396
+ @contextmanager
397
+ def acquire_write(self):
398
+ assert self._is_writer, "Only writers can acquire write"
399
+ start_time = time.monotonic()
400
+ n_warning = 1
401
+ while True:
402
+ with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
403
+ read_count = sum(metadata_buffer[1:])
404
+ written_flag = metadata_buffer[0]
405
+ if written_flag and read_count != self.buffer.n_reader:
406
+ # this block is written and not read by all readers
407
+ # for writers, `self.current_idx` is the next block to write
408
+ # if this block is not ready to write,
409
+ # we need to wait until it is read by all readers
410
+
411
+ # Release the processor to other threads
412
+ os.sched_yield()
413
+
414
+ # if we wait for a long time, we should warn the user
415
+ if (
416
+ time.monotonic() - start_time
417
+ > SGLANG_RINGBUFFER_WARNING_INTERVAL * n_warning
418
+ ):
419
+ logger.warning(
420
+ "No available block found in %s second. ",
421
+ SGLANG_RINGBUFFER_WARNING_INTERVAL,
422
+ )
423
+ n_warning += 1
424
+
425
+ continue
426
+ # found a block that is either
427
+ # (1) not written
428
+ # (2) read by all readers
429
+
430
+ # mark the block as not written
431
+ metadata_buffer[0] = 0
432
+ # let caller write to the buffer
433
+ with self.buffer.get_data(self.current_idx) as buf:
434
+ yield buf
435
+
436
+ # caller has written to the buffer
437
+ # NOTE: order is important here
438
+ # first set the read flags to 0
439
+ # then set the written flag to 1
440
+ # otherwise, the readers may think they already read the block
441
+ for i in range(1, self.buffer.n_reader + 1):
442
+ # set read flag to 0, meaning it is not read yet
443
+ metadata_buffer[i] = 0
444
+ # mark the block as written
445
+ metadata_buffer[0] = 1
446
+ self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
447
+ break
448
+
449
+ @contextmanager
450
+ def acquire_read(self):
451
+ assert self._is_local_reader, "Only readers can acquire read"
452
+ start_time = time.monotonic()
453
+ n_warning = 1
454
+ while True:
455
+ with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
456
+ read_flag = metadata_buffer[self.local_reader_rank + 1]
457
+ written_flag = metadata_buffer[0]
458
+ if not written_flag or read_flag:
459
+ # this block is either
460
+ # (1) not written
461
+ # (2) already read by this reader
462
+
463
+ # for readers, `self.current_idx` is the next block to read
464
+ # if this block is not ready,
465
+ # we need to wait until it is written
466
+
467
+ # Release the processor to other threads
468
+ os.sched_yield()
469
+
470
+ # if we wait for a long time, we should warn the user
471
+ if (
472
+ time.monotonic() - start_time
473
+ > SGLANG_RINGBUFFER_WARNING_INTERVAL * n_warning
474
+ ):
475
+ logger.warning(
476
+ "No available block found in %s second. ",
477
+ SGLANG_RINGBUFFER_WARNING_INTERVAL,
478
+ )
479
+ n_warning += 1
480
+
481
+ continue
482
+ # found a block that is not read by this reader
483
+ # let caller read from the buffer
484
+ with self.buffer.get_data(self.current_idx) as buf:
485
+ yield buf
486
+
487
+ # caller has read from the buffer
488
+ # set the read flag
489
+ metadata_buffer[self.local_reader_rank + 1] = 1
490
+ self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
491
+ break
492
+
493
+ def enqueue(self, obj):
494
+ assert self._is_writer, "Only writers can enqueue"
495
+ serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
496
+ if self.n_local_reader > 0:
497
+ if len(serialized_obj) >= self.buffer.max_chunk_bytes:
498
+ with self.acquire_write() as buf:
499
+ buf[0] = 1 # overflow
500
+ self.local_socket.send(serialized_obj)
501
+ else:
502
+ with self.acquire_write() as buf:
503
+ buf[0] = 0 # not overflow
504
+ buf[1 : len(serialized_obj) + 1] = serialized_obj
505
+ if self.n_remote_reader > 0:
506
+ self.remote_socket.send(serialized_obj)
507
+
508
+ def dequeue(self):
509
+ if self._is_local_reader:
510
+ with self.acquire_read() as buf:
511
+ overflow = buf[0] == 1
512
+ if not overflow:
513
+ # no need to know the size of serialized object
514
+ # pickle format contains the size information internally
515
+ # see https://docs.python.org/3/library/pickle.html
516
+ obj = pickle.loads(buf[1:])
517
+ if overflow:
518
+ recv = self.local_socket.recv()
519
+ obj = pickle.loads(recv)
520
+ elif self._is_remote_reader:
521
+ recv = self.remote_socket.recv()
522
+ obj = pickle.loads(recv)
523
+ else:
524
+ raise RuntimeError("Only readers can dequeue")
525
+ return obj
526
+
527
+ def broadcast_object(self, obj=None):
528
+ if self._is_writer:
529
+ self.enqueue(obj)
530
+ return obj
531
+ else:
532
+ return self.dequeue()
533
+
534
+ @staticmethod
535
+ def create_from_process_group(
536
+ pg: ProcessGroup, max_chunk_bytes, max_chunks, writer_rank=0
537
+ ) -> "MessageQueue":
538
+ group_rank = dist.get_rank(pg)
539
+ group_world_size = dist.get_world_size(pg)
540
+ global_ranks = dist.get_process_group_ranks(pg)
541
+
542
+ from sglang.srt.distributed.parallel_state import in_the_same_node_as
543
+
544
+ status = in_the_same_node_as(pg, source_rank=writer_rank)
545
+ same_node_ranks = [i for i, s in enumerate(status) if s]
546
+ n_reader = group_world_size - 1
547
+ n_local_reader = len(same_node_ranks) - 1
548
+ local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
549
+ buffer_io: MessageQueue
550
+ if group_rank == writer_rank:
551
+ buffer_io = MessageQueue(
552
+ n_reader=n_reader,
553
+ n_local_reader=n_local_reader,
554
+ local_reader_ranks=local_reader_ranks,
555
+ max_chunk_bytes=max_chunk_bytes,
556
+ max_chunks=max_chunks,
557
+ )
558
+ handle = buffer_io.export_handle()
559
+ dist.broadcast_object_list(
560
+ [handle], src=global_ranks[writer_rank], group=pg
561
+ )
562
+ else:
563
+ recv = [None]
564
+ dist.broadcast_object_list(recv, src=global_ranks[writer_rank], group=pg)
565
+ handle = recv[0] # type: ignore
566
+ buffer_io = MessageQueue.create_from_handle(handle, group_rank)
567
+ buffer_io.wait_until_ready()
568
+ return buffer_io
@@ -0,0 +1,47 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/xpu_communicator.py
2
+ import torch
3
+ import torch.distributed as dist
4
+ from torch.distributed import ProcessGroup
5
+
6
+ from sglang.srt.utils import is_xpu
7
+
8
+
9
+ class XpuCommunicator:
10
+
11
+ def __init__(self, group: ProcessGroup):
12
+ if not is_xpu():
13
+ self.disabled = True
14
+ return
15
+ self.disabled = False
16
+ self.group = group
17
+ self.world_size = dist.get_world_size(self.group)
18
+
19
+ def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
20
+ dist.all_reduce(x, group=self.group)
21
+ return x
22
+
23
+ def gather(
24
+ self, input_: torch.Tensor, rank_in_group: int, dst: int = 0, dim: int = -1
25
+ ):
26
+ # For xpu path, gather doesn't work properly together with ray
27
+ # cluster so we use all_gather instead for now.
28
+ input_size = input_.size()
29
+ # Allocate output tensor.
30
+ output_tensor = torch.empty(
31
+ (self.world_size,) + input_size, dtype=input_.dtype, device=input_.device
32
+ )
33
+ # All-gather.
34
+ torch.distributed.all_gather_into_tensor(
35
+ output_tensor, input_, group=self.group
36
+ )
37
+ if rank_in_group == dst:
38
+ # Reshape
39
+ output_tensor = output_tensor.movedim(0, dim)
40
+ output_tensor = output_tensor.reshape(
41
+ input_size[:dim]
42
+ + (self.world_size * input_size[dim],)
43
+ + input_size[dim + 1 :]
44
+ )
45
+ else:
46
+ output_tensor = None
47
+ return output_tensor