sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -43,10 +43,10 @@ from sglang.srt.configs.model_config import ModelConfig
43
43
  from sglang.srt.disaggregation.utils import DisaggregationMode
44
44
  from sglang.srt.lora.lora_registry import LoRARegistry
45
45
  from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
46
+ from sglang.srt.managers.async_mm_data_processor import AsyncMMDataProcessor
46
47
  from sglang.srt.managers.disagg_service import start_disagg_service
47
48
  from sglang.srt.managers.io_struct import (
48
49
  AbortReq,
49
- BaseReq,
50
50
  BatchEmbeddingOutput,
51
51
  BatchMultimodalOutput,
52
52
  BatchStrOutput,
@@ -69,6 +69,7 @@ from sglang.srt.managers.io_struct import (
69
69
  )
70
70
  from sglang.srt.managers.mm_utils import TensorTransportMode
71
71
  from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
72
+ from sglang.srt.managers.schedule_batch import RequestStage
72
73
  from sglang.srt.managers.scheduler import is_health_check_generate_req
73
74
  from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
74
75
  from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
@@ -80,6 +81,7 @@ from sglang.srt.tracing.trace import (
80
81
  trace_get_proc_propagate_context,
81
82
  trace_req_finish,
82
83
  trace_req_start,
84
+ trace_set_remote_propagate_context,
83
85
  trace_slice_end,
84
86
  trace_slice_start,
85
87
  )
@@ -171,7 +173,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
171
173
  self.context_len = self.model_config.context_len
172
174
  self.image_token_id = self.model_config.image_token_id
173
175
  self.max_req_input_len = None # Will be set later in engine.py
174
-
175
176
  speculative_algorithm = SpeculativeAlgorithm.from_string(
176
177
  server_args.speculative_algorithm
177
178
  )
@@ -180,9 +181,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
180
181
  if speculative_algorithm.is_none()
181
182
  else server_args.speculative_num_draft_tokens
182
183
  )
183
- # Initialize delimiter text for multi-item scoring (will be set after tokenizer is loaded)
184
- self.multi_item_delimiter_text = None
185
184
 
185
+ # Initialize tokenizer and processor
186
186
  if self.model_config.is_multimodal:
187
187
  import_processors("sglang.srt.multimodal.processors")
188
188
  try:
@@ -216,6 +216,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
216
216
  self.mm_processor = get_mm_processor(
217
217
  self.model_config.hf_config, server_args, _processor, transport_mode
218
218
  )
219
+ self.mm_data_processor = AsyncMMDataProcessor(
220
+ self.mm_processor,
221
+ max_concurrent_calls=self.server_args.mm_max_concurrent_calls,
222
+ timeout_s=self.server_args.mm_per_request_timeout,
223
+ )
219
224
 
220
225
  if server_args.skip_tokenizer_init:
221
226
  self.tokenizer = self.processor = None
@@ -237,6 +242,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
237
242
  revision=server_args.revision,
238
243
  )
239
244
  self._initialize_multi_item_delimiter_text()
245
+
240
246
  # Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
