sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -21,20 +21,19 @@ Life cycle of a request in the decode server
21
21
  from __future__ import annotations
22
22
 
23
23
  import logging
24
- import os
25
24
  from collections import deque
26
25
  from dataclasses import dataclass
27
26
  from http import HTTPStatus
28
27
  from typing import TYPE_CHECKING, List, Optional, Tuple, Union
29
28
 
30
- import numpy as np
31
29
  import torch
32
30
  from torch.distributed import ProcessGroup
33
31
 
34
- from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
32
+ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
33
+ from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
35
34
  from sglang.srt.disaggregation.utils import (
35
+ FAKE_BOOTSTRAP_HOST,
36
36
  DisaggregationMode,
37
- FakeBootstrapHost,
38
37
  KVClassType,
39
38
  MetadataBuffers,
40
39
  ReqToMetadataIdxAllocator,
@@ -46,10 +45,12 @@ from sglang.srt.disaggregation.utils import (
46
45
  prepare_abort,
47
46
  )
48
47
  from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
48
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
49
49
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
50
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
50
+ from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
51
51
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
52
52
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
53
+ from sglang.srt.utils import require_mlp_sync
53
54
 
54
55
  logger = logging.getLogger(__name__)
55
56
 
@@ -86,7 +87,7 @@ class DecodeReqToTokenPool:
86
87
  self.max_context_len = max_context_len
87
88
  self.device = device
88
89
  self.pre_alloc_size = pre_alloc_size
89
- with memory_saver_adapter.region():
90
+ with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):
90
91
  self.req_to_token = torch.zeros(
91
92
  (size + pre_alloc_size, max_context_len),
92
93
  dtype=torch.int32,
@@ -135,7 +136,7 @@ class DecodePreallocQueue:
135
136
  def __init__(
136
137
  self,
137
138
  req_to_token_pool: ReqToTokenPool,
138
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
139
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
139
140
  draft_token_to_kv_pool: Optional[KVCache],
140
141
  req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
141
142
  metadata_buffers: MetadataBuffers,
@@ -145,7 +146,12 @@ class DecodePreallocQueue:
145
146
  gloo_group: ProcessGroup,
146
147
  tp_rank: int,
147
148
  tp_size: int,
149
+ dp_size: int,
150
+ gpu_id: int,
148
151
  bootstrap_port: int,
152
+ max_total_num_tokens: int,
153
+ prefill_pp_size: int,
154
+ num_reserved_decode_tokens: int,
149
155
  transfer_backend: TransferBackend,
150
156
  ):
151
157
  self.req_to_token_pool = req_to_token_pool
@@ -161,25 +167,33 @@ class DecodePreallocQueue:
161
167
  self.gloo_group = gloo_group
162
168
  self.tp_rank = tp_rank
163
169
  self.tp_size = tp_size
170
+ self.dp_size = dp_size
171
+ self.gpu_id = gpu_id
164
172
  self.bootstrap_port = bootstrap_port
165
-
166
- self.num_reserved_decode_tokens = int(
167
- os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512")
168
- )
169
-
173
+ self.max_total_num_tokens = max_total_num_tokens
174
+ self.prefill_pp_size = prefill_pp_size
175
+ self.num_reserved_decode_tokens = num_reserved_decode_tokens
176
+ self.transfer_backend = transfer_backend
170
177
  # Queue for requests pending pre-allocation
171
178
  self.queue: List[DecodeRequest] = []
172
- self.transfer_backend = transfer_backend
179
+ self.retracted_queue: List[Req] = []
180
+ self.prefill_pp_size = prefill_pp_size
173
181
  self.kv_manager = self._init_kv_manager()
174
182
 
175
183
  def _init_kv_manager(self) -> BaseKVManager:
176
- kv_args = KVArgs()
177
- kv_args.engine_rank = self.tp_rank
184
+ kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
185
+ kv_args = kv_args_class()
186
+
187
+ attn_tp_size = self.tp_size // self.dp_size
188
+ kv_args.engine_rank = self.tp_rank % (attn_tp_size)
189
+ kv_args.decode_tp_size = attn_tp_size
190
+ kv_args.prefill_pp_size = self.prefill_pp_size
178
191
  kv_data_ptrs, kv_data_lens, kv_item_lens = (
179
192
  self.token_to_kv_pool.get_contiguous_buf_infos()
180
193
  )
181
-
182
194
  if self.draft_token_to_kv_pool is not None:
195
+ # We should also transfer draft model kv cache. The indices are
196
+ # always shared with a target model.
183
197
  draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
184
198
  self.draft_token_to_kv_pool.get_contiguous_buf_infos()
185
199
  )
