sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/lang/interpreter.py +40 -1
  5. sglang/lang/ir.py +27 -0
  6. sglang/math_utils.py +8 -0
  7. sglang/srt/configs/model_config.py +6 -0
  8. sglang/srt/conversation.py +6 -0
  9. sglang/srt/disaggregation/base/__init__.py +1 -1
  10. sglang/srt/disaggregation/base/conn.py +25 -11
  11. sglang/srt/disaggregation/common/__init__.py +5 -1
  12. sglang/srt/disaggregation/common/utils.py +42 -0
  13. sglang/srt/disaggregation/decode.py +196 -51
  14. sglang/srt/disaggregation/fake/__init__.py +1 -1
  15. sglang/srt/disaggregation/fake/conn.py +15 -9
  16. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +18 -13
  18. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  19. sglang/srt/disaggregation/nixl/conn.py +17 -12
  20. sglang/srt/disaggregation/prefill.py +128 -43
  21. sglang/srt/disaggregation/utils.py +127 -123
  22. sglang/srt/entrypoints/engine.py +15 -1
  23. sglang/srt/entrypoints/http_server.py +13 -2
  24. sglang/srt/eplb_simulator/__init__.py +1 -0
  25. sglang/srt/eplb_simulator/reader.py +51 -0
  26. sglang/srt/layers/activation.py +19 -0
  27. sglang/srt/layers/attention/aiter_backend.py +15 -2
  28. sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
  29. sglang/srt/layers/attention/flashattention_backend.py +53 -64
  30. sglang/srt/layers/attention/flashinfer_backend.py +1 -2
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
  32. sglang/srt/layers/attention/flashmla_backend.py +2 -10
  33. sglang/srt/layers/attention/triton_backend.py +119 -119
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  35. sglang/srt/layers/attention/vision.py +51 -24
  36. sglang/srt/layers/communicator.py +23 -5
  37. sglang/srt/layers/linear.py +0 -4
  38. sglang/srt/layers/logits_processor.py +0 -12
  39. sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
  40. sglang/srt/layers/moe/ep_moe/layer.py +42 -32
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
  43. sglang/srt/layers/moe/topk.py +16 -8
  44. sglang/srt/layers/pooler.py +56 -0
  45. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  46. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  47. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  49. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  50. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  51. sglang/srt/layers/radix_attention.py +2 -3
  52. sglang/srt/lora/lora_manager.py +79 -34
  53. sglang/srt/lora/mem_pool.py +4 -5
  54. sglang/srt/managers/cache_controller.py +2 -1
  55. sglang/srt/managers/io_struct.py +28 -4
  56. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  57. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  58. sglang/srt/managers/schedule_batch.py +39 -6
  59. sglang/srt/managers/scheduler.py +73 -17
  60. sglang/srt/managers/tokenizer_manager.py +29 -2
  61. sglang/srt/mem_cache/chunk_cache.py +1 -0
  62. sglang/srt/mem_cache/hiradix_cache.py +4 -2
  63. sglang/srt/mem_cache/memory_pool.py +111 -407
  64. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  65. sglang/srt/mem_cache/radix_cache.py +36 -12
  66. sglang/srt/model_executor/cuda_graph_runner.py +122 -55
  67. sglang/srt/model_executor/forward_batch_info.py +14 -5
  68. sglang/srt/model_executor/model_runner.py +6 -6
  69. sglang/srt/model_loader/loader.py +8 -1
  70. sglang/srt/models/bert.py +113 -13
  71. sglang/srt/models/deepseek_v2.py +113 -155
  72. sglang/srt/models/internvl.py +46 -102
  73. sglang/srt/models/roberta.py +117 -9
  74. sglang/srt/models/vila.py +305 -0
  75. sglang/srt/openai_api/adapter.py +162 -4
  76. sglang/srt/openai_api/protocol.py +37 -1
  77. sglang/srt/sampling/sampling_batch_info.py +24 -0
  78. sglang/srt/sampling/sampling_params.py +2 -0
  79. sglang/srt/server_args.py +318 -233
  80. sglang/srt/speculative/build_eagle_tree.py +1 -1
  81. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
  82. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
  83. sglang/srt/speculative/eagle_utils.py +389 -109
  84. sglang/srt/speculative/eagle_worker.py +134 -43
  85. sglang/srt/two_batch_overlap.py +4 -2
  86. sglang/srt/utils.py +58 -0
  87. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  88. sglang/test/runners.py +38 -3
  89. sglang/test/test_block_fp8.py +1 -0
  90. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  91. sglang/test/test_block_fp8_ep.py +1 -0
  92. sglang/test/test_utils.py +3 -1
  93. sglang/utils.py +9 -0
  94. sglang/version.py +1 -1
  95. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
  96. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
  97. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -31,10 +31,10 @@ import numpy as np