241
247
  if (
242
248
  server_args.enable_dynamic_batch_tokenizer
@@ -255,24 +261,20 @@ class TokenizerManager(TokenizerCommunicatorMixin):
255
261
  self.recv_from_detokenizer = get_zmq_socket(
256
262
  context, zmq.PULL, port_args.tokenizer_ipc_name, True
257
263
  )
258
- if self.server_args.tokenizer_worker_num > 1:
264
+ if self.server_args.tokenizer_worker_num == 1:
265
+ self.send_to_scheduler = get_zmq_socket(
266
+ context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
267
+ )
268
+ else:
269
+ from sglang.srt.managers.multi_tokenizer_mixin import SenderWrapper
270
+
259
271
  # Use tokenizer_worker_ipc_name in multi-tokenizer mode
260
272
  send_to_scheduler = get_zmq_socket(
261
273
  context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
262
274
  )
263
275
 
264
- class SenderWrapper:
265
- def send_pyobj(self, obj):
266
- if isinstance(obj, BaseReq):
267
- obj.http_worker_ipc = port_args.tokenizer_ipc_name
268
- send_to_scheduler.send_pyobj(obj)
269
-
270
276
  # Make sure that each request carries the tokenizer_ipc_name for response routing
271
- self.send_to_scheduler = SenderWrapper()
272
- else:
273
- self.send_to_scheduler = get_zmq_socket(
274
- context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
275
- )
277
+ self.send_to_scheduler = SenderWrapper(port_args, send_to_scheduler)
276
278
 
277
279
  # Request states
278
280
  self._chosen_loop = None
@@ -320,6 +322,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
320
322
  # LoRA updates and inference to overlap.
321
323
  self.lora_update_lock = asyncio.Lock()
322
324
 
325
+ # Disaggregation
323
326
  self.disaggregation_mode = DisaggregationMode(
324
327
  self.server_args.disaggregation_mode
325
328
  )
@@ -388,10 +391,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
388
391
  self.auto_create_handle_loop()
389
392
  obj.normalize_batch_and_arguments()
390
393
 
391
- if self.server_args.tokenizer_worker_num > 1:
392
- from sglang.srt.managers.multi_tokenizer_mixin import TokenizerWorker
394
+ if request:
395
+ if "trace_context" in request.headers:
396
+ trace_set_remote_propagate_context(request.headers["trace_context"])
393
397
 
394
- assert isinstance(self, TokenizerWorker)
398
+ if self.server_args.tokenizer_worker_num > 1:
395
399
  self._attach_multi_http_worker_info(obj)
396
400
 
397
401
  if self.enable_trace:
@@ -600,10 +604,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
600
604
  obj.image_data = [obj.image_data]
601
605
  if obj.audio_data is not None and not isinstance(obj.audio_data, list):
602
606
  obj.audio_data = [obj.audio_data]
603
- mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
607
+ mm_inputs: Dict = await self.mm_data_processor.process(
604
608
  image_data=obj.image_data,
605
609
  audio_data=obj.audio_data,
606
- input_text=input_text or input_ids,
610
+ input_text_or_ids=(input_text or input_ids),
607
611
  request_obj=obj,
608
612
  max_req_input_len=self.max_req_input_len,
609
613
  )
@@ -613,7 +617,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
613
617
  mm_inputs = None
614
618
 
615
619
  self._validate_one_request(obj, input_ids)
616
- trace_slice_end("tokenize", obj.rid)
620
+ trace_slice_end(RequestStage.TOKENIZE, obj.rid)
617
621
  return self._create_tokenized_object(
618
622
  obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
619
623
  )
@@ -674,6 +678,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
674
678
  )
675
679
  raise ValueError(error_msg)
676
680
 
681
+ # Matryoshka embeddings validations
682
+ if isinstance(obj, EmbeddingReqInput):
683
+ self._validate_for_matryoshka_dim(obj)
684
+
677
685
  if isinstance(obj, GenerateReqInput):
678
686
  if (
679
687
  obj.return_hidden_states
@@ -692,6 +700,34 @@ class TokenizerManager(TokenizerCommunicatorMixin):
692
700
  "Please set `--enable-custom-logit-processor` to enable this feature."
693
701
  )
694
702
 
703
+ def _validate_for_matryoshka_dim(self, obj: EmbeddingReqInput) -> None:
704
+ """Validate the request for Matryoshka dim if it has the field set."""
705
+ if obj.dimensions is None:
706
+ return
707
+
708
+ if not self.model_config.is_matryoshka:
709
+ raise ValueError(
710
+ f"Model '{self.model_config.model_path}' does not support matryoshka representation, "
711
+ f"changing output dimensions will lead to poor results."
712
+ )
713
+
714
+ if obj.dimensions < 1:
715
+ raise ValueError("Requested dimensions must be greater than 0")
716
+
717
+ if (
718
+ self.model_config.matryoshka_dimensions
719
+ and obj.dimensions not in self.model_config.matryoshka_dimensions
720
+ ):
721
+ raise ValueError(
722
+ f"Model '{self.model_config.model_path}' only supports {self.model_config.matryoshka_dimensions} matryoshka dimensions, "
723
+ f"using other output dimensions will lead to poor results."
724
+ )
725
+
726
+ if obj.dimensions > self.model_config.hidden_size:
727
+ raise ValueError(
728
+ f"Provided dimensions are greater than max embedding dimension: {self.model_config.hidden_size}"
729
+ )
730
+
695
731
  def _validate_input_ids_in_vocab(
696
732
  self, input_ids: List[int], vocab_size: int
697
733
  ) -> None:
@@ -760,6 +796,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
760
796
  sampling_params,
761
797
  rid=obj.rid,
762
798
  priority=obj.priority,
799
+ dimensions=obj.dimensions,
763
800
  http_worker_ipc=obj.http_worker_ipc,
764
801
  )
765
802
 
@@ -806,7 +843,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
806
843
  req, req.text, input_ids_list[i], None, None, token_type_ids
807
844
  )
808
845
  )