@@ -194,6 +208,7 @@ class DecodePreallocQueue:
194
208
  kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
195
209
  self.metadata_buffers.get_buf_infos()
196
210
  )
211
+
197
212
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
198
213
  kv_args.gpu_id = self.scheduler.gpu_id
199
214
  kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
@@ -205,27 +220,84 @@ class DecodePreallocQueue:
205
220
  )
206
221
  return kv_manager
207
222
 
208
- def add(self, req: Req) -> None:
223
+ def add(self, req: Req, is_retracted: bool = False) -> None:
209
224
  """Add a request to the pending queue."""
210
- if req.bootstrap_host == FakeBootstrapHost:
211
- # Fake transfer for warmup reqs
212
- kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
225
+ if self._check_if_req_exceed_kv_capacity(req):
226
+ return
227
+
228
+ if is_retracted:
229
+ self.retracted_queue.append(req)
213
230
  else:
214
- kv_receiver_class = get_kv_class(
215
- self.transfer_backend, KVClassType.RECEIVER
231
+ if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
232
+ kv_receiver_class = get_kv_class(
233
+ TransferBackend.FAKE, KVClassType.RECEIVER
234
+ )
235
+ else:
236
+ kv_receiver_class = get_kv_class(
237
+ self.transfer_backend, KVClassType.RECEIVER
238
+ )
239
+
240
+ kv_receiver = kv_receiver_class(
241
+ mgr=self.kv_manager,
242
+ bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
243
+ bootstrap_room=req.bootstrap_room,
244
+ data_parallel_rank=req.data_parallel_rank,
216
245
  )
217
- kv_receiver = kv_receiver_class(
218
- mgr=self.kv_manager,
219
- bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
220
- bootstrap_room=req.bootstrap_room,
221
- data_parallel_rank=req.data_parallel_rank,
222
- )
223
- self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
224
246
 
225
- def extend(self, reqs: List[Req]) -> None:
247
+ self.queue.append(
248
+ DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
249
+ )
250
+
251
+ def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
252
+ if len(req.origin_input_ids) > self.max_total_num_tokens:
253
+ message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
254
+ logger.error(message)
255
+ prepare_abort(req, message)
256
+ self.scheduler.stream_output([req], req.return_logprob)
257
+ return True
258
+ return False
259
+
260
+ def extend(self, reqs: List[Req], is_retracted: bool = False) -> None:
226
261
  """Add a request to the pending queue."""
227
262
  for req in reqs:
228
- self.add(req)
263
+ self.add(req, is_retracted=is_retracted)
264
+
265
+ def resume_retracted_reqs(self) -> List[Req]:
266
+ # TODO refactor the scheduling part, reuse with the unified engine logic as much as possible
267
+
268
+ # allocate memory
269
+ resumed_reqs = []
270
+ indices_to_remove = set()
271
+ allocatable_tokens = self._allocatable_tokens(count_retracted=False)
272
+
273
+ for i, req in enumerate(self.retracted_queue):
274
+ if self.req_to_token_pool.available_size() <= 0:
275
+ break
276
+
277
+ required_tokens_for_request = (
278
+ len(req.origin_input_ids)
279
+ + len(req.output_ids)
280
+ + self.num_reserved_decode_tokens
281
+ )
282
+ if required_tokens_for_request > allocatable_tokens:
283
+ break
284
+
285
+ resumed_reqs.append(req)
286
+ indices_to_remove.add(i)
287
+ req.is_retracted = False
288
+ self._pre_alloc(req)
289
+ allocatable_tokens -= required_tokens_for_request
290
+
291
+ # load from cpu, release the cpu copy
292
+ req.load_kv_cache(self.req_to_token_pool, self.token_to_kv_pool_allocator)
293
+
294
+ self.retracted_queue = [
295
+ entry
296
+ for i, entry in enumerate(self.retracted_queue)
297
+ if i not in indices_to_remove
298
+ ]
299
+
300
+ return resumed_reqs
229
301
 
230
302
  def _update_handshake_waiters(self) -> None:
231
303
  if not self.queue:
@@ -255,6 +327,8 @@ class DecodePreallocQueue:
255
327
  error_message,
256
328
  status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
257
329
  )