31
31
  import torch
32
32
  from torch.distributed import ProcessGroup
33
33
 
34
- from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
34
+ from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
35
35
  from sglang.srt.disaggregation.utils import (
36
+ FAKE_BOOTSTRAP_HOST,
36
37
  DisaggregationMode,
37
- FakeBootstrapHost,
38
38
  KVClassType,
39
39
  MetadataBuffers,
40
40
  ReqToMetadataIdxAllocator,
@@ -47,7 +47,11 @@ from sglang.srt.disaggregation.utils import (
47
47
  )
48
48
  from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
49
49
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
50
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
50
+ from sglang.srt.mem_cache.memory_pool import (
51
+ KVCache,
52
+ ReqToTokenPool,
53
+ TokenToKVPoolAllocator,
54
+ )
51
55
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
52
56
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
53
57
 
@@ -145,7 +149,12 @@ class DecodePreallocQueue:
145
149
  gloo_group: ProcessGroup,
146
150
  tp_rank: int,
147
151
  tp_size: int,
152
+ dp_size: int,
153
+ gpu_id: int,
148
154
  bootstrap_port: int,
155
+ max_total_num_tokens: int,
156
+ prefill_pp_size: int,
157
+ num_reserved_decode_tokens: int,
149
158
  transfer_backend: TransferBackend,
150
159
  ):
151
160
  self.req_to_token_pool = req_to_token_pool
@@ -161,25 +170,33 @@ class DecodePreallocQueue:
161
170
  self.gloo_group = gloo_group
162
171
  self.tp_rank = tp_rank
163
172
  self.tp_size = tp_size
173
+ self.dp_size = dp_size
174
+ self.gpu_id = gpu_id
164
175
  self.bootstrap_port = bootstrap_port
165
-
166
- self.num_reserved_decode_tokens = int(
167
- os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512")
168
- )
169
-
176
+ self.max_total_num_tokens = max_total_num_tokens
177
+ self.prefill_pp_size = prefill_pp_size
178
+ self.num_reserved_decode_tokens = num_reserved_decode_tokens
179
+ self.transfer_backend = transfer_backend
170
180
  # Queue for requests pending pre-allocation
171
181
  self.queue: List[DecodeRequest] = []
172
- self.transfer_backend = transfer_backend
182
+ self.retracted_queue: List[Req] = []
183
+ self.prefill_pp_size = prefill_pp_size
173
184
  self.kv_manager = self._init_kv_manager()
174
185
 
175
186
  def _init_kv_manager(self) -> BaseKVManager:
176
- kv_args = KVArgs()
177
- kv_args.engine_rank = self.tp_rank
187
+ kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
188
+ kv_args = kv_args_class()
189
+
190
+ attn_tp_size = self.tp_size // self.dp_size
191
+ kv_args.engine_rank = self.tp_rank % (attn_tp_size)
192
+ kv_args.decode_tp_size = attn_tp_size
193
+ kv_args.prefill_pp_size = self.prefill_pp_size
178
194
  kv_data_ptrs, kv_data_lens, kv_item_lens = (
179
195
  self.token_to_kv_pool.get_contiguous_buf_infos()
180
196
  )