809
- trace_slice_end("tokenize", req.rid)
846
+ trace_slice_end(RequestStage.TOKENIZE, req.rid)
810
847
  logger.debug(f"Completed batch processing for {batch_size} requests")
811
848
  return tokenized_objs
812
849
 
@@ -858,12 +895,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
858
895
  tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
859
896
  created_time: Optional[float] = None,
860
897
  ):
861
- trace_slice_start("dispatch", obj.rid)
898
+ trace_slice_start(RequestStage.TOKENIZER_DISPATCH, obj.rid)
862
899
  tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
863
900
  self.send_to_scheduler.send_pyobj(tokenized_obj)
864
901
  state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
865
902
  self.rid_to_state[obj.rid] = state
866
- trace_slice_end("dispatch", obj.rid, thread_finish_flag=True)
903
+ trace_slice_end(
904
+ RequestStage.TOKENIZER_DISPATCH, obj.rid, thread_finish_flag=True
905
+ )
867
906
  return state
868
907
 
869
908
  def _send_batch_request(
@@ -1365,6 +1404,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1365
1404
  "finish_reason": recv_obj.finished_reasons[i],
1366
1405
  "prompt_tokens": recv_obj.prompt_tokens[i],
1367
1406
  "weight_version": self.server_args.weight_version,
1407
+ "total_retractions": recv_obj.retraction_counts[i],
1368
1408
  }
1369
1409
 
1370
1410
  if getattr(state.obj, "return_logprob", False):
@@ -1453,6 +1493,51 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1453
1493
  if self.crash_dump_folder and state.finished and state.obj.log_metrics:
1454
1494
  self.record_request_for_crash_dump(state, out_dict)
1455
1495
 
1496
+ def add_logprob_to_meta_info(
1497
+ self,
1498
+ meta_info: dict,
1499
+ state: ReqState,
1500
+ top_logprobs_num: int,
1501
+ token_ids_logprob: List[int],
1502
+ return_text_in_logprobs: bool,
1503
+ ):
1504
+ meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
1505
+ state.input_token_logprobs_val,
1506
+ state.input_token_logprobs_idx,
1507
+ return_text_in_logprobs,
1508
+ )
1509
+ meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
1510
+ state.output_token_logprobs_val,
1511
+ state.output_token_logprobs_idx,
1512
+ return_text_in_logprobs,
1513
+ )
1514
+
1515
+ if top_logprobs_num > 0:
1516
+ meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1517
+ state.input_top_logprobs_val,
1518
+ state.input_top_logprobs_idx,
1519
+ return_text_in_logprobs,
1520
+ )
1521
+ meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1522
+ state.output_top_logprobs_val,
1523
+ state.output_top_logprobs_idx,
1524
+ return_text_in_logprobs,
1525
+ )
1526
+
1527
+ if token_ids_logprob is not None:
1528
+ meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
1529
+ state.input_token_ids_logprobs_val,
1530
+ state.input_token_ids_logprobs_idx,
1531
+ return_text_in_logprobs,
1532
+ )
1533
+ meta_info["output_token_ids_logprobs"] = (
1534
+ self.detokenize_top_logprobs_tokens(
1535
+ state.output_token_ids_logprobs_val,
1536
+ state.output_token_ids_logprobs_idx,
1537
+ return_text_in_logprobs,
1538
+ )
1539
+ )
1540
+
1456
1541
  def convert_logprob_style(
1457
1542
  self,
1458
1543
  meta_info: dict,
@@ -1479,16 +1564,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1479
1564
  state.output_token_logprobs_idx.extend(
1480
1565
  recv_obj.output_token_logprobs_idx[recv_obj_index]
1481
1566
  )
1482
- meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
1483
- state.input_token_logprobs_val,
1484
- state.input_token_logprobs_idx,
1485
- return_text_in_logprobs,
1486
- )
1487
- meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
1488
- state.output_token_logprobs_val,
1489
- state.output_token_logprobs_idx,
1490
- return_text_in_logprobs,
1491
- )
1492
1567
 
1493
1568
  if top_logprobs_num > 0:
1494
1569
  if len(recv_obj.input_top_logprobs_val) > 0:
@@ -1504,16 +1579,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1504
1579
  state.output_top_logprobs_idx.extend(
1505
1580
  recv_obj.output_top_logprobs_idx[recv_obj_index]
1506
1581
  )
1507
- meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1508
- state.input_top_logprobs_val,
1509
- state.input_top_logprobs_idx,
1510
- return_text_in_logprobs,
1511
- )
1512
- meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1513
- state.output_top_logprobs_val,
1514
- state.output_top_logprobs_idx,
1515
- return_text_in_logprobs,
1516
- )
1517
1582
 
