sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,7 @@ import logging
24
24
  import os
25
25
  from collections import deque
26
26
  from dataclasses import dataclass
27
+ from http import HTTPStatus
27
28
  from typing import TYPE_CHECKING, List, Optional, Tuple
28
29
 
29
30
  import numpy as np
@@ -35,24 +36,25 @@ from sglang.srt.disaggregation.utils import (
35
36
  DisaggregationMode,
36
37
  FakeBootstrapHost,
37
38
  KVClassType,
39
+ MetadataBuffers,
38
40
  ReqToMetadataIdxAllocator,
39
41
  TransferBackend,
40
42
  get_kv_class,
43
+ is_mla_backend,
41
44
  kv_to_page_indices,
42
45
  poll_and_all_reduce,
46
+ prepare_abort,
43
47
  )
48
+ from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
44
49
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
45
50
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
46
51
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
47
- from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
48
52
 
49
53
  logger = logging.getLogger(__name__)
50
54
 
51
55
  if TYPE_CHECKING:
52
- from sglang.srt.configs.model_config import ModelConfig
53
- from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
56
+ from sglang.srt.managers.schedule_batch import Req
54
57
  from sglang.srt.managers.scheduler import Scheduler
55
- from sglang.srt.server_args import ServerArgs
56
58
 
57
59
 
58
60
  @dataclass
@@ -72,9 +74,9 @@ class DecodePreallocQueue:
72
74
  self,
73
75
  req_to_token_pool: ReqToTokenPool,
74
76
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
77
+ draft_token_to_kv_pool: Optional[KVCache],
75
78
  req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
76
- metadata_buffers: List[torch.Tensor],
77
- aux_dtype: torch.dtype,
79
+ metadata_buffers: MetadataBuffers,
78
80
  scheduler: Scheduler,
79
81
  transfer_queue: DecodeTransferQueue,
80
82
  tree_cache: BasePrefixCache,
@@ -87,7 +89,8 @@ class DecodePreallocQueue:
87
89
  self.req_to_token_pool = req_to_token_pool
88
90
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
89
91
  self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
90
- self.aux_dtype = aux_dtype
92
+ self.draft_token_to_kv_pool = draft_token_to_kv_pool
93
+ self.is_mla_backend = is_mla_backend(self.token_to_kv_pool)
91
94
  self.metadata_buffers = metadata_buffers
92
95
  self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
93
96
  self.scheduler = scheduler
@@ -114,24 +117,29 @@ class DecodePreallocQueue:
114
117
  self.token_to_kv_pool.get_contiguous_buf_infos()
115
118
  )
116
119
 
120
+ if self.draft_token_to_kv_pool is not None:
121
+ draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
122
+ self.draft_token_to_kv_pool.get_contiguous_buf_infos()
123
+ )
124
+ kv_data_ptrs += draft_kv_data_ptrs
125
+ kv_data_lens += draft_kv_data_lens
126
+ kv_item_lens += draft_kv_item_lens
127
+
117
128
  kv_args.kv_data_ptrs = kv_data_ptrs
118
129
  kv_args.kv_data_lens = kv_data_lens
119
130
  kv_args.kv_item_lens = kv_item_lens
120
131
 
121
- kv_args.aux_data_ptrs = [
122
- output_id_tensor.data_ptr() for output_id_tensor in self.metadata_buffers
123
- ]
124
- kv_args.aux_data_lens = [
125
- metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
126
- ]
127
- kv_args.aux_item_lens = [
128
- metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
129
- ]
132
+ kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
133
+ self.metadata_buffers.get_buf_infos()
134
+ )
130
135
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
131
136
  kv_args.gpu_id = self.scheduler.gpu_id
132
137
  kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
133
138
  kv_manager = kv_manager_class(
134
- kv_args, DisaggregationMode.DECODE, self.scheduler.server_args
139
+ kv_args,
140
+ DisaggregationMode.DECODE,
141
+ self.scheduler.server_args,
142
+ self.is_mla_backend,
135
143
  )
