sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/conversation.py +6 -0
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +196 -51
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +18 -13
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +128 -43
- sglang/srt/disaggregation/utils.py +127 -123
- sglang/srt/entrypoints/engine.py +15 -1
- sglang/srt/entrypoints/http_server.py +13 -2
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/layers/activation.py +19 -0
- sglang/srt/layers/attention/aiter_backend.py +15 -2
- sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
- sglang/srt/layers/attention/flashattention_backend.py +53 -64
- sglang/srt/layers/attention/flashinfer_backend.py +1 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
- sglang/srt/layers/attention/flashmla_backend.py +2 -10
- sglang/srt/layers/attention/triton_backend.py +119 -119
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +23 -5
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +0 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
- sglang/srt/layers/moe/ep_moe/layer.py +42 -32
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
- sglang/srt/layers/moe/topk.py +16 -8
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/lora/lora_manager.py +79 -34
- sglang/srt/lora/mem_pool.py +4 -5
- sglang/srt/managers/cache_controller.py +2 -1
- sglang/srt/managers/io_struct.py +28 -4
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +39 -6
- sglang/srt/managers/scheduler.py +73 -17
- sglang/srt/managers/tokenizer_manager.py +29 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +4 -2
- sglang/srt/mem_cache/memory_pool.py +111 -407
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +36 -12
- sglang/srt/model_executor/cuda_graph_runner.py +122 -55
- sglang/srt/model_executor/forward_batch_info.py +14 -5
- sglang/srt/model_executor/model_runner.py +6 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_v2.py +113 -155
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/openai_api/adapter.py +162 -4
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +318 -233
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
- sglang/srt/speculative/eagle_utils.py +389 -109
- sglang/srt/speculative/eagle_worker.py +134 -43
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +58 -0
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -31,10 +31,10 @@ import numpy as np
|
|
31
31
|
import torch
|
32
32
|
from torch.distributed import ProcessGroup
|
33
33
|
|
34
|
-
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver,
|
34
|
+
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
|
35
35
|
from sglang.srt.disaggregation.utils import (
|
36
|
+
FAKE_BOOTSTRAP_HOST,
|
36
37
|
DisaggregationMode,
|
37
|
-
FakeBootstrapHost,
|
38
38
|
KVClassType,
|
39
39
|
MetadataBuffers,
|
40
40
|
ReqToMetadataIdxAllocator,
|
@@ -47,7 +47,11 @@ from sglang.srt.disaggregation.utils import (
|
|
47
47
|
)
|
48
48
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
49
49
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
50
|
-
from sglang.srt.mem_cache.memory_pool import
|
50
|
+
from sglang.srt.mem_cache.memory_pool import (
|
51
|
+
KVCache,
|
52
|
+
ReqToTokenPool,
|
53
|
+
TokenToKVPoolAllocator,
|
54
|
+
)
|
51
55
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
52
56
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
53
57
|
|
@@ -145,7 +149,12 @@ class DecodePreallocQueue:
|
|
145
149
|
gloo_group: ProcessGroup,
|
146
150
|
tp_rank: int,
|
147
151
|
tp_size: int,
|
152
|
+
dp_size: int,
|
153
|
+
gpu_id: int,
|
148
154
|
bootstrap_port: int,
|
155
|
+
max_total_num_tokens: int,
|
156
|
+
prefill_pp_size: int,
|
157
|
+
num_reserved_decode_tokens: int,
|
149
158
|
transfer_backend: TransferBackend,
|
150
159
|
):
|
151
160
|
self.req_to_token_pool = req_to_token_pool
|
@@ -161,25 +170,33 @@ class DecodePreallocQueue:
|
|
161
170
|
self.gloo_group = gloo_group
|
162
171
|
self.tp_rank = tp_rank
|
163
172
|
self.tp_size = tp_size
|
173
|
+
self.dp_size = dp_size
|
174
|
+
self.gpu_id = gpu_id
|
164
175
|
self.bootstrap_port = bootstrap_port
|
165
|
-
|
166
|
-
self.
|
167
|
-
|
168
|
-
|
169
|
-
|
176
|
+
self.max_total_num_tokens = max_total_num_tokens
|
177
|
+
self.prefill_pp_size = prefill_pp_size
|
178
|
+
self.num_reserved_decode_tokens = num_reserved_decode_tokens
|
179
|
+
self.transfer_backend = transfer_backend
|
170
180
|
# Queue for requests pending pre-allocation
|
171
181
|
self.queue: List[DecodeRequest] = []
|
172
|
-
self.
|
182
|
+
self.retracted_queue: List[Req] = []
|
183
|
+
self.prefill_pp_size = prefill_pp_size
|
173
184
|
self.kv_manager = self._init_kv_manager()
|
174
185
|
|
175
186
|
def _init_kv_manager(self) -> BaseKVManager:
|
176
|
-
|
177
|
-
kv_args
|
187
|
+
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
|
188
|
+
kv_args = kv_args_class()
|
189
|
+
|
190
|
+
attn_tp_size = self.tp_size // self.dp_size
|
191
|
+
kv_args.engine_rank = self.tp_rank % (attn_tp_size)
|
192
|
+
kv_args.decode_tp_size = attn_tp_size
|
193
|
+
kv_args.prefill_pp_size = self.prefill_pp_size
|
178
194
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
179
195
|
self.token_to_kv_pool.get_contiguous_buf_infos()
|
180
196
|
)
|
181
|
-
|
182
197
|
if self.draft_token_to_kv_pool is not None:
|
198
|
+
# We should also transfer draft model kv cache. The indices are
|
199
|
+
# always shared with a target model.
|
183
200
|
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
|
184
201
|
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
|
185
202
|
)
|
@@ -194,6 +211,7 @@ class DecodePreallocQueue:
|
|
194
211
|
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
195
212
|
self.metadata_buffers.get_buf_infos()
|
196
213
|
)
|
214
|
+
|
197
215
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
198
216
|
kv_args.gpu_id = self.scheduler.gpu_id
|
199
217
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
@@ -205,27 +223,84 @@ class DecodePreallocQueue:
|
|
205
223
|
)
|
206
224
|
return kv_manager
|
207
225
|
|
208
|
-
def add(self, req: Req) -> None:
|
226
|
+
def add(self, req: Req, is_retracted: bool = False) -> None:
|
209
227
|
"""Add a request to the pending queue."""
|
210
|
-
if req
|
211
|
-
|
212
|
-
|
228
|
+
if self._check_if_req_exceed_kv_capacity(req):
|
229
|
+
return
|
230
|
+
|
231
|
+
if is_retracted:
|
232
|
+
self.retracted_queue.append(req)
|
213
233
|
else:
|
214
|
-
|
215
|
-
|
234
|
+
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
|
235
|
+
kv_receiver_class = get_kv_class(
|
236
|
+
TransferBackend.FAKE, KVClassType.RECEIVER
|
237
|
+
)
|
238
|
+
else:
|
239
|
+
kv_receiver_class = get_kv_class(
|
240
|
+
self.transfer_backend, KVClassType.RECEIVER
|
241
|
+
)
|
242
|
+
|
243
|
+
kv_receiver = kv_receiver_class(
|
244
|
+
mgr=self.kv_manager,
|
245
|
+
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
246
|
+
bootstrap_room=req.bootstrap_room,
|
247
|
+
data_parallel_rank=req.data_parallel_rank,
|
248
|
+
)
|
249
|
+
|
250
|
+
self.queue.append(
|
251
|
+
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
|
216
252
|
)
|
217
|
-
kv_receiver = kv_receiver_class(
|
218
|
-
mgr=self.kv_manager,
|
219
|
-
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
220
|
-
bootstrap_room=req.bootstrap_room,
|
221
|
-
data_parallel_rank=req.data_parallel_rank,
|
222
|
-
)
|
223
|
-
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
|
224
253
|
|
225
|
-
def
|
254
|
+
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
|
255
|
+
if len(req.origin_input_ids) > self.max_total_num_tokens:
|
256
|
+
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
|
257
|
+
logger.error(message)
|
258
|
+
prepare_abort(req, message)
|
259
|
+
self.scheduler.stream_output([req], req.return_logprob)
|
260
|
+
return True
|
261
|
+
return False
|
262
|
+
|
263
|
+
def extend(self, reqs: List[Req], is_retracted: bool = False) -> None:
|
226
264
|
"""Add a request to the pending queue."""
|
227
265
|
for req in reqs:
|
228
|
-
self.add(req)
|
266
|
+
self.add(req, is_retracted=is_retracted)
|
267
|
+
|
268
|
+
def resume_retracted_reqs(self) -> List[Req]:
|
269
|
+
# TODO refactor the scheduling part, reuse with the unified engine logic as much as possible
|
270
|
+
|
271
|
+
# allocate memory
|
272
|
+
resumed_reqs = []
|
273
|
+
indices_to_remove = set()
|
274
|
+
allocatable_tokens = self._allocatable_tokens(count_retracted=False)
|
275
|
+
|
276
|
+
for i, req in enumerate(self.retracted_queue):
|
277
|
+
if self.req_to_token_pool.available_size() <= 0:
|
278
|
+
break
|
279
|
+
|
280
|
+
required_tokens_for_request = (
|
281
|
+
len(req.origin_input_ids)
|
282
|
+
+ len(req.output_ids)
|
283
|
+
+ self.num_reserved_decode_tokens
|
284
|
+
)
|
285
|
+
if required_tokens_for_request > allocatable_tokens:
|
286
|
+
break
|
287
|
+
|
288
|
+
resumed_reqs.append(req)
|
289
|
+
indices_to_remove.add(i)
|
290
|
+
req.is_retracted = False
|
291
|
+
self._pre_alloc(req)
|
292
|
+
allocatable_tokens -= required_tokens_for_request
|
293
|
+
|
294
|
+
# load from cpu, release the cpu copy
|
295
|
+
req.load_kv_cache(self.req_to_token_pool, self.token_to_kv_pool_allocator)
|
296
|
+
|
297
|
+
self.retracted_queue = [
|
298
|
+
entry
|
299
|
+
for i, entry in enumerate(self.retracted_queue)
|
300
|
+
if i not in indices_to_remove
|
301
|
+
]
|
302
|
+
|
303
|
+
return resumed_reqs
|
229
304
|
|
230
305
|
def _update_handshake_waiters(self) -> None:
|
231
306
|
if not self.queue:
|
@@ -255,6 +330,8 @@ class DecodePreallocQueue:
|
|
255
330
|
error_message,
|
256
331
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
257
332
|
)
|
333
|
+
else:
|
334
|
+
raise ValueError(f"Unexpected poll case: {poll}")
|
258
335
|
|
259
336
|
def pop_preallocated(self) -> List[DecodeRequest]:
|
260
337
|
"""Pop the preallocated requests from the pending queue (FIFO)."""
|
@@ -262,8 +339,16 @@ class DecodePreallocQueue:
|
|
262
339
|
|
263
340
|
preallocated_reqs = []
|
264
341
|
indices_to_remove = set()
|
265
|
-
allocatable_tokens = self._allocatable_tokens()
|
266
342
|
|
343
|
+
# We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request
|
344
|
+
# Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted.
|
345
|
+
retractable_tokens = sum(
|
346
|
+
len(r.origin_input_ids) + len(r.output_ids)
|
347
|
+
for r in self.scheduler.running_batch.reqs
|
348
|
+
)
|
349
|
+
allocatable_tokens = self._allocatable_tokens(
|
350
|
+
retractable_tokens=retractable_tokens, count_retracted=True
|
351
|
+
)
|
267
352
|
# First, remove all failed requests from the queue
|
268
353
|
for i, decode_req in enumerate(self.queue):
|
269
354
|
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
|
@@ -272,6 +357,7 @@ class DecodePreallocQueue:
|
|
272
357
|
)
|
273
358
|
indices_to_remove.add(i)
|
274
359
|
|
360
|
+
# Then, preallocate the remaining requests if possible
|
275
361
|
for i, decode_req in enumerate(self.queue):
|
276
362
|
if i in indices_to_remove:
|
277
363
|
continue
|
@@ -285,10 +371,23 @@ class DecodePreallocQueue:
|
|
285
371
|
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
|
286
372
|
break
|
287
373
|
|
374
|
+
# Memory estimation: don't add if the projected memory cannot be met
|
375
|
+
# TODO: add new_token ratio
|
376
|
+
origin_input_len = len(decode_req.req.origin_input_ids)
|
288
377
|
required_tokens_for_request = (
|
289
|
-
|
378
|
+
origin_input_len + self.num_reserved_decode_tokens
|
290
379
|
)
|
291
380
|
|
381
|
+
if (
|
382
|
+
max(
|
383
|
+
required_tokens_for_request,
|
384
|
+
origin_input_len
|
385
|
+
+ decode_req.req.sampling_params.max_new_tokens
|
386
|
+
- retractable_tokens,
|
387
|
+
)
|
388
|
+
> allocatable_tokens
|
389
|
+
):
|
390
|
+
break
|
292
391
|
if required_tokens_for_request > allocatable_tokens:
|
293
392
|
break
|
294
393
|
|
@@ -301,7 +400,6 @@ class DecodePreallocQueue:
|
|
301
400
|
]
|
302
401
|
.cpu()
|
303
402
|
.numpy()
|
304
|
-
.astype(np.int64)
|
305
403
|
)
|
306
404
|
|
307
405
|
decode_req.metadata_buffer_index = (
|
@@ -321,15 +419,35 @@ class DecodePreallocQueue:
|
|
321
419
|
|
322
420
|
return preallocated_reqs
|
323
421
|
|
324
|
-
def _allocatable_tokens(
|
325
|
-
|
326
|
-
|
327
|
-
|
422
|
+
def _allocatable_tokens(
|
423
|
+
self, retractable_tokens: Optional[int] = None, count_retracted: bool = True
|
424
|
+
) -> int:
|
425
|
+
need_space_for_single_req = (
|
426
|
+
max(
|
427
|
+
[
|
428
|
+
x.sampling_params.max_new_tokens
|
429
|
+
+ len(x.origin_input_ids)
|
430
|
+
- retractable_tokens
|
431
|
+
for x in self.scheduler.running_batch.reqs
|
432
|
+
]
|
433
|
+
)
|
434
|
+
if retractable_tokens is not None
|
435
|
+
and len(self.scheduler.running_batch.reqs) > 0
|
436
|
+
else 0
|
437
|
+
)
|
438
|
+
|
439
|
+
available_size = self.token_to_kv_pool_allocator.available_size()
|
440
|
+
|
441
|
+
allocatable_tokens = available_size - max(
|
442
|
+
# preserve some space for future decode
|
443
|
+
self.num_reserved_decode_tokens
|
328
444
|
* (
|
329
445
|
len(self.scheduler.running_batch.reqs)
|
330
446
|
+ len(self.transfer_queue.queue)
|
331
447
|
+ len(self.scheduler.waiting_queue)
|
332
|
-
)
|
448
|
+
),
|
449
|
+
# make sure each request can finish if reach max_tokens with all other requests retracted
|
450
|
+
need_space_for_single_req,
|
333
451
|
)
|
334
452
|
|
335
453
|
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
|
@@ -342,15 +460,27 @@ class DecodePreallocQueue:
|
|
342
460
|
self.scheduler.last_batch.reqs
|
343
461
|
)
|
344
462
|
|
463
|
+
if count_retracted:
|
464
|
+
allocatable_tokens -= sum(
|
465
|
+
[
|
466
|
+
len(req.origin_input_ids)
|
467
|
+
+ len(req.output_ids)
|
468
|
+
+ self.num_reserved_decode_tokens
|
469
|
+
for req in self.retracted_queue
|
470
|
+
]
|
471
|
+
)
|
345
472
|
return allocatable_tokens
|
346
473
|
|
347
474
|
def _pre_alloc(self, req: Req) -> torch.Tensor:
|
348
475
|
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
|
349
476
|
req_pool_indices = self.req_to_token_pool.alloc(1)
|
350
477
|
|
351
|
-
assert
|
478
|
+
assert (
|
479
|
+
req_pool_indices is not None
|
480
|
+
), "req_pool_indices is full! There is a bug in memory estimation."
|
352
481
|
|
353
482
|
req.req_pool_idx = req_pool_indices[0]
|
483
|
+
|
354
484
|
if self.token_to_kv_pool_allocator.page_size == 1:
|
355
485
|
kv_loc = self.token_to_kv_pool_allocator.alloc(
|
356
486
|
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
@@ -375,7 +505,10 @@ class DecodePreallocQueue:
|
|
375
505
|
),
|
376
506
|
extend_num_tokens=num_tokens,
|
377
507
|
)
|
378
|
-
|
508
|
+
|
509
|
+
assert (
|
510
|
+
kv_loc is not None
|
511
|
+
), "KV cache is full! There is a bug in memory estimation."
|
379
512
|
|
380
513
|
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
|
381
514
|
|
@@ -395,6 +528,7 @@ class DecodeTransferQueue:
|
|
395
528
|
self,
|
396
529
|
gloo_group: ProcessGroup,
|
397
530
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
531
|
+
tp_rank: int,
|
398
532
|
metadata_buffers: MetadataBuffers,
|
399
533
|
scheduler: Scheduler,
|
400
534
|
tree_cache: BasePrefixCache,
|
@@ -402,6 +536,7 @@ class DecodeTransferQueue:
|
|
402
536
|
self.queue: List[DecodeRequest] = []
|
403
537
|
self.gloo_group = gloo_group
|
404
538
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
539
|
+
self.tp_rank = tp_rank
|
405
540
|
self.metadata_buffers = metadata_buffers
|
406
541
|
self.scheduler = scheduler
|
407
542
|
self.tree_cache = tree_cache
|
@@ -412,10 +547,9 @@ class DecodeTransferQueue:
|
|
412
547
|
def extend(self, decode_reqs: List[DecodeRequest]) -> None:
|
413
548
|
self.queue.extend(decode_reqs)
|
414
549
|
|
415
|
-
def pop_transferred(self) -> List[
|
550
|
+
def pop_transferred(self) -> List[Req]:
|
416
551
|
if not self.queue:
|
417
552
|
return []
|
418
|
-
|
419
553
|
polls = poll_and_all_reduce(
|
420
554
|
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
421
555
|
)
|
@@ -424,7 +558,7 @@ class DecodeTransferQueue:
|
|
424
558
|
indices_to_remove = set()
|
425
559
|
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
426
560
|
if poll == KVPoll.Failed:
|
427
|
-
error_message = f"Decode transfer failed for request rank={self.
|
561
|
+
error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
428
562
|
try:
|
429
563
|
decode_req.kv_receiver.failure_exception()
|
430
564
|
except Exception as e:
|
@@ -499,15 +633,6 @@ class DecodeTransferQueue:
|
|
499
633
|
|
500
634
|
class SchedulerDisaggregationDecodeMixin:
|
501
635
|
|
502
|
-
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
503
|
-
batch, _ = self.prepare_dp_attn_batch(batch)
|
504
|
-
result = None
|
505
|
-
if batch:
|
506
|
-
result = self.run_batch(batch)
|
507
|
-
if not delay_process:
|
508
|
-
self.process_batch_result(batch, result)
|
509
|
-
return batch, result
|
510
|
-
|
511
636
|
@torch.no_grad()
|
512
637
|
def event_loop_normal_disagg_decode(self: Scheduler):
|
513
638
|
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
@@ -543,13 +668,15 @@ class SchedulerDisaggregationDecodeMixin:
|
|
543
668
|
batch, _ = self._prepare_idle_batch_and_run(None)
|
544
669
|
|
545
670
|
if batch is None and (
|
546
|
-
len(self.
|
671
|
+
len(self.waiting_queue)
|
672
|
+
+ len(self.disagg_decode_transfer_queue.queue)
|
547
673
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
548
674
|
== 0
|
549
675
|
):
|
550
676
|
# When the server is idle, do self-check and re-init some states
|
551
677
|
self.check_memory()
|
552
678
|
self.new_token_ratio = self.init_new_token_ratio
|
679
|
+
self.maybe_sleep_on_idle()
|
553
680
|
|
554
681
|
self.last_batch = batch
|
555
682
|
|
@@ -621,17 +748,28 @@ class SchedulerDisaggregationDecodeMixin:
|
|
621
748
|
self.process_batch_result(tmp_batch, tmp_result)
|
622
749
|
|
623
750
|
if batch is None and (
|
624
|
-
len(self.
|
751
|
+
len(self.waiting_queue)
|
752
|
+
+ len(self.disagg_decode_transfer_queue.queue)
|
625
753
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
626
754
|
== 0
|
627
755
|
):
|
628
756
|
# When the server is idle, do self-check and re-init some states
|
629
757
|
self.check_memory()
|
630
758
|
self.new_token_ratio = self.init_new_token_ratio
|
759
|
+
self.maybe_sleep_on_idle()
|
631
760
|
|
632
761
|
self.last_batch = batch
|
633
762
|
self.last_batch_in_queue = last_batch_in_queue
|
634
763
|
|
764
|
+
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
765
|
+
batch, _ = self.prepare_dp_attn_batch(batch)
|
766
|
+
result = None
|
767
|
+
if batch:
|
768
|
+
result = self.run_batch(batch)
|
769
|
+
if not delay_process:
|
770
|
+
self.process_batch_result(batch, result)
|
771
|
+
return batch, result
|
772
|
+
|
635
773
|
def get_next_disagg_decode_batch_to_run(
|
636
774
|
self: Scheduler,
|
637
775
|
) -> Optional[Tuple[ScheduleBatch, bool]]:
|
@@ -714,6 +852,13 @@ class SchedulerDisaggregationDecodeMixin:
|
|
714
852
|
return new_batch
|
715
853
|
|
716
854
|
def process_decode_queue(self: Scheduler):
|
855
|
+
# try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
|
856
|
+
resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
|
857
|
+
self.waiting_queue.extend(resumed_reqs)
|
858
|
+
if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0:
|
859
|
+
# if there are still retracted requests, we do not allocate new requests
|
860
|
+
return
|
861
|
+
|
717
862
|
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
718
863
|
self.disagg_decode_transfer_queue.extend(req_conns)
|
719
864
|
alloc_reqs = (
|
@@ -1 +1 @@
|
|
1
|
-
from .conn import FakeKVReceiver, FakeKVSender
|
1
|
+
from sglang.srt.disaggregation.fake.conn import FakeKVReceiver, FakeKVSender
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import
|
2
|
+
from typing import List, Optional
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
import numpy.typing as npt
|
@@ -8,7 +8,6 @@ from sglang.srt.disaggregation.base.conn import (
|
|
8
8
|
BaseKVManager,
|
9
9
|
BaseKVReceiver,
|
10
10
|
BaseKVSender,
|
11
|
-
KVArgs,
|
12
11
|
KVPoll,
|
13
12
|
)
|
14
13
|
|
@@ -17,7 +16,14 @@ logger = logging.getLogger(__name__)
|
|
17
16
|
|
18
17
|
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
|
19
18
|
class FakeKVSender(BaseKVSender):
|
20
|
-
def __init__(
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
mgr: BaseKVManager,
|
22
|
+
bootstrap_addr: str,
|
23
|
+
bootstrap_room: int,
|
24
|
+
dest_tp_ranks: List[int],
|
25
|
+
pp_rank: int,
|
26
|
+
):
|
21
27
|
self.has_sent = False
|
22
28
|
|
23
29
|
def poll(self) -> KVPoll:
|
@@ -26,7 +32,7 @@ class FakeKVSender(BaseKVSender):
|
|
26
32
|
return KVPoll.WaitingForInput
|
27
33
|
else:
|
28
34
|
# Assume transfer completed instantly
|
29
|
-
logger.
|
35
|
+
logger.debug("FakeKVSender poll success")
|
30
36
|
return KVPoll.Success
|
31
37
|
|
32
38
|
def init(
|
@@ -34,17 +40,17 @@ class FakeKVSender(BaseKVSender):
|
|
34
40
|
kv_indices: list[int],
|
35
41
|
aux_index: Optional[int] = None,
|
36
42
|
):
|
37
|
-
logger.
|
43
|
+
logger.debug(
|
38
44
|
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
39
45
|
)
|
40
46
|
pass
|
41
47
|
|
42
48
|
def send(
|
43
49
|
self,
|
44
|
-
kv_indices: npt.NDArray[np.
|
50
|
+
kv_indices: npt.NDArray[np.int32],
|
45
51
|
):
|
46
52
|
self.has_sent = True
|
47
|
-
logger.
|
53
|
+
logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
|
48
54
|
|
49
55
|
def failure_exception(self):
|
50
56
|
raise Exception("Fake KVSender Exception")
|
@@ -66,12 +72,12 @@ class FakeKVReceiver(BaseKVReceiver):
|
|
66
72
|
return KVPoll.WaitingForInput
|
67
73
|
else:
|
68
74
|
# Assume transfer completed instantly
|
69
|
-
logger.
|
75
|
+
logger.debug("FakeKVReceiver poll success")
|
70
76
|
return KVPoll.Success
|
71
77
|
|
72
78
|
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
|
73
79
|
self.has_init = True
|
74
|
-
logger.
|
80
|
+
logger.debug(
|
75
81
|
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
76
82
|
)
|
77
83
|
|
@@ -28,12 +28,12 @@ from sglang.srt.disaggregation.base.conn import (
|
|
28
28
|
KVArgs,
|
29
29
|
KVPoll,
|
30
30
|
)
|
31
|
-
from sglang.srt.disaggregation.
|
32
|
-
from sglang.srt.disaggregation.utils import (
|
33
|
-
DisaggregationMode,
|
31
|
+
from sglang.srt.disaggregation.common.utils import (
|
34
32
|
FastQueue,
|
35
33
|
group_concurrent_contiguous,
|
36
34
|
)
|
35
|
+
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
36
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
37
37
|
from sglang.srt.server_args import ServerArgs
|
38
38
|
from sglang.srt.utils import (
|
39
39
|
get_free_port,
|
@@ -59,7 +59,7 @@ class KVTransferError(Exception):
|
|
59
59
|
@dataclasses.dataclass
|
60
60
|
class TransferKVChunk:
|
61
61
|
room: int
|
62
|
-
prefill_kv_indices: npt.NDArray[np.
|
62
|
+
prefill_kv_indices: npt.NDArray[np.int32]
|
63
63
|
index_slice: slice
|
64
64
|
is_last: bool
|
65
65
|
prefill_aux_index: Optional[int]
|
@@ -72,7 +72,7 @@ class TransferInfo:
|
|
72
72
|
endpoint: str
|
73
73
|
dst_port: int
|
74
74
|
mooncake_session_id: str
|
75
|
-
dst_kv_indices: npt.NDArray[np.
|
75
|
+
dst_kv_indices: npt.NDArray[np.int32]
|
76
76
|
dst_aux_index: int
|
77
77
|
required_dst_info_num: int
|
78
78
|
is_dummy: bool
|
@@ -81,10 +81,10 @@ class TransferInfo:
|
|
81
81
|
def from_zmq(cls, msg: List[bytes]):
|
82
82
|
if msg[4] == b"" and msg[5] == b"":
|
83
83
|
is_dummy = True
|
84
|
-
dst_kv_indices = np.array([], dtype=np.
|
84
|
+
dst_kv_indices = np.array([], dtype=np.int32)
|
85
85
|
dst_aux_index = None
|
86
86
|
else:
|
87
|
-
dst_kv_indices = np.frombuffer(msg[4], dtype=np.
|
87
|
+
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
|
88
88
|
dst_aux_index = int(msg[5].decode("ascii"))
|
89
89
|
is_dummy = False
|
90
90
|
return cls(
|
@@ -233,9 +233,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
233
233
|
def send_kvcache(
|
234
234
|
self,
|
235
235
|
mooncake_session_id: str,
|
236
|
-
prefill_kv_indices: npt.NDArray[np.
|
236
|
+
prefill_kv_indices: npt.NDArray[np.int32],
|
237
237
|
dst_kv_ptrs: list[int],
|
238
|
-
dst_kv_indices: npt.NDArray[np.
|
238
|
+
dst_kv_indices: npt.NDArray[np.int32],
|
239
239
|
executor: concurrent.futures.ThreadPoolExecutor,
|
240
240
|
):
|
241
241
|
# Group by indices
|
@@ -545,7 +545,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
545
545
|
def add_transfer_request(
|
546
546
|
self,
|
547
547
|
bootstrap_room: int,
|
548
|
-
kv_indices: npt.NDArray[np.
|
548
|
+
kv_indices: npt.NDArray[np.int32],
|
549
549
|
index_slice: slice,
|
550
550
|
is_last: bool,
|
551
551
|
aux_index: Optional[int] = None,
|
@@ -677,7 +677,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
677
677
|
class MooncakeKVSender(BaseKVSender):
|
678
678
|
|
679
679
|
def __init__(
|
680
|
-
self,
|
680
|
+
self,
|
681
|
+
mgr: MooncakeKVManager,
|
682
|
+
bootstrap_addr: str,
|
683
|
+
bootstrap_room: int,
|
684
|
+
dest_tp_ranks: List[int],
|
685
|
+
pp_rank: int,
|
681
686
|
):
|
682
687
|
self.kv_mgr = mgr
|
683
688
|
self.bootstrap_room = bootstrap_room
|
@@ -696,7 +701,7 @@ class MooncakeKVSender(BaseKVSender):
|
|
696
701
|
|
697
702
|
def send(
|
698
703
|
self,
|
699
|
-
kv_indices: npt.NDArray[np.
|
704
|
+
kv_indices: npt.NDArray[np.int32],
|
700
705
|
):
|
701
706
|
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
702
707
|
self.curr_idx += len(kv_indices)
|
@@ -966,7 +971,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
966
971
|
cls._socket_locks[endpoint] = threading.Lock()
|
967
972
|
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
968
973
|
|
969
|
-
def init(self, kv_indices: npt.NDArray[np.
|
974
|
+
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
970
975
|
for bootstrap_info in self.bootstrap_infos:
|
971
976
|
self.prefill_server_url = (
|
972
977
|
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|