sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +23 -15
  4. sglang/bench_serving.py +133 -57
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/srt/configs/model_config.py +39 -28
  7. sglang/srt/conversation.py +1 -1
  8. sglang/srt/disaggregation/decode.py +122 -133
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  10. sglang/srt/disaggregation/fake/conn.py +3 -13
  11. sglang/srt/disaggregation/kv_events.py +357 -0
  12. sglang/srt/disaggregation/mini_lb.py +57 -24
  13. sglang/srt/disaggregation/mooncake/conn.py +11 -2
  14. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  15. sglang/srt/disaggregation/nixl/conn.py +9 -19
  16. sglang/srt/disaggregation/prefill.py +126 -44
  17. sglang/srt/disaggregation/utils.py +116 -5
  18. sglang/srt/distributed/utils.py +3 -3
  19. sglang/srt/entrypoints/EngineBase.py +5 -0
  20. sglang/srt/entrypoints/engine.py +28 -8
  21. sglang/srt/entrypoints/http_server.py +6 -4
  22. sglang/srt/entrypoints/http_server_engine.py +5 -2
  23. sglang/srt/function_call/base_format_detector.py +250 -0
  24. sglang/srt/function_call/core_types.py +34 -0
  25. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  26. sglang/srt/function_call/ebnf_composer.py +234 -0
  27. sglang/srt/function_call/function_call_parser.py +175 -0
  28. sglang/srt/function_call/llama32_detector.py +74 -0
  29. sglang/srt/function_call/mistral_detector.py +84 -0
  30. sglang/srt/function_call/pythonic_detector.py +163 -0
  31. sglang/srt/function_call/qwen25_detector.py +67 -0
  32. sglang/srt/function_call/utils.py +35 -0
  33. sglang/srt/hf_transformers_utils.py +46 -7
  34. sglang/srt/layers/attention/aiter_backend.py +513 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +63 -17
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  37. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  38. sglang/srt/layers/attention/triton_backend.py +3 -0
  39. sglang/srt/layers/attention/utils.py +2 -2
  40. sglang/srt/layers/attention/vision.py +1 -1
  41. sglang/srt/layers/communicator.py +451 -0
  42. sglang/srt/layers/dp_attention.py +0 -10
  43. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  44. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  45. sglang/srt/layers/moe/ep_moe/layer.py +104 -50
  46. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  48. sglang/srt/layers/moe/topk.py +66 -9
  49. sglang/srt/layers/multimodal.py +70 -0
  50. sglang/srt/layers/quantization/__init__.py +7 -2
  51. sglang/srt/layers/quantization/deep_gemm.py +5 -3
  52. sglang/srt/layers/quantization/fp8.py +90 -0
  53. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  54. sglang/srt/layers/quantization/gptq.py +298 -6
  55. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  56. sglang/srt/layers/quantization/qoq.py +244 -0
  57. sglang/srt/lora/lora_manager.py +1 -3
  58. sglang/srt/managers/deepseek_eplb.py +278 -0
  59. sglang/srt/managers/eplb_manager.py +55 -0
  60. sglang/srt/managers/expert_distribution.py +704 -56
  61. sglang/srt/managers/expert_location.py +394 -0
  62. sglang/srt/managers/expert_location_dispatch.py +91 -0
  63. sglang/srt/managers/io_struct.py +16 -3
  64. sglang/srt/managers/mm_utils.py +293 -139
  65. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  66. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  67. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  68. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  69. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  70. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  71. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  72. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  73. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  74. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  75. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  76. sglang/srt/managers/schedule_batch.py +49 -21
  77. sglang/srt/managers/schedule_policy.py +4 -5
  78. sglang/srt/managers/scheduler.py +92 -50
  79. sglang/srt/managers/session_controller.py +1 -1
  80. sglang/srt/managers/tokenizer_manager.py +99 -24
  81. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  82. sglang/srt/mem_cache/chunk_cache.py +3 -1
  83. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  84. sglang/srt/mem_cache/memory_pool.py +74 -52
  85. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  86. sglang/srt/mem_cache/radix_cache.py +58 -5
  87. sglang/srt/metrics/collector.py +2 -2
  88. sglang/srt/mm_utils.py +10 -0
  89. sglang/srt/model_executor/cuda_graph_runner.py +20 -9
  90. sglang/srt/model_executor/expert_location_updater.py +422 -0
  91. sglang/srt/model_executor/forward_batch_info.py +4 -0
  92. sglang/srt/model_executor/model_runner.py +144 -54
  93. sglang/srt/model_loader/loader.py +10 -6
  94. sglang/srt/models/clip.py +5 -1
  95. sglang/srt/models/deepseek_v2.py +297 -343
  96. sglang/srt/models/exaone.py +8 -3
  97. sglang/srt/models/gemma3_mm.py +70 -33
  98. sglang/srt/models/llama4.py +10 -2
  99. sglang/srt/models/llava.py +26 -18
  100. sglang/srt/models/mimo_mtp.py +220 -0
  101. sglang/srt/models/minicpmo.py +5 -12
  102. sglang/srt/models/mistral.py +71 -1
  103. sglang/srt/models/mllama.py +3 -3
  104. sglang/srt/models/qwen2.py +95 -26
  105. sglang/srt/models/qwen2_5_vl.py +8 -0
  106. sglang/srt/models/qwen2_moe.py +330 -60
  107. sglang/srt/models/qwen2_vl.py +6 -0
  108. sglang/srt/models/qwen3.py +52 -10
  109. sglang/srt/models/qwen3_moe.py +411 -48
  110. sglang/srt/models/siglip.py +294 -0
  111. sglang/srt/openai_api/adapter.py +28 -16
  112. sglang/srt/openai_api/protocol.py +6 -0
  113. sglang/srt/operations.py +154 -0
  114. sglang/srt/operations_strategy.py +31 -0
  115. sglang/srt/server_args.py +134 -24
  116. sglang/srt/speculative/eagle_utils.py +131 -0
  117. sglang/srt/speculative/eagle_worker.py +47 -2
  118. sglang/srt/utils.py +68 -12
  119. sglang/test/test_cutlass_moe.py +278 -0
  120. sglang/test/test_utils.py +2 -36
  121. sglang/utils.py +2 -2
  122. sglang/version.py +1 -1
  123. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
  124. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
  125. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  126. sglang/srt/function_call_parser.py +0 -858
  127. sglang/srt/platforms/interface.py +0 -371
  128. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  129. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,11 @@ from typing import List, Optional, Set, Union