136
144
  return kv_manager
137
145
 
@@ -173,7 +181,17 @@ class DecodePreallocQueue:
173
181
  elif poll == KVPoll.WaitingForInput:
174
182
  decode_req.waiting_for_input = True
175
183
  elif poll == KVPoll.Failed:
176
- raise Exception("Handshake failed")
184
+ error_message = f"Decode handshake failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
185
+ try:
186
+ decode_req.kv_receiver.failure_exception()
187
+ except Exception as e:
188
+ error_message += f" with exception {e}"
189
+ logger.error(error_message)
190
+ prepare_abort(
191
+ decode_req.req,
192
+ error_message,
193
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
194
+ )
177
195
 
178
196
  def pop_preallocated(self) -> List[DecodeRequest]:
179
197
  """Pop the preallocated requests from the pending queue (FIFO)."""
@@ -183,7 +201,18 @@ class DecodePreallocQueue:
183
201
  indices_to_remove = set()
184
202
  allocatable_tokens = self._allocatable_tokens()
185
203
 
204
+ # First, remove all failed requests from the queue
186
205
  for i, decode_req in enumerate(self.queue):
206
+ if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
207
+ self.scheduler.stream_output(
208
+ [decode_req.req], decode_req.req.return_logprob
209
+ )
210
+ indices_to_remove.add(i)
211
+
212
+ for i, decode_req in enumerate(self.queue):
213
+ if i in indices_to_remove:
214
+ continue
215
+
187
216
  if not decode_req.waiting_for_input:
188
217
  continue
189
218
 
@@ -303,18 +332,22 @@ class DecodeTransferQueue:
303
332
  self,
304
333
  gloo_group: ProcessGroup,
305
334
  req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
306
- metadata_buffers: torch.Tensor,
335
+ metadata_buffers: MetadataBuffers,
336
+ scheduler: Scheduler,
337
+ tree_cache: BasePrefixCache,
307
338
  ):
308
339
  self.queue: List[DecodeRequest] = []
309
340
  self.gloo_group = gloo_group
310
341
  self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
311
342
  self.metadata_buffers = metadata_buffers
343
+ self.scheduler = scheduler
344
+ self.tree_cache = tree_cache
312
345
 
313
- def add(self, req_conn: DecodeRequest) -> None:
314
- self.queue.append(req_conn)
346
+ def add(self, decode_req: DecodeRequest) -> None:
347
+ self.queue.append(decode_req)
315
348
 
316
- def extend(self, req_conns) -> None:
317
- self.queue.extend(req_conns)
349
+ def extend(self, decode_reqs: List[DecodeRequest]) -> None:
350
+ self.queue.extend(decode_reqs)
318
351
 
319
352
  def pop_transferred(self) -> List[DecodeRequest]:
320
353
  if not self.queue:
@@ -328,18 +361,56 @@ class DecodeTransferQueue:
328
361
  indices_to_remove = set()
329
362
  for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
330
363
  if poll == KVPoll.Failed:
331
- raise Exception("Transfer failed")
364
+ error_message = f"Decode transfer failed for request {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
365
+ try:
366
+ decode_req.kv_receiver.failure_exception()
367
+ except Exception as e:
368
+ error_message += f" with exception {e}"
369
+ logger.error(error_message)
370
+ prepare_abort(
371
+ decode_req.req,
372
+ error_message,
373
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
374
+ )
375
+ self.scheduler.stream_output(
376
+ [decode_req.req], decode_req.req.return_logprob
377
+ )
378
+ # unlock the kv cache or it will have memory leak
379
+ self.tree_cache.cache_finished_req(decode_req.req)
380
+ indices_to_remove.add(i)
381
+ continue
332
382
  elif poll == KVPoll.Success:
333
- # pop and push it to waiting queue
383
+
334
384
  idx = decode_req.metadata_buffer_index