181
-
182
197
  if self.draft_token_to_kv_pool is not None:
198
+ # We should also transfer draft model kv cache. The indices are
199
+ # always shared with a target model.
183
200
  draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
184
201
  self.draft_token_to_kv_pool.get_contiguous_buf_infos()
185
202
  )
@@ -194,6 +211,7 @@ class DecodePreallocQueue:
194
211
  kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
195
212
  self.metadata_buffers.get_buf_infos()
196
213
  )
214
+
197
215
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
198
216
  kv_args.gpu_id = self.scheduler.gpu_id
199
217
  kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
@@ -205,27 +223,84 @@ class DecodePreallocQueue:
205
223
  )
206
224
  return kv_manager
207
225
 
208
- def add(self, req: Req) -> None:
226
+ def add(self, req: Req, is_retracted: bool = False) -> None:
209
227
  """Add a request to the pending queue."""
210
- if req.bootstrap_host == FakeBootstrapHost:
211
- # Fake transfer for warmup reqs
212
- kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
228
+ if self._check_if_req_exceed_kv_capacity(req):
229
+ return
230
+
231
+ if is_retracted:
232
+ self.retracted_queue.append(req)
213
233
  else:
214
- kv_receiver_class = get_kv_class(
215
- self.transfer_backend, KVClassType.RECEIVER
234
+ if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
235
+ kv_receiver_class = get_kv_class(
236
+ TransferBackend.FAKE, KVClassType.RECEIVER
237
+ )
238
+ else:
239
+ kv_receiver_class = get_kv_class(
240
+ self.transfer_backend, KVClassType.RECEIVER
241
+ )
242
+
243
+ kv_receiver = kv_receiver_class(
244
+ mgr=self.kv_manager,
245
+ bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
246
+ bootstrap_room=req.bootstrap_room,
247
+ data_parallel_rank=req.data_parallel_rank,
248
+ )
249
+
250
+ self.queue.append(
251
+ DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
216
252
  )
217
- kv_receiver = kv_receiver_class(
218
- mgr=self.kv_manager,
219
- bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
220
- bootstrap_room=req.bootstrap_room,
221
- data_parallel_rank=req.data_parallel_rank,
222
- )
223
- self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
224
253
 
225
- def extend(self, reqs: List[Req]) -> None:
254
+ def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
255
+ if len(req.origin_input_ids) > self.max_total_num_tokens:
256
+ message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
257
+ logger.error(message)
258
+ prepare_abort(req, message)
259
+ self.scheduler.stream_output([req], req.return_logprob)
260
+ return True
261
+ return False
262
+
263
+ def extend(self, reqs: List[Req], is_retracted: bool = False) -> None:
226
264
  """Add a request to the pending queue."""
227
265
  for req in reqs:
228
- self.add(req)
266
+ self.add(req, is_retracted=is_retracted)
267
+
268
+ def resume_retracted_reqs(self) -> List[Req]:
269
+ # TODO refactor the scheduling part, reuse with the unified engine logic as much as possible
270
+
271
+ # allocate memory
272
+ resumed_reqs = []
273
+ indices_to_remove = set()
274
+ allocatable_tokens = self._allocatable_tokens(count_retracted=False)
275
+
276
+ for i, req in enumerate(self.retracted_queue):
277
+ if self.req_to_token_pool.available_size() <= 0:
278
+ break
279
+
280
+ required_tokens_for_request = (
281
+ len(req.origin_input_ids)
282
+ + len(req.output_ids)
283
+ + self.num_reserved_decode_tokens
284
+ )
285
+ if required_tokens_for_request > allocatable_tokens:
286
+ break
287
+
288
+ resumed_reqs.append(req)
289
+ indices_to_remove.add(i)
290
+ req.is_retracted = False
291
+ self._pre_alloc(req)
292
+ allocatable_tokens -= required_tokens_for_request
293
+
294
+ # load from cpu, release the cpu copy
295
+ req.load_kv_cache(self.req_to_token_pool, self.token_to_kv_pool_allocator)
296
+
297
+ self.retracted_queue = [
298
+ entry
299
+ for i, entry in enumerate(self.retracted_queue)
300
+ if i not in indices_to_remove
301
+ ]
302
+
303
+ return resumed_reqs
229
304
 
230
305
  def _update_handshake_waiters(self) -> None:
231
306
  if not self.queue:
@@ -255,6 +330,8 @@ class DecodePreallocQueue:
255
330
  error_message,
256
331
  status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
257
332
  )