330
+ else:
331
+ raise ValueError(f"Unexpected poll case: {poll}")
258
332
 
259
333
  def pop_preallocated(self) -> List[DecodeRequest]:
260
334
  """Pop the preallocated requests from the pending queue (FIFO)."""
@@ -262,8 +336,16 @@ class DecodePreallocQueue:
262
336
 
263
337
  preallocated_reqs = []
264
338
  indices_to_remove = set()
265
- allocatable_tokens = self._allocatable_tokens()
266
339
 
340
+ # We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request
341
+ # Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted.
342
+ retractable_tokens = sum(
343
+ len(r.origin_input_ids) + len(r.output_ids)
344
+ for r in self.scheduler.running_batch.reqs
345
+ )
346
+ allocatable_tokens = self._allocatable_tokens(
347
+ retractable_tokens=retractable_tokens, count_retracted=True
348
+ )
267
349
  # First, remove all failed requests from the queue
268
350
  for i, decode_req in enumerate(self.queue):
269
351
  if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
@@ -272,6 +354,7 @@ class DecodePreallocQueue:
272
354
  )
273
355
  indices_to_remove.add(i)
274
356
 
357
+ # Then, preallocate the remaining requests if possible
275
358
  for i, decode_req in enumerate(self.queue):
276
359
  if i in indices_to_remove:
277
360
  continue
@@ -285,10 +368,23 @@ class DecodePreallocQueue:
285
368
  if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
286
369
  break
287
370
 
371
+ # Memory estimation: don't add if the projected memory cannot be met
372
+ # TODO: add new_token ratio
373
+ origin_input_len = len(decode_req.req.origin_input_ids)
288
374
  required_tokens_for_request = (
289
- len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens
375
+ origin_input_len + self.num_reserved_decode_tokens
290
376
  )
291
377
 
378
+ if (
379
+ max(
380
+ required_tokens_for_request,
381
+ origin_input_len
382
+ + decode_req.req.sampling_params.max_new_tokens
383
+ - retractable_tokens,
384
+ )
385
+ > allocatable_tokens
386
+ ):
387
+ break
292
388
  if required_tokens_for_request > allocatable_tokens:
293
389
  break
294
390
 
@@ -301,7 +397,6 @@ class DecodePreallocQueue:
301
397
  ]
302
398
  .cpu()
303
399
  .numpy()
304
- .astype(np.int64)
305
400
  )
306
401
 