22
22
  import torch
23
23
  from transformers import PretrainedConfig
24
24
 
25
- from sglang.srt.hf_transformers_utils import get_config, get_context_length
25
+ from sglang.srt.hf_transformers_utils import (
26
+ get_config,
27
+ get_context_length,
28
+ get_hf_text_config,
29
+ )
26
30
  from sglang.srt.layers.quantization import QUANTIZATION_METHODS
27
31
  from sglang.srt.server_args import ServerArgs
28
32
  from sglang.srt.utils import get_bool_env_var, is_hip
@@ -69,6 +73,7 @@ class ModelConfig:
69
73
  model_override_args=self.model_override_args,
70
74
  **kwargs,
71
75
  )
76
+
72
77
  self.hf_text_config = get_hf_text_config(self.hf_config)
73
78
  self.attention_chunk_size = getattr(
74
79
  self.hf_text_config, "attention_chunk_size", None
@@ -93,6 +98,8 @@ class ModelConfig:
93
98
  ):
94
99
  self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
95
100
 
101
+ if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
102
+ self.hf_config.architectures[0] = "MiMoMTP"
96
103
  # Check model type
97
104
  self.is_generation = is_generation_model(
98
105
  self.hf_config.architectures, is_embedding
@@ -109,6 +116,10 @@ class ModelConfig:
109
116
  self.is_audio_model = enable_multimodal and is_audio_model(
110
117
  self.hf_config.architectures
111
118
  )
119
+ self.is_multimodal_chunked_prefill_supported = (
120
+ enable_multimodal
121
+ and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
122
+ )
112
123
  self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
113
124
  self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
114
125
 
@@ -209,7 +220,13 @@ class ModelConfig:
209
220
 
210
221
  # Cache attributes
211
222
  self.hf_eos_token_id = self.get_hf_eos_token_id()
212
- self.image_token_id = getattr(self.hf_config, "image_token_id", None)
223
+
224
+ config = self.hf_config
225
+
226
+ # multimodal
227
+ self.image_token_id = getattr(config, "image_token_id", None) or getattr(
228
+ config, "image_token_index", None
229
+ )
213
230
 
214
231
  @staticmethod
215
232
  def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
@@ -332,6 +349,7 @@ class ModelConfig:
332
349
  "w8a8_int8",
333
350
  "w8a8_fp8",
334
351
  "moe_wna16",
352
+ "qoq",
335
353
  ]