333
+ else:
334
+ raise ValueError(f"Unexpected poll case: {poll}")
258
335
 
259
336
  def pop_preallocated(self) -> List[DecodeRequest]:
260
337
  """Pop the preallocated requests from the pending queue (FIFO)."""
@@ -262,8 +339,16 @@ class DecodePreallocQueue:
262
339
 
263
340
  preallocated_reqs = []
264
341
  indices_to_remove = set()
265
- allocatable_tokens = self._allocatable_tokens()
266
342
 
343
+ # We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request
344
+ # Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted.
345
+ retractable_tokens = sum(
346
+ len(r.origin_input_ids) + len(r.output_ids)
347
+ for r in self.scheduler.running_batch.reqs
348
+ )
349
+ allocatable_tokens = self._allocatable_tokens(
350
+ retractable_tokens=retractable_tokens, count_retracted=True
351
+ )
267
352
  # First, remove all failed requests from the queue
268
353
  for i, decode_req in enumerate(self.queue):
269
354
  if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
@@ -272,6 +357,7 @@ class DecodePreallocQueue:
272
357
  )
273
358
  indices_to_remove.add(i)
274
359
 
360
+ # Then, preallocate the remaining requests if possible
275
361
  for i, decode_req in enumerate(self.queue):
276
362
  if i in indices_to_remove:
277
363
  continue
@@ -285,10 +371,23 @@ class DecodePreallocQueue:
285
371
  if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
286
372
  break
287
373
 
374
+ # Memory estimation: don't add if the projected memory cannot be met
375
+ # TODO: add new_token ratio
376
+ origin_input_len = len(decode_req.req.origin_input_ids)
288
377
  required_tokens_for_request = (
289
- len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens
378
+ origin_input_len + self.num_reserved_decode_tokens
290
379
  )
291
380
 
381
+ if (
382
+ max(
383
+ required_tokens_for_request,
384
+ origin_input_len
385
+ + decode_req.req.sampling_params.max_new_tokens
386
+ - retractable_tokens,
387
+ )
388
+ > allocatable_tokens
389
+ ):
390
+ break
292
391
  if required_tokens_for_request > allocatable_tokens:
293
392
  break
294
393
 
@@ -301,7 +400,6 @@ class DecodePreallocQueue:
301
400
  ]
302
401
  .cpu()
303
402
  .numpy()
304
- .astype(np.int64)
305
403
  )
306
404
 