307
402
  decode_req.metadata_buffer_index = (
@@ -321,15 +416,35 @@ class DecodePreallocQueue:
321
416
 
322
417
  return preallocated_reqs
323
418
 
324
- def _allocatable_tokens(self) -> int:
325
- allocatable_tokens = (
326
- self.token_to_kv_pool_allocator.available_size()
327
- - self.num_reserved_decode_tokens
419
+ def _allocatable_tokens(
420
+ self, retractable_tokens: Optional[int] = None, count_retracted: bool = True
421
+ ) -> int:
422
+ need_space_for_single_req = (
423
+ max(
424
+ [
425
+ x.sampling_params.max_new_tokens
426
+ + len(x.origin_input_ids)
427
+ - retractable_tokens
428
+ for x in self.scheduler.running_batch.reqs
429
+ ]
430
+ )
431
+ if retractable_tokens is not None
432
+ and len(self.scheduler.running_batch.reqs) > 0
433
+ else 0
434
+ )
435
+
436
+ available_size = self.token_to_kv_pool_allocator.available_size()
437
+
438
+ allocatable_tokens = available_size - max(
439
+ # preserve some space for future decode
440
+ self.num_reserved_decode_tokens
328
441
  * (
329
442
  len(self.scheduler.running_batch.reqs)
330
443
  + len(self.transfer_queue.queue)
331
444
  + len(self.scheduler.waiting_queue)
332
- )
445
+ ),
446
+ # make sure each request can finish if reach max_tokens with all other requests retracted
447
+ need_space_for_single_req,
333
448
  )
334
449
 
335
450
  # Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
@@ -342,15 +457,27 @@ class DecodePreallocQueue:
342
457
  self.scheduler.last_batch.reqs
343
458
  )
344
459
 
460
+ if count_retracted:
461
+ allocatable_tokens -= sum(
462
+ [
463
+ len(req.origin_input_ids)
464
+ + len(req.output_ids)
465
+ + self.num_reserved_decode_tokens
466
+ for req in self.retracted_queue
467
+ ]
468
+ )
345
469
  return allocatable_tokens
346
470
 
347
471
  def _pre_alloc(self, req: Req) -> torch.Tensor:
348
472
  """Pre-allocate the memory for req_to_token and token_kv_pool"""
349
473
  req_pool_indices = self.req_to_token_pool.alloc(1)
350
474
 
351
- assert req_pool_indices is not None
475
+ assert (
476
+ req_pool_indices is not None
477
+ ), "req_pool_indices is full! There is a bug in memory estimation."
352
478
 
353
479
  req.req_pool_idx = req_pool_indices[0]
480
+
354
481
  if self.token_to_kv_pool_allocator.page_size == 1:
355
482
  kv_loc = self.token_to_kv_pool_allocator.alloc(
356
483
  len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
@@ -375,7 +502,10 @@ class DecodePreallocQueue:
375
502
  ),
376
503
  extend_num_tokens=num_tokens,
377
504
  )
378
- assert kv_loc is not None
505
+
506
+ assert (
507
+ kv_loc is not None
508
+ ), "KV cache is full! There is a bug in memory estimation."
379
509
 
380
510
  self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
381
511
 
@@ -395,6 +525,7 @@ class DecodeTransferQueue:
395
525
  self,
396
526
  gloo_group: ProcessGroup,
397
527
  req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
528
+ tp_rank: int,
398
529
  metadata_buffers: MetadataBuffers,
399
530
  scheduler: Scheduler,
400
531
  tree_cache: BasePrefixCache,
@@ -402,9 +533,11 @@ class DecodeTransferQueue:
402
533
  self.queue: List[DecodeRequest] = []
403
534
  self.gloo_group = gloo_group
404
535
  self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
536
+ self.tp_rank = tp_rank
405
537
  self.metadata_buffers = metadata_buffers
406
538
  self.scheduler = scheduler
407
539
  self.tree_cache = tree_cache
540
+ self.spec_algorithm = scheduler.spec_algorithm
408
541
 
409
542
  def add(self, decode_req: DecodeRequest) -> None:
410
543
  self.queue.append(decode_req)
@@ -412,10 +545,9 @@ class DecodeTransferQueue:
412
545
  def extend(self, decode_reqs: List[DecodeRequest]) -> None:
413
546
  self.queue.extend(decode_reqs)