336
354
  compatible_quantization_methods = {
337
355
  "modelopt_fp4": ["modelopt"],
@@ -423,31 +441,6 @@ class ModelConfig:
423
441
  self.model_path = client.get_local_dir()
424
442
 
425
443
 
426
- def get_hf_text_config(config: PretrainedConfig):
427
- """Get the "sub" config relevant to llm for multi modal models.
428
- No op for pure text models.
429
- """
430
- class_name = config.architectures[0]
431
- if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
432
- # We support non-hf version of llava models, so we do not want to
433
- # read the wrong values from the unused default text_config.
434
- # NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
435
- # `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
436
- setattr(config, "torch_dtype", torch.float16)
437
- return config
438
-
439
- if hasattr(config, "text_config"):
440
- # The code operates under the assumption that text_config should have
441
- # `num_attention_heads` (among others). Assert here to fail early
442
- # if transformers config doesn't align with this assumption.
443
- assert hasattr(config.text_config, "num_attention_heads")
444
- return config.text_config
445
- if hasattr(config, "language_config"):
446
- return config.language_config
447
- else:
448
- return config
449
-
450
-
451
444
  # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
452
445
  _STR_DTYPE_TO_TORCH_DTYPE = {
453
446
  "half": torch.float16,
@@ -466,6 +459,8 @@ def _get_and_verify_dtype(
466
459
  # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
467
460
  # because config.torch_dtype can be None.
468
461
  config_dtype = getattr(config, "torch_dtype", None)
462
+ if isinstance(config_dtype, str):
463
+ config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
469
464
  if config_dtype is None:
470
465
  config_dtype = torch.float32
471
466
 
@@ -537,6 +532,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
537
532
 
538
533
 
539
534
  multimodal_model_archs = [
535
+ "CLIPModel",
540
536
  "DeepseekVL2ForCausalLM",
541
537
  "Gemma3ForConditionalGeneration",
542
538
  "Grok1VForCausalLM",
@@ -549,11 +545,11 @@ multimodal_model_archs = [
549
545
  "LlavaVidForCausalLM",
550
546
  "MiniCPMO",
551
547
  "MiniCPMV",
548
+ "Mistral3ForConditionalGeneration",
552
549
  "MultiModalityCausalLM",
553
550
  "MllamaForConditionalGeneration",
554
551
  "Qwen2VLForConditionalGeneration",
555
552
  "Qwen2_5_VLForConditionalGeneration",
556
- "CLIPModel",
557
553
  "KimiVLForConditionalGeneration",
558
554
  "InternVLChatModel",
559
555
  ]
@@ -585,6 +581,21 @@ def is_encoder_decoder_model(model_architectures: List[str]):
585
581
  return "MllamaForConditionalGeneration" in model_architectures
586
582
 
587
583
 
584
+ def is_multimodal_chunked_prefill_supported(model_architectures: List[str]):
585
+ """Check if chunked prefill is supported for a MultiModal model."""
586
+ unsupported = [
587
+ "Grok1VForCausalLM",
588
+ "Grok1AForCausalLM",
589
+ "LlavaLlamaForCausalLM",
590
+ "MllamaForConditionalGeneration",
591
+ "CLIPModel",
592
+ ]
593
+ if any(multi_model_arch in unsupported for multi_model_arch in model_architectures):
594
+ return False
595
+ else:
596
+ return True
597
+
598
+
588
599
  def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
589
600
  if scale <= 1:
590
601
  return 1.0
@@ -781,7 +781,7 @@ register_conv_template(
781
781
  Conversation(
782
782
  name="gemma-it",
783
783
  system_message="You are a helpful assistant.",
784
- system_template="<start_of_turn>user{system_message}\n\n",
784
+ system_template="<start_of_turn>user\n{system_message}\n\n",
785
785
  roles=("<start_of_turn>user\n", "<start_of_turn>model\n"),
786
786
  sep="<end_of_turn>\n",
787
787
  sep_style=SeparatorStyle.GEMMA3,
@@ -24,6 +24,7 @@ import logging
24
24
  import os
25
25
  from collections import deque
26
26
  from dataclasses import dataclass
27
+ from http import HTTPStatus
27
28
  from typing import TYPE_CHECKING, List, Optional, Tuple
28
29
 
29
30
  import numpy as np
@@ -35,25 +36,25 @@ from sglang.srt.disaggregation.utils import (
35
36
  DisaggregationMode,
36
37
  FakeBootstrapHost,
37
38
  KVClassType,
39
+ MetadataBuffers,
38
40
  ReqToMetadataIdxAllocator,
39
41
  TransferBackend,
40
42
  get_kv_class,
41
43
  is_mla_backend,
42
44
  kv_to_page_indices,
43
45
  poll_and_all_reduce,
46
+ prepare_abort,
44
47
  )
48
+ from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
45
49
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
46
50
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
47
51
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
48
- from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
49
52
 
50
53
  logger = logging.getLogger(__name__)
51
54
 
52
55
  if TYPE_CHECKING:
53
- from sglang.srt.configs.model_config import ModelConfig
54
- from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
56
+ from sglang.srt.managers.schedule_batch import Req
55
57
  from sglang.srt.managers.scheduler import Scheduler
56
- from sglang.srt.server_args import ServerArgs
57
58
 
58
59
 
59
60
  @dataclass
@@ -73,9 +74,9 @@ class DecodePreallocQueue:
73
74
  self,
74
75
  req_to_token_pool: ReqToTokenPool,
75
76
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
77
+ draft_token_to_kv_pool: Optional[KVCache],
76
78
  req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
77
- metadata_buffers: List[torch.Tensor],
78
- aux_dtype: torch.dtype,
79
+ metadata_buffers: MetadataBuffers,
79
80
  scheduler: Scheduler,
80
81
  transfer_queue: DecodeTransferQueue,
81
82
  tree_cache: BasePrefixCache,
@@ -88,8 +89,8 @@ class DecodePreallocQueue:
88
89
  self.req_to_token_pool = req_to_token_pool
89
90
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
90
91
  self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
92
+ self.draft_token_to_kv_pool = draft_token_to_kv_pool
91
93
  self.is_mla_backend = is_mla_backend(self.token_to_kv_pool)
92
- self.aux_dtype = aux_dtype
93
94
  self.metadata_buffers = metadata_buffers
94
95
  self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
95
96
  self.scheduler = scheduler
@@ -116,19 +117,21 @@ class DecodePreallocQueue:
116
117
  self.token_to_kv_pool.get_contiguous_buf_infos()
117
118
  )
118
119
 
120
+ if self.draft_token_to_kv_pool is not None:
121
+ draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
122
+ self.draft_token_to_kv_pool.get_contiguous_buf_infos()
123
+ )
124
+ kv_data_ptrs += draft_kv_data_ptrs
125
+ kv_data_lens += draft_kv_data_lens
126
+ kv_item_lens += draft_kv_item_lens
127
+
119
128
  kv_args.kv_data_ptrs = kv_data_ptrs
120
129
  kv_args.kv_data_lens = kv_data_lens
121
130
  kv_args.kv_item_lens = kv_item_lens
122
131
 
123
- kv_args.aux_data_ptrs = [
124
- output_id_tensor.data_ptr() for output_id_tensor in self.metadata_buffers
125
- ]
126
- kv_args.aux_data_lens = [
127
- metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
128
- ]
129
- kv_args.aux_item_lens = [
130
- metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
131
- ]
132
+ kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
133
+ self.metadata_buffers.get_buf_infos()
134
+ )
132
135
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
133
136
  kv_args.gpu_id = self.scheduler.gpu_id
134
137
  kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
@@ -178,7 +181,17 @@ class DecodePreallocQueue:
178
181
  elif poll == KVPoll.WaitingForInput:
179
182
  decode_req.waiting_for_input = True
180
183
  elif poll == KVPoll.Failed:
181
- raise Exception("Handshake failed")
184
+ error_message = f"Decode handshake failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
185
+ try:
186
+ decode_req.kv_receiver.failure_exception()
187
+ except Exception as e:
188
+ error_message += f" with exception {e}"
189
+ logger.error(error_message)
190
+ prepare_abort(
191
+ decode_req.req,
192
+ error_message,
193
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
194
+ )
182
195
 
183
196
  def pop_preallocated(self) -> List[DecodeRequest]:
184
197
  """Pop the preallocated requests from the pending queue (FIFO)."""
@@ -188,7 +201,18 @@ class DecodePreallocQueue:
188
201
  indices_to_remove = set()
189
202
  allocatable_tokens = self._allocatable_tokens()
190
203
 
204
+ # First, remove all failed requests from the queue
191
205
  for i, decode_req in enumerate(self.queue):
206
+ if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
207
+ self.scheduler.stream_output(
208
+ [decode_req.req], decode_req.req.return_logprob
209
+ )
210
+ indices_to_remove.add(i)
211
+
212
+ for i, decode_req in enumerate(self.queue):
213
+ if i in indices_to_remove:
214
+ continue
215
+
192
216
  if not decode_req.waiting_for_input:
193
217
  continue
194
218
 
@@ -308,18 +332,22 @@ class DecodeTransferQueue:
308
332
  self,
309
333
  gloo_group: ProcessGroup,
310
334
  req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
311
- metadata_buffers: torch.Tensor,
335
+ metadata_buffers: MetadataBuffers,
336
+ scheduler: Scheduler,
337
+ tree_cache: BasePrefixCache,
312
338
  ):
313
339
  self.queue: List[DecodeRequest] = []
314
340
  self.gloo_group = gloo_group
315
341
  self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
316
342
  self.metadata_buffers = metadata_buffers
343
+ self.scheduler = scheduler
344
+ self.tree_cache = tree_cache
317
345
 
318
- def add(self, req_conn: DecodeRequest) -> None:
319
- self.queue.append(req_conn)
346
+ def add(self, decode_req: DecodeRequest) -> None:
347
+ self.queue.append(decode_req)
320
348
 
321
- def extend(self, req_conns) -> None:
322
- self.queue.extend(req_conns)
349
+ def extend(self, decode_reqs: List[DecodeRequest]) -> None:
350
+ self.queue.extend(decode_reqs)
323
351
 
324
352
  def pop_transferred(self) -> List[DecodeRequest]:
325
353
  if not self.queue:
@@ -333,18 +361,56 @@ class DecodeTransferQueue:
333
361
  indices_to_remove = set()
334
362
  for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
335
363
  if poll == KVPoll.Failed:
336
- raise Exception("Transfer failed")
364
+ error_message = f"Decode transfer failed for request {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
365
+ try:
366
+ decode_req.kv_receiver.failure_exception()
367
+ except Exception as e:
368
+ error_message += f" with exception {e}"
369
+ logger.error(error_message)
370
+ prepare_abort(
371
+ decode_req.req,
372
+ error_message,
373
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
374
+ )
375
+ self.scheduler.stream_output(
376
+ [decode_req.req], decode_req.req.return_logprob
377
+ )
378
+ # unlock the kv cache or it will have memory leak
379
+ self.tree_cache.cache_finished_req(decode_req.req)
380
+ indices_to_remove.add(i)
381
+ continue
337
382
  elif poll == KVPoll.Success:
338
- # pop and push it to waiting queue
383
+
339
384
  idx = decode_req.metadata_buffer_index
340
- assert len(decode_req.req.output_ids) == 0
341
- output_id_buffer = self.metadata_buffers[0]
342
- # the last dimension is padded by the same values.
343
- output_id = output_id_buffer[idx][0].item()
344
- assert len(decode_req.req.output_ids) == 0
345
- assert decode_req.req.transferred_output_id is None
346
- decode_req.req.transferred_output_id = output_id
347
- transferred_reqs.append(decode_req)
385
+ (
386
+ output_id,
387
+ output_token_logprobs_val,
388
+ output_token_logprobs_idx,
389
+ output_top_logprobs_val,
390
+ output_top_logprobs_idx,
391
+ ) = self.metadata_buffers.get_buf(idx)
392
+
393
+ decode_req.req.output_ids.append(output_id[0].item())
394
+
395
+ if decode_req.req.return_logprob:
396
+ decode_req.req.output_token_logprobs_val.append(
397
+ output_token_logprobs_val[0].item()
398
+ )
399
+ decode_req.req.output_token_logprobs_idx.append(
400
+ output_token_logprobs_idx[0].item()
401
+ )
402
+ decode_req.req.output_top_logprobs_val.append(
403
+ output_top_logprobs_val[
404
+ : decode_req.req.top_logprobs_num
405
+ ].tolist()
406
+ )
407
+ decode_req.req.output_top_logprobs_idx.append(
408
+ output_top_logprobs_idx[
409
+ : decode_req.req.top_logprobs_num
410
+ ].tolist()
411
+ )
412
+
413
+ transferred_reqs.append(decode_req.req)
348
414
  indices_to_remove.add(i)
349
415
  elif poll in [
350
416
  KVPoll.Bootstrapping,
@@ -367,95 +433,6 @@ class DecodeTransferQueue:
367
433
  return transferred_reqs
368
434
 
369
435
 
370
- class ScheduleBatchDisaggregationDecodeMixin:
371
-
372
- def prepare_for_prebuilt_extend(self: ScheduleBatch):
373
- """
374
- Prepare a prebuilt extend by populate metadata
375
- Adapted from .prepare_for_extend().
376
- """
377
-
378
- self.forward_mode = ForwardMode.EXTEND
379
- reqs = self.reqs
380
- input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
381
- extend_num_tokens = sum(len(ids) for ids in input_ids)
382
- seq_lens = []
383
- pre_lens = []
384
- req_pool_indices = []
385
-
386
- # Pre-calculate total size
387
- total_size = sum(req.extend_input_len for req in reqs)
388
- out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
389
-
390
- # Fill the tensor in one pass
391
- offset = 0
392
- for i, req in enumerate(reqs):
393
- req_pool_indices.append(req.req_pool_idx)
394
-
395
- chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
396
- : req.extend_input_len
397
- ]
398
- assert (
399
- offset + req.extend_input_len <= total_size
400
- ), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
401
- out_cache_loc[offset : offset + req.extend_input_len] = chunk
402
- offset += req.extend_input_len
403
-
404
- pre_len = len(req.prefix_indices)
405
- seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
406
- seq_lens.append(seq_len)
407
- if len(req.output_ids) == 0:
408
- assert (
409
- seq_len - pre_len == req.extend_input_len
410
- ), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
411
-
412
- req.cached_tokens += pre_len - req.already_computed
413
- req.already_computed = seq_len
414
- req.is_retracted = False
415
- pre_lens.append(pre_len)
416
- req.extend_logprob_start_len = 0
417
-
418
- extend_input_logprob_token_ids = None
419
-
420
- # Set fields
421
- self.input_ids = torch.tensor(
422
- sum(input_ids, []), dtype=torch.int32, device=self.device
423
- )
424
- self.req_pool_indices = torch.tensor(
425
- req_pool_indices, dtype=torch.int64, device=self.device
426
- )
427
- self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
428
- self.out_cache_loc = out_cache_loc
429
- self.seq_lens_sum = sum(seq_lens)
430
- self.extend_num_tokens = extend_num_tokens
431
- self.prefix_lens = [len(r.prefix_indices) for r in reqs]
432
- self.extend_lens = [r.extend_input_len for r in reqs]
433
- self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
434
- self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
435
-
436
- # Build sampling info
437
- self.sampling_info = SamplingBatchInfo.from_schedule_batch(
438
- self,
439
- self.model_config.vocab_size,
440
- )
441
-
442
- def process_prebuilt_extend(
443
- self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
444
- ):
445
- """Assign the buffered last input id to schedule batch"""
446
- self.output_ids = []
447
- for req in self.reqs:
448
- if req.output_ids and len(req.output_ids) > 0:
449
- # resumed retracted req
450
- self.output_ids.append(req.output_ids[-1])
451
- else:
452
- assert req.transferred_output_id is not None
453
- req.output_ids.append(req.transferred_output_id)
454
- self.output_ids.append(req.transferred_output_id)
455
- self.tree_cache.cache_unfinished_req(req)
456
- self.output_ids = torch.tensor(self.output_ids, device=self.device)
457
-
458
-
459
436
  class SchedulerDisaggregationDecodeMixin:
460
437
 
461
438
  def _prepare_idle_batch_and_run(self, batch, delay_process=False):
@@ -488,7 +465,9 @@ class SchedulerDisaggregationDecodeMixin:
488
465
  # Generate fake extend output.
489
466
  if batch.forward_mode.is_extend():
490
467
  # Note: Logprobs should be handled on the prefill engine.
491
- self.stream_output(batch.reqs, False)
468
+ self.stream_output(
469
+ batch.reqs, any(req.return_logprob for req in batch.reqs)
470
+ )
492
471
  if prepare_dp_attn_flag:
493
472
  self._prepare_idle_batch_and_run(None)
494
473
  else:
@@ -534,7 +513,9 @@ class SchedulerDisaggregationDecodeMixin:
534
513
  # Generate fake extend output.
535
514
  if batch.forward_mode.is_extend():
536
515
  # Note: Logprobs should be handled on the prefill engine.
537
- self.stream_output(batch.reqs, False)
516
+ self.stream_output(
517
+ batch.reqs, any(req.return_logprob for req in batch.reqs)
518
+ )
538
519
  if prepare_dp_attn_flag:
539
520
  batch_, result = self._prepare_idle_batch_and_run(
540
521
  None, delay_process=True
@@ -547,7 +528,18 @@ class SchedulerDisaggregationDecodeMixin:
547
528
  self.prepare_dp_attn_batch(batch)
548
529
  result = self.run_batch(batch)
549
530
  result_queue.append((batch.copy(), result))
531
+
532
+ if (self.last_batch is None) or (not self.last_batch_in_queue):
533
+ # Create a dummy first batch to start the pipeline for overlap schedule.
534
+ # It is now used for triggering the sampling_info_done event.
535
+ tmp_batch = ScheduleBatch(
536
+ reqs=None,
537
+ forward_mode=ForwardMode.DUMMY_FIRST,
538
+ next_batch_sampling_info=self.tp_worker.cur_sampling_info,
539
+ )
540
+ self.set_next_batch_sampling_info_done(tmp_batch)
550
541
  last_batch_in_queue = True
542
+
551
543
  elif prepare_dp_attn_flag:
552
544
  batch, result = self._prepare_idle_batch_and_run(
553
545
  None, delay_process=True
@@ -559,6 +551,9 @@ class SchedulerDisaggregationDecodeMixin:
559
551
  # Process the results of the previous batch but skip if the last batch is extend
560
552
  if self.last_batch and self.last_batch_in_queue:
561
553
  tmp_batch, tmp_result = result_queue.popleft()
554
+ tmp_batch.next_batch_sampling_info = (
555
+ self.tp_worker.cur_sampling_info if batch else None
556
+ )
562
557
  self.process_batch_result(tmp_batch, tmp_result)
563
558
 
564
559
  if batch is None and (
@@ -607,6 +602,9 @@ class SchedulerDisaggregationDecodeMixin:
607
602
 
608
603
  def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
609
604
  """Create a schedulebatch for fake completed prefill"""
605
+ if self.grammar_queue:
606
+ self.move_ready_grammar_requests()
607
+
610
608
  if len(self.waiting_queue) == 0:
611
609
  return None
612
610
 
@@ -632,8 +630,6 @@ class SchedulerDisaggregationDecodeMixin:
632
630
  self.waiting_queue = waiting_queue
633
631
  if len(can_run_list) == 0:
634
632
  return None
635
- # local import to avoid circular import
636
- from sglang.srt.managers.schedule_batch import ScheduleBatch
637
633
 
638
634
  # construct a schedule batch with those requests and mark as decode
639
635
  new_batch = ScheduleBatch.init_new(
@@ -655,15 +651,8 @@ class SchedulerDisaggregationDecodeMixin:
655
651
 
656
652
  def process_decode_queue(self: Scheduler):
657
653
  req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
658
-
659
- def _num_pre_alloc(req):
660
- return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0)
661
-
662
- self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns)
663
654
  self.disagg_decode_transfer_queue.extend(req_conns)
664
655
  alloc_reqs = (
665
656
  self.disagg_decode_transfer_queue.pop_transferred()
666
657
  ) # the requests which kv has arrived
667
- self.num_tokens_pre_allocated -= sum(_num_pre_alloc(req) for req in alloc_reqs)
668
-
669
- self.waiting_queue.extend([req.req for req in alloc_reqs])
658
+ self.waiting_queue.extend(alloc_reqs)