307
405
  decode_req.metadata_buffer_index = (
@@ -321,15 +419,35 @@ class DecodePreallocQueue:
321
419
 
322
420
  return preallocated_reqs
323
421
 
324
- def _allocatable_tokens(self) -> int:
325
- allocatable_tokens = (
326
- self.token_to_kv_pool_allocator.available_size()
327
- - self.num_reserved_decode_tokens
422
+ def _allocatable_tokens(
423
+ self, retractable_tokens: Optional[int] = None, count_retracted: bool = True
424
+ ) -> int:
425
+ need_space_for_single_req = (
426
+ max(
427
+ [
428
+ x.sampling_params.max_new_tokens
429
+ + len(x.origin_input_ids)
430
+ - retractable_tokens
431
+ for x in self.scheduler.running_batch.reqs
432
+ ]
433
+ )
434
+ if retractable_tokens is not None
435
+ and len(self.scheduler.running_batch.reqs) > 0
436
+ else 0
437
+ )
438
+
439
+ available_size = self.token_to_kv_pool_allocator.available_size()
440
+
441
+ allocatable_tokens = available_size - max(
442
+ # preserve some space for future decode
443
+ self.num_reserved_decode_tokens
328
444
  * (
329
445
  len(self.scheduler.running_batch.reqs)
330
446
  + len(self.transfer_queue.queue)
331
447
  + len(self.scheduler.waiting_queue)
332
- )
448
+ ),
449
+ # make sure each request can finish if reach max_tokens with all other requests retracted
450
+ need_space_for_single_req,
333
451
  )
334
452
 
335
453
  # Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
@@ -342,15 +460,27 @@ class DecodePreallocQueue:
342
460
  self.scheduler.last_batch.reqs
343
461
  )
344
462
 
463
+ if count_retracted:
464
+ allocatable_tokens -= sum(
465
+ [
466
+ len(req.origin_input_ids)
467
+ + len(req.output_ids)
468
+ + self.num_reserved_decode_tokens
469
+ for req in self.retracted_queue
470
+ ]
471
+ )
345
472
  return allocatable_tokens
346
473
 
347
474
  def _pre_alloc(self, req: Req) -> torch.Tensor:
348
475
  """Pre-allocate the memory for req_to_token and token_kv_pool"""
349
476
  req_pool_indices = self.req_to_token_pool.alloc(1)
350
477
 
351
- assert req_pool_indices is not None
478
+ assert (
479
+ req_pool_indices is not None
480
+ ), "req_pool_indices is full! There is a bug in memory estimation."
352
481
 
353
482
  req.req_pool_idx = req_pool_indices[0]
483
+
354
484
  if self.token_to_kv_pool_allocator.page_size == 1:
355
485
  kv_loc = self.token_to_kv_pool_allocator.alloc(
356
486
  len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
@@ -375,7 +505,10 @@ class DecodePreallocQueue:
375
505
  ),
376
506
  extend_num_tokens=num_tokens,
377
507
  )
378
- assert kv_loc is not None
508
+
509
+ assert (
510
+ kv_loc is not None
511
+ ), "KV cache is full! There is a bug in memory estimation."
379
512
 
380
513
  self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
381
514
 
@@ -395,6 +528,7 @@ class DecodeTransferQueue:
395
528
  self,
396
529
  gloo_group: ProcessGroup,
397
530
  req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
531
+ tp_rank: int,
398
532
  metadata_buffers: MetadataBuffers,
399
533
  scheduler: Scheduler,
400
534
  tree_cache: BasePrefixCache,
@@ -402,6 +536,7 @@ class DecodeTransferQueue:
402
536
  self.queue: List[DecodeRequest] = []
403
537
  self.gloo_group = gloo_group
404
538
  self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
539
+ self.tp_rank = tp_rank
405
540
  self.metadata_buffers = metadata_buffers
406
541
  self.scheduler = scheduler
407
542
  self.tree_cache = tree_cache
@@ -412,10 +547,9 @@ class DecodeTransferQueue:
412
547
  def extend(self, decode_reqs: List[DecodeRequest]) -> None:
413
548
  self.queue.extend(decode_reqs)
414
549
 
415
- def pop_transferred(self) -> List[DecodeRequest]:
550
+ def pop_transferred(self) -> List[Req]:
416
551
  if not self.queue:
417
552
  return []