414
547
 
415
- def pop_transferred(self) -> List[DecodeRequest]:
548
+ def pop_transferred(self) -> List[Req]:
416
549
  if not self.queue:
417
550
  return []
418
-
419
551
  polls = poll_and_all_reduce(
420
552
  [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
421
553
  )
@@ -424,7 +556,7 @@ class DecodeTransferQueue:
424
556
  indices_to_remove = set()
425
557
  for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
426
558
  if poll == KVPoll.Failed:
427
- error_message = f"Decode transfer failed for request rank={self.scheduler.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
559
+ error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
428
560
  try:
429
561
  decode_req.kv_receiver.failure_exception()
430
562
  except Exception as e:
@@ -447,6 +579,7 @@ class DecodeTransferQueue:
447
579
  idx = decode_req.metadata_buffer_index
448
580
  (
449
581
  output_id,
582
+ output_hidden_states,
450
583
  output_token_logprobs_val,
451
584
  output_token_logprobs_idx,
452
585
  output_top_logprobs_val,
@@ -454,7 +587,8 @@ class DecodeTransferQueue:
454
587
  ) = self.metadata_buffers.get_buf(idx)
455
588
 
456
589
  decode_req.req.output_ids.append(output_id[0].item())
457
-
590
+ if not self.spec_algorithm.is_none():
591
+ decode_req.req.hidden_states_tensor = output_hidden_states
458
592
  if decode_req.req.return_logprob:
459
593
  decode_req.req.output_token_logprobs_val.append(
460
594
  output_token_logprobs_val[0].item()
@@ -499,15 +633,6 @@ class DecodeTransferQueue:
499
633
 
500
634
  class SchedulerDisaggregationDecodeMixin:
501
635
 
502
- def _prepare_idle_batch_and_run(self, batch, delay_process=False):
503
- batch, _ = self.prepare_dp_attn_batch(batch)
504
- result = None
505
- if batch:
506
- result = self.run_batch(batch)
507
- if not delay_process:
508
- self.process_batch_result(batch, result)
509
- return batch, result
510
-
511
636
  @torch.no_grad()
512
637
  def event_loop_normal_disagg_decode(self: Scheduler):
513
638
  """A normal scheduler loop for decode worker in disaggregation mode."""
@@ -520,10 +645,7 @@ class SchedulerDisaggregationDecodeMixin:
520
645
  batch = self.get_next_disagg_decode_batch_to_run()
521
646
  self.cur_batch = batch
522
647
 
523
- prepare_dp_attn_flag = (
524
- self.server_args.enable_dp_attention
525
- or self.server_args.enable_sp_layernorm
526
- )
648
+ prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
527
649
 
528
650
  if batch:
529
651
  # Generate fake extend output.
@@ -532,24 +654,26 @@ class SchedulerDisaggregationDecodeMixin:
532
654
  self.stream_output(
533
655
  batch.reqs, any(req.return_logprob for req in batch.reqs)
534
656
  )
535
- if prepare_dp_attn_flag:
657
+ if prepare_mlp_sync_flag:
536
658
  self._prepare_idle_batch_and_run(None)
537
659
  else:
538
- if prepare_dp_attn_flag:
539
- self.prepare_dp_attn_batch(batch)
660
+ if prepare_mlp_sync_flag:
661
+ self.prepare_mlp_sync_batch(batch)
540
662
  result = self.run_batch(batch)
541
663
  self.process_batch_result(batch, result)
542
- elif prepare_dp_attn_flag:
664
+ elif prepare_mlp_sync_flag:
543
665
  batch, _ = self._prepare_idle_batch_and_run(None)
544
666
 
545
667
  if batch is None and (
546
- len(self.disagg_decode_transfer_queue.queue)
668
+ len(self.waiting_queue)
669
+ + len(self.disagg_decode_transfer_queue.queue)
547
670
  + len(self.disagg_decode_prealloc_queue.queue)
548
671
  == 0
549
672
  ):
550
673
  # When the server is idle, do self-check and re-init some states
551
674
  self.check_memory()
552
675
  self.new_token_ratio = self.init_new_token_ratio
676
+ self.maybe_sleep_on_idle()
553
677
 
554
678
  self.last_batch = batch
555
679
 
@@ -568,10 +692,7 @@ class SchedulerDisaggregationDecodeMixin:
568
692
  self.cur_batch = batch
569
693
  last_batch_in_queue = False
570
694
 
571
- prepare_dp_attn_flag = (
572
- self.server_args.enable_dp_attention
573
- or self.server_args.enable_sp_layernorm
574
- )
695
+ prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
575
696
 
576
697
  if batch:
577
698
  # Generate fake extend output.
@@ -580,7 +701,7 @@ class SchedulerDisaggregationDecodeMixin:
580
701
  self.stream_output(
581
702
  batch.reqs, any(req.return_logprob for req in batch.reqs)
582
703
  )
583
- if prepare_dp_attn_flag:
704
+ if prepare_mlp_sync_flag:
584
705
  batch_, result = self._prepare_idle_batch_and_run(
585
706
  None, delay_process=True
586
707
  )
@@ -588,8 +709,8 @@ class SchedulerDisaggregationDecodeMixin:
588
709
  result_queue.append((batch_.copy(), result))
589
710
  last_batch_in_queue = True
590
711
  else:
591
- if prepare_dp_attn_flag:
592
- self.prepare_dp_attn_batch(batch)
712
+ if prepare_mlp_sync_flag:
713
+ self.prepare_mlp_sync_batch(batch)
593
714
  result = self.run_batch(batch)
594
715
  result_queue.append((batch.copy(), result))
595
716
 
@@ -604,7 +725,7 @@ class SchedulerDisaggregationDecodeMixin:
604
725
  self.set_next_batch_sampling_info_done(tmp_batch)
605
726
  last_batch_in_queue = True
606
727
 
607
- elif prepare_dp_attn_flag:
728
+ elif prepare_mlp_sync_flag:
608
729
  batch, result = self._prepare_idle_batch_and_run(
609
730
  None, delay_process=True
610
731
  )
@@ -621,17 +742,28 @@ class SchedulerDisaggregationDecodeMixin:
621
742
  self.process_batch_result(tmp_batch, tmp_result)
622
743
 
623
744
  if batch is None and (
624
- len(self.disagg_decode_transfer_queue.queue)
745
+ len(self.waiting_queue)
746
+ + len(self.disagg_decode_transfer_queue.queue)
625
747
  + len(self.disagg_decode_prealloc_queue.queue)
626
748
  == 0
627
749
  ):
628
750
  # When the server is idle, do self-check and re-init some states
629
751
  self.check_memory()
630
752
  self.new_token_ratio = self.init_new_token_ratio
753
+ self.maybe_sleep_on_idle()
631
754
 
632
755
  self.last_batch = batch
633
756
  self.last_batch_in_queue = last_batch_in_queue
634
757
 
758
+ def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
759
+ batch, _ = self.prepare_mlp_sync_batch(batch)
760
+ result = None
761
+ if batch:
762
+ result = self.run_batch(batch)
763
+ if not delay_process:
764
+ self.process_batch_result(batch, result)
765
+ return batch, result
766
+
635
767
  def get_next_disagg_decode_batch_to_run(
636
768
  self: Scheduler,
637
769
  ) -> Optional[Tuple[ScheduleBatch, bool]]:
@@ -714,6 +846,13 @@ class SchedulerDisaggregationDecodeMixin:
714
846
  return new_batch
715
847
 
716
848
  def process_decode_queue(self: Scheduler):
849
+ # try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
850
+ resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
851
+ self.waiting_queue.extend(resumed_reqs)
852
+ if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0:
853
+ # if there are still retracted requests, we do not allocate new requests
854
+ return
855
+
717
856
  req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
718
857
  self.disagg_decode_transfer_queue.extend(req_conns)
719
858
  alloc_reqs = (
@@ -126,15 +126,16 @@ class ScheduleBatchDisaggregationDecodeMixin:
126
126
  )
127
127
  topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
128
128
 
129
+ hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
130
+ hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
131
+
129
132
  # local import to avoid circular import
130
133
  from sglang.srt.speculative.eagle_utils import EagleDraftInput
131
134
 
132
135
  spec_info = EagleDraftInput(
133
136
  topk_p=topk_p,
134
137
  topk_index=topk_index,
135
- hidden_states=torch.ones(
136
- (b, model_config.hidden_size), device=self.device
137
- ),
138
+ hidden_states=hidden_states,
138
139
  verified_id=self.output_ids,
139
140
  )
140
141
  spec_info.prepare_for_extend(self)
@@ -1 +1 @@
1
- from .conn import FakeKVReceiver, FakeKVSender
1
+ from sglang.srt.disaggregation.fake.conn import FakeKVReceiver, FakeKVSender
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Dict, List, Optional, Tuple, Union
2
+ from typing import List, Optional
3
3
 
4
4
  import numpy as np
5
5
  import numpy.typing as npt
@@ -8,7 +8,6 @@ from sglang.srt.disaggregation.base.conn import (
8
8
  BaseKVManager,
9
9
  BaseKVReceiver,
10
10
  BaseKVSender,
11
- KVArgs,
12
11
  KVPoll,
13
12
  )
14
13
 
@@ -17,7 +16,14 @@ logger = logging.getLogger(__name__)
17
16
 
18
17
  # For warmup reqs, we don't kv transfer, we use the fake sender and receiver
19
18
  class FakeKVSender(BaseKVSender):
20
- def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int):
19
+ def __init__(
20
+ self,
21
+ mgr: BaseKVManager,
22
+ bootstrap_addr: str,
23
+ bootstrap_room: int,
24
+ dest_tp_ranks: List[int],
25
+ pp_rank: int,
26
+ ):
21
27
  self.has_sent = False
22
28
 
23
29
  def poll(self) -> KVPoll:
@@ -26,7 +32,7 @@ class FakeKVSender(BaseKVSender):
26
32
  return KVPoll.WaitingForInput
27
33
  else:
28
34
  # Assume transfer completed instantly
29
- logger.info("FakeKVSender poll success")
35
+ logger.debug("FakeKVSender poll success")
30
36
  return KVPoll.Success
31
37
 
32
38
  def init(
@@ -34,17 +40,17 @@ class FakeKVSender(BaseKVSender):
34
40
  kv_indices: list[int],
35
41
  aux_index: Optional[int] = None,
36
42
  ):
37
- logger.info(
43
+ logger.debug(
38
44
  f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
39
45
  )
40
46
  pass
41
47
 
42
48
  def send(
43
49
  self,
44
- kv_indices: npt.NDArray[np.int64],
50
+ kv_indices: npt.NDArray[np.int32],
45
51
  ):
46
52
  self.has_sent = True
47
- logger.info(f"FakeKVSender send with kv_indices: {kv_indices}")
53
+ logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
48
54
 
49
55
  def failure_exception(self):
50
56
  raise Exception("Fake KVSender Exception")
@@ -66,12 +72,12 @@ class FakeKVReceiver(BaseKVReceiver):
66
72
  return KVPoll.WaitingForInput
67
73
  else:
68
74
  # Assume transfer completed instantly
69
- logger.info("FakeKVReceiver poll success")
75
+ logger.debug("FakeKVReceiver poll success")
70
76
  return KVPoll.Success
71
77
 
72
78
  def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
73
79
  self.has_init = True
74
- logger.info(
80
+ logger.debug(
75
81
  f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
76
82
  )
77
83