1518
1583
  if token_ids_logprob is not None:
1519
1584
  if len(recv_obj.input_token_ids_logprobs_val) > 0:
@@ -1529,18 +1594,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1529
1594
  state.output_token_ids_logprobs_idx.extend(
1530
1595
  recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
1531
1596
  )
1532
- meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
1533
- state.input_token_ids_logprobs_val,
1534
- state.input_token_ids_logprobs_idx,
1535
- return_text_in_logprobs,
1536
- )
1537
- meta_info["output_token_ids_logprobs"] = (
1538
- self.detokenize_top_logprobs_tokens(
1539
- state.output_token_ids_logprobs_val,
1540
- state.output_token_ids_logprobs_idx,
1541
- return_text_in_logprobs,
1542
- )
1543
- )
1597
+
1598
+ self.add_logprob_to_meta_info(
1599
+ meta_info,
1600
+ state,
1601
+ state.obj.top_logprobs_num,
1602
+ state.obj.token_ids_logprob,
1603
+ return_text_in_logprobs,
1604
+ )
1544
1605
 
1545
1606
  def detokenize_logprob_tokens(
1546
1607
  self,
@@ -1657,6 +1718,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1657
1718
  or state.obj.sampling_params.get("ebnf", None)
1658
1719
  or state.obj.sampling_params.get("structural_tag", None)
1659
1720
  )
1721
+
1722
+ retraction_count = (
1723
+ recv_obj.retraction_counts[i]
1724
+ if getattr(recv_obj, "retraction_counts", None)
1725
+ and i < len(recv_obj.retraction_counts)
1726
+ else 0
1727
+ )
1728
+
1660
1729
  self.metrics_collector.observe_one_finished_request(
1661
1730
  labels,
1662
1731
  recv_obj.prompt_tokens[i],
@@ -1664,6 +1733,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1664
1733
  recv_obj.cached_tokens[i],
1665
1734
  state.finished_time - state.created_time,
1666
1735
  has_grammar,
1736
+ retraction_count,
1667
1737
  )
1668
1738
 
1669
1739
  def dump_requests(self, state: ReqState, out_dict: dict):
@@ -1716,26 +1786,33 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1716
1786
  return
1717
1787
  state = self.rid_to_state[recv_obj.rid]
1718
1788
  state.finished = True
1789
+
1790
+ abort_message = recv_obj.abort_message or "Abort in waiting queue"
1791
+ finish_reason = {
1792
+ "type": "abort",
1793
+ "message": abort_message,
1794
+ }
1719
1795
  if recv_obj.finished_reason:
1720
- out = {
1721
- "meta_info": {
1722
- "id": recv_obj.rid,
1723
- "finish_reason": recv_obj.finished_reason,
1724
- },
1725
- }
1726
- else:
1727
- out = {
1728
- "text": "",
1729
- "meta_info": {
1730
- "id": recv_obj.rid,
1731
- "finish_reason": {
1732
- "type": "abort",
1733
- "message": "Abort before prefill",
1734
- },
1735
- "prompt_tokens": 0,
1736
- "completion_tokens": 0,
1737
- },
1738
- }
1796
+ finish_reason = recv_obj.finished_reason
1797
+ meta_info = {"id": recv_obj.rid, "finish_reason": finish_reason}
1798
+ is_stream = getattr(state.obj, "stream", False)
1799
+ if getattr(state.obj, "return_logprob", False):
1800
+ self.add_logprob_to_meta_info(
1801
+ meta_info,
1802
+ state,
1803
+ state.obj.top_logprobs_num,
1804
+ state.obj.token_ids_logprob,
1805
+ state.obj.return_text_in_logprobs
1806
+ and not self.server_args.skip_tokenizer_init,
1807
+ )
1808
+
1809
+ output_ids = state.output_ids
1810
+ meta_info["completion_tokens"] = len(output_ids)
1811
+ out = {
1812
+ "text": state.text,
1813
+ "output_ids": [output_ids[-1]] if is_stream else output_ids,
1814
+ "meta_info": meta_info,
1815
+ }
1739
1816
  state.out_list.append(out)
1740
1817
  state.event.set()
1741
1818
 
@@ -2096,7 +2173,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
2096
2173
  bootstrap_room = (
2097
2174
  obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
2098
2175
  )
2099
- trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
2176
+ trace_req_start(
2177
+ obj.rid,
2178
+ bootstrap_room,
2179
+ ts=int(created_time * 1e9),
2180
+ role=self.server_args.disaggregation_mode,
2181
+ )
2100
2182
  trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
2101
2183
  else:
2102
2184
  for i in range(len(obj.rid)):
@@ -2105,7 +2187,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
2105
2187
  if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
2106
2188
  else None
2107
2189
  )
2108
- trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
2190
+ trace_req_start(
2191
+ obj.rid[i],
2192
+ bootstrap_room,
2193
+ ts=int(created_time * 1e9),
2194
+ role=self.server_args.disaggregation_mode,
2195
+ )
2109
2196
  trace_slice_start(
2110
2197
  "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
2111
2198
  )
@@ -35,7 +35,7 @@ from sglang.srt.managers.io_struct import (
35
35
  UpdateWeightsFromIPCReqInput,
36
36
  UpdateWeightsFromTensorReqInput,
37
37
  )
38
- from sglang.srt.managers.schedule_batch import ModelWorkerBatch
38
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch
39
39
  from sglang.srt.managers.scheduler import GenerationBatchResult
40
40
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
41
41
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
@@ -425,3 +425,26 @@ class TpModelWorker(BaseTpWorker):
425
425
  pp_hidden_states_proxy_tensors=pp_proxy_tensors,
426
426
  can_run_cuda_graph=can_run_cuda_graph,
427
427
  )
428
+
429
+ def forward_batch_split_prefill(self, batch: ScheduleBatch):
430
+ if batch.split_index == 0:
431
+ model_worker_batch = batch.get_model_worker_batch()
432
+ forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
433
+ batch.split_forward_batch = forward_batch
434
+ batch.seq_lens_cpu_cache = model_worker_batch.seq_lens_cpu
435
+ else:
436
+ model_worker_batch = batch.get_model_worker_batch(batch.seq_lens_cpu_cache)
437
+
438
+ logits_output, can_run_cuda_graph = self.model_runner.forward(
439
+ batch.split_forward_batch, split_forward_count=batch.split_forward_count
440
+ )
441
+ if logits_output:
442
+ next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
443
+ else:
444
+ next_token_ids = None
445
+ batch_result = GenerationBatchResult(
446
+ logits_output=logits_output,
447
+ can_run_cuda_graph=can_run_cuda_graph,
448
+ )
449
+ batch_result.next_token_ids = next_token_ids
450
+ return batch_result
@@ -1,12 +1,31 @@
1
+ from __future__ import annotations
2
+
1
3
  from abc import ABC, abstractmethod
2
- from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ Any,
7
+ NamedTuple,
8
+ Optional,
9
+ Protocol,
10
+ Tuple,
11
+ runtime_checkable,
12
+ )
3
13
 
4
14
  import torch
5
15
 
16
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
17
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
18
+
6
19
  if TYPE_CHECKING:
7
20
  from sglang.srt.managers.schedule_batch import Req
8
- else:
9
- Req = Any # Placeholder for Req type when not type checking
21
+
22
+
23
+ @runtime_checkable
24
+ class PrefixCacheTrait(Protocol):
25
+ req_to_token_pool: ReqToTokenPool
26
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator
27
+ page_size: int
28
+ disable: bool
10
29
 
11
30
 
12
31
  class MatchResult(NamedTuple):
@@ -28,7 +47,7 @@ class MatchResult(NamedTuple):
28
47
  host_hit_length: int = 0
29
48
 
30
49
 
31
- class BasePrefixCache(ABC):
50
+ class BasePrefixCache(ABC, PrefixCacheTrait):
32
51
  """Cache can be indexed by either rid or key."""
33
52
 
34
53
  @abstractmethod
@@ -89,6 +89,7 @@ def write_cache_indices(
89
89
  prefix_pointers = torch.tensor(
90
90
  [t.data_ptr() for t in prefix_tensors],
91
91
  device=req_to_token_pool.device,
92
+ dtype=torch.uint64,
92
93
  )
93
94
  # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
94
95
  write_req_to_token_pool_triton[(req_pool_indices_tensor.shape[0],)](
@@ -19,7 +19,13 @@ def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
19
19
  hasher.update(bytes.fromhex(prior_hash))
20
20
 
21
21
  for t in token_ids:
22
- hasher.update(t.to_bytes(4, byteorder="little", signed=False))
22
+ if isinstance(t, tuple):
23
+ # EAGLE bigram mode: hash both elements to uniquely identify the bigram
24
+ for elem in t:
25
+ hasher.update(elem.to_bytes(4, byteorder="little", signed=False))
26
+ else:
27
+ # Regular mode: single integer token
28
+ hasher.update(t.to_bytes(4, byteorder="little", signed=False))
23
29
 
24
30
  return hasher.hexdigest()
25
31