418
-
419
553
  polls = poll_and_all_reduce(
420
554
  [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
421
555
  )
@@ -424,7 +558,7 @@ class DecodeTransferQueue:
424
558
  indices_to_remove = set()
425
559
  for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
426
560
  if poll == KVPoll.Failed:
427
- error_message = f"Decode transfer failed for request rank={self.scheduler.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
561
+ error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
428
562
  try:
429
563
  decode_req.kv_receiver.failure_exception()
430
564
  except Exception as e:
@@ -499,15 +633,6 @@ class DecodeTransferQueue:
499
633
 
500
634
  class SchedulerDisaggregationDecodeMixin:
501
635
 
502
- def _prepare_idle_batch_and_run(self, batch, delay_process=False):
503
- batch, _ = self.prepare_dp_attn_batch(batch)
504
- result = None
505
- if batch:
506
- result = self.run_batch(batch)
507
- if not delay_process:
508
- self.process_batch_result(batch, result)
509
- return batch, result
510
-
511
636
  @torch.no_grad()
512
637
  def event_loop_normal_disagg_decode(self: Scheduler):
513
638
  """A normal scheduler loop for decode worker in disaggregation mode."""
@@ -543,13 +668,15 @@ class SchedulerDisaggregationDecodeMixin:
543
668
  batch, _ = self._prepare_idle_batch_and_run(None)
544
669
 
545
670
  if batch is None and (
546
- len(self.disagg_decode_transfer_queue.queue)
671
+ len(self.waiting_queue)
672
+ + len(self.disagg_decode_transfer_queue.queue)
547
673
  + len(self.disagg_decode_prealloc_queue.queue)
548
674
  == 0
549
675
  ):
550
676
  # When the server is idle, do self-check and re-init some states
551
677
  self.check_memory()
552
678
  self.new_token_ratio = self.init_new_token_ratio
679
+ self.maybe_sleep_on_idle()
553
680
 
554
681
  self.last_batch = batch
555
682
 
@@ -621,17 +748,28 @@ class SchedulerDisaggregationDecodeMixin:
621
748
  self.process_batch_result(tmp_batch, tmp_result)
622
749
 
623
750
  if batch is None and (
624
- len(self.disagg_decode_transfer_queue.queue)
751
+ len(self.waiting_queue)
752
+ + len(self.disagg_decode_transfer_queue.queue)
625
753
  + len(self.disagg_decode_prealloc_queue.queue)
626
754
  == 0
627
755
  ):
628
756
  # When the server is idle, do self-check and re-init some states
629
757
  self.check_memory()
630
758
  self.new_token_ratio = self.init_new_token_ratio
759
+ self.maybe_sleep_on_idle()
631
760
 
632
761
  self.last_batch = batch
633
762
  self.last_batch_in_queue = last_batch_in_queue
634
763
 
764
+ def _prepare_idle_batch_and_run(self, batch, delay_process=False):
765
+ batch, _ = self.prepare_dp_attn_batch(batch)
766
+ result = None
767
+ if batch:
768
+ result = self.run_batch(batch)
769
+ if not delay_process:
770
+ self.process_batch_result(batch, result)
771
+ return batch, result
772
+
635
773
  def get_next_disagg_decode_batch_to_run(
636
774
  self: Scheduler,
637
775
  ) -> Optional[Tuple[ScheduleBatch, bool]]:
@@ -714,6 +852,13 @@ class SchedulerDisaggregationDecodeMixin:
714
852
  return new_batch
715
853
 
716
854
  def process_decode_queue(self: Scheduler):
855
+ # try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
856
+ resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
857
+ self.waiting_queue.extend(resumed_reqs)
858
+ if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0:
859
+ # if there are still retracted requests, we do not allocate new requests
860
+ return
861
+
717
862
  req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
718
863
  self.disagg_decode_transfer_queue.extend(req_conns)
719
864
  alloc_reqs = (
@@ -1 +1 @@
1
- from .conn import FakeKVReceiver, FakeKVSender
1
+ from sglang.srt.disaggregation.fake.conn import FakeKVReceiver, FakeKVSender
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Dict, List, Optional, Tuple, Union
2
+ from typing import List, Optional
3
3
 
4
4
  import numpy as np
5
5
  import numpy.typing as npt
@@ -8,7 +8,6 @@ from sglang.srt.disaggregation.base.conn import (
8
8
  BaseKVManager,
9
9
  BaseKVReceiver,
10
10
  BaseKVSender,
11
- KVArgs,
12
11
  KVPoll,
13
12
  )
14
13
 
@@ -17,7 +16,14 @@ logger = logging.getLogger(__name__)
17
16
 
18
17
  # For warmup reqs, we don't kv transfer, we use the fake sender and receiver
19
18
  class FakeKVSender(BaseKVSender):
20
- def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int):
19
+ def __init__(
20
+ self,
21
+ mgr: BaseKVManager,
22
+ bootstrap_addr: str,
23
+ bootstrap_room: int,
24
+ dest_tp_ranks: List[int],
25
+ pp_rank: int,
26
+ ):
21
27
  self.has_sent = False
22
28
 
23
29
  def poll(self) -> KVPoll:
@@ -26,7 +32,7 @@ class FakeKVSender(BaseKVSender):
26
32
  return KVPoll.WaitingForInput
27
33
  else:
28
34
  # Assume transfer completed instantly
29
- logger.info("FakeKVSender poll success")
35
+ logger.debug("FakeKVSender poll success")
30
36
  return KVPoll.Success
31
37
 
32
38
  def init(
@@ -34,17 +40,17 @@ class FakeKVSender(BaseKVSender):
34
40
  kv_indices: list[int],
35
41
  aux_index: Optional[int] = None,
36
42
  ):
37
- logger.info(
43
+ logger.debug(
38
44
  f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
39
45
  )
40
46
  pass
41
47
 
42
48
  def send(
43
49
  self,
44
- kv_indices: npt.NDArray[np.int64],
50
+ kv_indices: npt.NDArray[np.int32],
45
51
  ):
46
52
  self.has_sent = True
47
- logger.info(f"FakeKVSender send with kv_indices: {kv_indices}")
53
+ logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
48
54
 
49
55
  def failure_exception(self):
50
56
  raise Exception("Fake KVSender Exception")
@@ -66,12 +72,12 @@ class FakeKVReceiver(BaseKVReceiver):
66
72
  return KVPoll.WaitingForInput
67
73
  else:
68
74
  # Assume transfer completed instantly
69
- logger.info("FakeKVReceiver poll success")
75
+ logger.debug("FakeKVReceiver poll success")
70
76
  return KVPoll.Success
71
77
 
72
78
  def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
73
79
  self.has_init = True
74
- logger.info(
80
+ logger.debug(
75
81
  f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
76
82
  )
77
83
 
@@ -1,4 +1,4 @@
1
- from .conn import (
1
+ from sglang.srt.disaggregation.mooncake.conn import (
2
2
  MooncakeKVBootstrapServer,
3
3
  MooncakeKVManager,
4
4
  MooncakeKVReceiver,
@@ -28,12 +28,12 @@ from sglang.srt.disaggregation.base.conn import (
28
28
  KVArgs,
29
29
  KVPoll,
30
30
  )
31
- from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
32
- from sglang.srt.disaggregation.utils import (
33
- DisaggregationMode,
31
+ from sglang.srt.disaggregation.common.utils import (
34
32
  FastQueue,
35
33
  group_concurrent_contiguous,
36
34
  )
35
+ from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
36
+ from sglang.srt.disaggregation.utils import DisaggregationMode
37
37
  from sglang.srt.server_args import ServerArgs
38
38
  from sglang.srt.utils import (
39
39
  get_free_port,
@@ -59,7 +59,7 @@ class KVTransferError(Exception):
59
59
  @dataclasses.dataclass
60
60
  class TransferKVChunk:
61
61
  room: int
62
- prefill_kv_indices: npt.NDArray[np.int64]
62
+ prefill_kv_indices: npt.NDArray[np.int32]
63
63
  index_slice: slice
64
64
  is_last: bool
65
65
  prefill_aux_index: Optional[int]
@@ -72,7 +72,7 @@ class TransferInfo:
72
72
  endpoint: str
73
73
  dst_port: int
74
74
  mooncake_session_id: str
75
- dst_kv_indices: npt.NDArray[np.int64]
75
+ dst_kv_indices: npt.NDArray[np.int32]
76
76
  dst_aux_index: int
77
77
  required_dst_info_num: int
78
78
  is_dummy: bool
@@ -81,10 +81,10 @@ class TransferInfo:
81
81
  def from_zmq(cls, msg: List[bytes]):
82
82
  if msg[4] == b"" and msg[5] == b"":
83
83
  is_dummy = True
84
- dst_kv_indices = np.array([], dtype=np.int64)
84
+ dst_kv_indices = np.array([], dtype=np.int32)
85
85
  dst_aux_index = None
86
86
  else:
87
- dst_kv_indices = np.frombuffer(msg[4], dtype=np.int64)
87
+ dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
88
88
  dst_aux_index = int(msg[5].decode("ascii"))
89
89
  is_dummy = False
90
90
  return cls(
@@ -233,9 +233,9 @@ class MooncakeKVManager(BaseKVManager):
233
233
  def send_kvcache(
234
234
  self,
235
235
  mooncake_session_id: str,
236
- prefill_kv_indices: npt.NDArray[np.int64],
236
+ prefill_kv_indices: npt.NDArray[np.int32],
237
237
  dst_kv_ptrs: list[int],
238
- dst_kv_indices: npt.NDArray[np.int64],
238
+ dst_kv_indices: npt.NDArray[np.int32],
239
239
  executor: concurrent.futures.ThreadPoolExecutor,
240
240
  ):
241
241
  # Group by indices
@@ -545,7 +545,7 @@ class MooncakeKVManager(BaseKVManager):
545
545
  def add_transfer_request(
546
546
  self,
547
547
  bootstrap_room: int,
548
- kv_indices: npt.NDArray[np.int64],
548
+ kv_indices: npt.NDArray[np.int32],
549
549
  index_slice: slice,
550
550
  is_last: bool,
551
551
  aux_index: Optional[int] = None,
@@ -677,7 +677,12 @@ class MooncakeKVManager(BaseKVManager):
677
677
  class MooncakeKVSender(BaseKVSender):
678
678
 
679
679
  def __init__(
680
- self, mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: int
680
+ self,
681
+ mgr: MooncakeKVManager,
682
+ bootstrap_addr: str,
683
+ bootstrap_room: int,
684
+ dest_tp_ranks: List[int],
685
+ pp_rank: int,
681
686
  ):
682
687
  self.kv_mgr = mgr
683
688
  self.bootstrap_room = bootstrap_room
@@ -696,7 +701,7 @@ class MooncakeKVSender(BaseKVSender):
696
701
 
697
702
  def send(
698
703
  self,
699
- kv_indices: npt.NDArray[np.int64],
704
+ kv_indices: npt.NDArray[np.int32],
700
705
  ):
701
706
  index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
702
707
  self.curr_idx += len(kv_indices)
@@ -966,7 +971,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
966
971
  cls._socket_locks[endpoint] = threading.Lock()
967
972
  return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
968
973
 
969
- def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
974
+ def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
970
975
  for bootstrap_info in self.bootstrap_infos:
971
976
  self.prefill_server_url = (
972
977
  f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
@@ -1 +1,6 @@
1
- from .conn import NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender
1
+ from sglang.srt.disaggregation.nixl.conn import (
2
+ NixlKVBootstrapServer,
3
+ NixlKVManager,
4
+ NixlKVReceiver,
5
+ NixlKVSender,
6
+ )