335
- assert len(decode_req.req.output_ids) == 0
336
- output_id_buffer = self.metadata_buffers[0]
337
- # the last dimension is padded by the same values.
338
- output_id = output_id_buffer[idx][0].item()
339
- assert len(decode_req.req.output_ids) == 0
340
- assert decode_req.req.transferred_output_id is None
341
- decode_req.req.transferred_output_id = output_id
342
- transferred_reqs.append(decode_req)
385
+ (
386
+ output_id,
387
+ output_token_logprobs_val,
388
+ output_token_logprobs_idx,
389
+ output_top_logprobs_val,
390
+ output_top_logprobs_idx,
391
+ ) = self.metadata_buffers.get_buf(idx)
392
+
393
+ decode_req.req.output_ids.append(output_id[0].item())
394
+
395
+ if decode_req.req.return_logprob:
396
+ decode_req.req.output_token_logprobs_val.append(
397
+ output_token_logprobs_val[0].item()
398
+ )
399
+ decode_req.req.output_token_logprobs_idx.append(
400
+ output_token_logprobs_idx[0].item()
401
+ )
402
+ decode_req.req.output_top_logprobs_val.append(
403
+ output_top_logprobs_val[
404
+ : decode_req.req.top_logprobs_num
405
+ ].tolist()
406
+ )
407
+ decode_req.req.output_top_logprobs_idx.append(
408
+ output_top_logprobs_idx[
409
+ : decode_req.req.top_logprobs_num
410
+ ].tolist()
411
+ )
412
+
413
+ transferred_reqs.append(decode_req.req)
343
414
  indices_to_remove.add(i)
344
415
  elif poll in [
345
416
  KVPoll.Bootstrapping,
@@ -362,95 +433,6 @@ class DecodeTransferQueue:
362
433
  return transferred_reqs
363
434
 
364
435
 
365
- class ScheduleBatchDisaggregationDecodeMixin:
366
-
367
- def prepare_for_prebuilt_extend(self: ScheduleBatch):
368
- """
369
- Prepare a prebuilt extend by populate metadata
370
- Adapted from .prepare_for_extend().
371
- """
372
-
373
- self.forward_mode = ForwardMode.EXTEND
374
- reqs = self.reqs
375
- input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
376
- extend_num_tokens = sum(len(ids) for ids in input_ids)
377
- seq_lens = []
378
- pre_lens = []
379
- req_pool_indices = []
380
-
381
- # Pre-calculate total size
382
- total_size = sum(req.extend_input_len for req in reqs)
383
- out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
384
-
385
- # Fill the tensor in one pass
386
- offset = 0
387
- for i, req in enumerate(reqs):
388
- req_pool_indices.append(req.req_pool_idx)
389
-
390
- chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
391
- : req.extend_input_len
392
- ]
393
- assert (
394
- offset + req.extend_input_len <= total_size
395
- ), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
396
- out_cache_loc[offset : offset + req.extend_input_len] = chunk
397
- offset += req.extend_input_len
398
-
399
- pre_len = len(req.prefix_indices)
400
- seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
401
- seq_lens.append(seq_len)
402
- if len(req.output_ids) == 0:
403
- assert (
404
- seq_len - pre_len == req.extend_input_len
405
- ), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
406
-
407
- req.cached_tokens += pre_len - req.already_computed
408
- req.already_computed = seq_len
409
- req.is_retracted = False
410
- pre_lens.append(pre_len)
411
- req.extend_logprob_start_len = 0
412
-
413
- extend_input_logprob_token_ids = None
414
-
415
- # Set fields
416
- self.input_ids = torch.tensor(
417
- sum(input_ids, []), dtype=torch.int32, device=self.device
418
- )
419
- self.req_pool_indices = torch.tensor(
420
- req_pool_indices, dtype=torch.int64, device=self.device
421
- )
422
- self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
423
- self.out_cache_loc = out_cache_loc
424
- self.seq_lens_sum = sum(seq_lens)
425
- self.extend_num_tokens = extend_num_tokens
426
- self.prefix_lens = [len(r.prefix_indices) for r in reqs]
427
- self.extend_lens = [r.extend_input_len for r in reqs]
428
- self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
429
- self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
430
-
431
- # Build sampling info
432
- self.sampling_info = SamplingBatchInfo.from_schedule_batch(
433
- self,
434
- self.model_config.vocab_size,
435
- )
436
-
437
- def process_prebuilt_extend(
438
- self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
439
- ):
440
- """Assign the buffered last input id to schedule batch"""
441
- self.output_ids = []
442
- for req in self.reqs:
443
- if req.output_ids and len(req.output_ids) > 0:
444
- # resumed retracted req
445
- self.output_ids.append(req.output_ids[-1])
446
- else:
447
- assert req.transferred_output_id is not None
448
- req.output_ids.append(req.transferred_output_id)
449
- self.output_ids.append(req.transferred_output_id)
450
- self.tree_cache.cache_unfinished_req(req)
451
- self.output_ids = torch.tensor(self.output_ids, device=self.device)
452
-
453
-
454
436
  class SchedulerDisaggregationDecodeMixin:
455
437
 
456
438
  def _prepare_idle_batch_and_run(self, batch, delay_process=False):
@@ -483,7 +465,9 @@ class SchedulerDisaggregationDecodeMixin:
483
465
  # Generate fake extend output.
484
466
  if batch.forward_mode.is_extend():
485
467
  # Note: Logprobs should be handled on the prefill engine.
486
- self.stream_output(batch.reqs, False)
468
+ self.stream_output(
469
+ batch.reqs, any(req.return_logprob for req in batch.reqs)
470
+ )
487
471
  if prepare_dp_attn_flag:
488
472
  self._prepare_idle_batch_and_run(None)
489
473
  else:
@@ -509,7 +493,7 @@ class SchedulerDisaggregationDecodeMixin:
509
493
  def event_loop_overlap_disagg_decode(self: Scheduler):
510
494
  result_queue = deque()
511
495
  self.last_batch: Optional[ScheduleBatch] = None
512
- self.last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend
496
+ self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
513
497
 
514
498
  while True:
515
499
  recv_reqs = self.recv_requests()
@@ -529,7 +513,9 @@ class SchedulerDisaggregationDecodeMixin:
529
513
  # Generate fake extend output.
530
514
  if batch.forward_mode.is_extend():
531
515
  # Note: Logprobs should be handled on the prefill engine.
532
- self.stream_output(batch.reqs, False)
516
+ self.stream_output(
517
+ batch.reqs, any(req.return_logprob for req in batch.reqs)
518
+ )
533
519
  if prepare_dp_attn_flag:
534
520
  batch_, result = self._prepare_idle_batch_and_run(
535
521
  None, delay_process=True
@@ -542,7 +528,18 @@ class SchedulerDisaggregationDecodeMixin:
542
528
  self.prepare_dp_attn_batch(batch)
543
529
  result = self.run_batch(batch)
544
530
  result_queue.append((batch.copy(), result))
531
+
532
+ if (self.last_batch is None) or (not self.last_batch_in_queue):
533
+ # Create a dummy first batch to start the pipeline for overlap schedule.
534
+ # It is now used for triggering the sampling_info_done event.
535
+ tmp_batch = ScheduleBatch(
536
+ reqs=None,
537
+ forward_mode=ForwardMode.DUMMY_FIRST,
538
+ next_batch_sampling_info=self.tp_worker.cur_sampling_info,
539
+ )
540
+ self.set_next_batch_sampling_info_done(tmp_batch)
545
541
  last_batch_in_queue = True
542
+
546
543
  elif prepare_dp_attn_flag:
547
544
  batch, result = self._prepare_idle_batch_and_run(
548
545
  None, delay_process=True
@@ -554,6 +551,9 @@ class SchedulerDisaggregationDecodeMixin:
554
551
  # Process the results of the previous batch but skip if the last batch is extend
555
552
  if self.last_batch and self.last_batch_in_queue:
556
553
  tmp_batch, tmp_result = result_queue.popleft()
554
+ tmp_batch.next_batch_sampling_info = (
555
+ self.tp_worker.cur_sampling_info if batch else None
556
+ )
557
557
  self.process_batch_result(tmp_batch, tmp_result)
558
558
 
559
559
  if batch is None and (
@@ -602,6 +602,9 @@ class SchedulerDisaggregationDecodeMixin:
602
602
 
603
603
  def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
604
604
  """Create a schedulebatch for fake completed prefill"""
605
+ if self.grammar_queue:
606
+ self.move_ready_grammar_requests()
607
+
605
608
  if len(self.waiting_queue) == 0:
606
609
  return None
607
610
 
@@ -627,8 +630,6 @@ class SchedulerDisaggregationDecodeMixin:
627
630
  self.waiting_queue = waiting_queue
628
631
  if len(can_run_list) == 0:
629
632
  return None
630
- # local import to avoid circular import
631
- from sglang.srt.managers.schedule_batch import ScheduleBatch
632
633
 
633
634
  # construct a schedule batch with those requests and mark as decode
634
635
  new_batch = ScheduleBatch.init_new(
@@ -650,15 +651,8 @@ class SchedulerDisaggregationDecodeMixin:
650
651
 
651
652
  def process_decode_queue(self: Scheduler):
652
653
  req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
653
-
654
- def _num_pre_alloc(req):
655
- return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0)
656
-
657
- self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns)
658
654
  self.disagg_decode_transfer_queue.extend(req_conns)
659
655
  alloc_reqs = (
660
656
  self.disagg_decode_transfer_queue.pop_transferred()
661
657
  ) # the requests which kv has arrived
662
- self.num_tokens_pre_allocated -= sum(_num_pre_alloc(req) for req in alloc_reqs)
663
-
664
- self.waiting_queue.extend([req.req for req in alloc_reqs])
658
+ self.waiting_queue.extend(alloc_reqs)
@@ -0,0 +1,142 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ import torch
7
+
8
+ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
9
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ if TYPE_CHECKING:
14
+ from sglang.srt.configs.model_config import ModelConfig
15
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
16
+ from sglang.srt.server_args import ServerArgs
17
+
18
+
19
+ class ScheduleBatchDisaggregationDecodeMixin:
20
+
21
+ def prepare_for_prebuilt_extend(self: ScheduleBatch):
22
+ """
23
+ Prepare a prebuilt extend by populate metadata
24
+ Adapted from .prepare_for_extend().
25
+ """
26
+
27
+ self.forward_mode = ForwardMode.EXTEND
28
+ reqs = self.reqs
29
+ input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
30
+ extend_num_tokens = sum(len(ids) for ids in input_ids)
31
+ seq_lens = []
32
+ pre_lens = []
33
+ req_pool_indices = []
34
+
35
+ # Pre-calculate total size
36
+ total_size = sum(req.extend_input_len for req in reqs)
37
+ out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
38
+
39
+ # Fill the tensor in one pass
40
+ offset = 0
41
+ for i, req in enumerate(reqs):
42
+ req_pool_indices.append(req.req_pool_idx)
43
+
44
+ chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
45
+ : req.extend_input_len
46
+ ]
47
+ assert (
48
+ offset + req.extend_input_len <= total_size
49
+ ), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
50
+ out_cache_loc[offset : offset + req.extend_input_len] = chunk
51
+ offset += req.extend_input_len
52
+
53
+ pre_len = len(req.prefix_indices)
54
+ seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
55
+ seq_lens.append(seq_len)
56
+ if len(req.output_ids) == 0:
57
+ assert (
58
+ seq_len - pre_len == req.extend_input_len
59
+ ), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
60
+
61
+ req.cached_tokens += pre_len - req.already_computed
62
+ req.already_computed = seq_len
63
+ req.is_retracted = False
64
+ pre_lens.append(pre_len)
65
+ req.extend_logprob_start_len = 0
66
+
67
+ extend_input_logprob_token_ids = None
68
+
69
+ # Set fields
70
+ self.input_ids = torch.tensor(
71
+ sum(input_ids, []), dtype=torch.int32, device=self.device
72
+ )
73
+ self.req_pool_indices = torch.tensor(
74
+ req_pool_indices, dtype=torch.int64, device=self.device
75
+ )
76
+ self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
77
+ self.out_cache_loc = out_cache_loc
78
+ self.seq_lens_sum = sum(seq_lens)
79
+
80
+ if self.return_logprob:
81
+ self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
82
+ self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
83
+
84
+ self.extend_num_tokens = extend_num_tokens
85
+ self.prefix_lens = [len(r.prefix_indices) for r in reqs]
86
+ self.extend_lens = [r.extend_input_len for r in reqs]
87
+ self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
88
+ self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
89
+
90
+ # Build sampling info
91
+ self.sampling_info = SamplingBatchInfo.from_schedule_batch(
92
+ self,
93
+ self.model_config.vocab_size,
94
+ )
95
+
96
+ def process_prebuilt_extend(
97
+ self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
98
+ ):
99
+ """Assign the buffered last input id to schedule batch"""
100
+ self.output_ids = []
101
+ for req in self.reqs:
102
+ self.output_ids.append(req.output_ids[-1])
103
+ self.tree_cache.cache_unfinished_req(req)
104
+ if req.grammar is not None:
105
+ req.grammar.accept_token(req.output_ids[-1])
106
+ req.grammar.finished = req.finished()
107
+ self.output_ids = torch.tensor(self.output_ids, device=self.device)
108
+
109
+ # Simulate the eagle run. We add mock data to hidden states for the
110
+ # ease of implementation now meaning the first token will have acc rate
111
+ # of 0.
112
+ if not self.spec_algorithm.is_none():
113
+
114
+ b = len(self.reqs)
115
+ topk_p = torch.arange(
116
+ b * server_args.speculative_eagle_topk,
117
+ 0,
118
+ -1,
119
+ device=self.device,
120
+ dtype=torch.float32,
121
+ )
122
+ topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
123
+ topk_p /= b * server_args.speculative_eagle_topk
124
+ topk_index = torch.arange(
125
+ b * server_args.speculative_eagle_topk, device=self.device
126
+ )
127
+ topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
128
+
129
+ # local import to avoid circular import
130
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput
131
+
132
+ spec_info = EagleDraftInput(
133
+ topk_p=topk_p,
134
+ topk_index=topk_index,
135
+ hidden_states=torch.ones(
136
+ (b, model_config.hidden_size), device=self.device
137
+ ),
138
+ verified_id=self.output_ids,
139
+ )
140
+ spec_info.prepare_for_extend(self)
141
+ spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
142
+ self.spec_info = spec_info
@@ -33,28 +33,18 @@ class FakeKVSender(BaseKVSender):
33
33
  self,
34
34
  kv_indices: list[int],
35
35
  aux_index: Optional[int] = None,
36
- dest_ranks: Optional[list[int]] = None,
37
36
  ):
38
37
  logger.info(
39
- f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}"
38
+ f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
40
39
  )
41
40
  pass
42
41
 
43
42
  def send(
44
43
  self,
45
44
  kv_indices: npt.NDArray[np.int64],
46
- index_slice: slice,
47
- is_last: bool,
48
45
  ):
49
- logger.info(
50
- f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}"
51
- )
52
- if is_last:
53
- self.has_sent = True
54
- logger.info(f"FakeKVSender send success")
55
- else:
56
- self.has_sent = False
57
- logger.info(f"FakeKVSender send fake transfering")
46
+ self.has_sent = True
47
+ logger.info(f"FakeKVSender send with kv_indices: {kv_indices}")
58
48
 
59
49
  def failure_exception(self):
60
50
  raise Exception("Fake KVSender Exception")